snad_logo_3.png

MNIST dataset

This notebook gives an example of Active Anomaly Detection with coniferest and MNIST dataset.

Developers of conferest: - Matwey Kornilov (MSU) - Vladimir Korolev - Konstantin Malanchev (LINCC Frameworks / CMU), notebook author

Run this NB in Google Colab

[1]:
## Install and import the required libraries
[2]:
# Install packages
%pip install coniferest
%pip install datasets
Requirement already satisfied: coniferest in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (0.0.14)
Requirement already satisfied: click in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (8.1.7)
Requirement already satisfied: joblib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.4.2)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.26.4)
Requirement already satisfied: scikit-learn<2,>=1.4 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.4.2)
Requirement already satisfied: matplotlib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (3.9.0)
Requirement already satisfied: onnxconverter-common in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.14.0)
Requirement already satisfied: scipy>=1.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (1.13.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (3.5.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (24.0)
Requirement already satisfied: pillow>=8 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (10.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (2.9.0.post0)
Requirement already satisfied: onnx in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (1.16.0)
Requirement already satisfied: protobuf==3.20.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (3.20.2)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->coniferest) (1.16.0)
Note: you may need to restart the kernel to use updated packages.
Collecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl.metadata (19 kB)
Collecting filelock (from datasets)
  Downloading filelock-3.14.0-py3-none-any.whl.metadata (2.8 kB)
Requirement already satisfied: numpy>=1.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=12.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (16.1.0)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (2.2.2)
Requirement already satisfied: requests>=2.19.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (2.31.0)
Collecting tqdm>=4.62.1 (from datasets)
  Downloading tqdm-4.66.4-py3-none-any.whl.metadata (57 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.6/57.6 kB 3.8 MB/s eta 0:00:00
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.3.1,>=2023.1.0 (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets)
  Downloading fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting huggingface-hub>=0.21.2 (from datasets)
  Downloading huggingface_hub-0.23.0-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: packaging in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (6.0.1)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets)
  Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Requirement already satisfied: attrs>=17.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from aiohttp->datasets) (23.2.0)
Collecting frozenlist>=1.1.1 (from aiohttp->datasets)
  Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)
  Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting yarl<2.0,>=1.0 (from aiohttp->datasets)
  Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (31 kB)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from huggingface-hub>=0.21.2->datasets) (4.11.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (2024.2.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 542.0/542.0 kB 18.3 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 4.9 MB/s eta 0:00:00
Downloading fsspec-2024.3.1-py3-none-any.whl (171 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 172.0/172.0 kB 16.5 MB/s eta 0:00:00
Downloading aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 40.9 MB/s eta 0:00:00
Downloading huggingface_hub-0.23.0-py3-none-any.whl (401 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 401.2/401.2 kB 28.3 MB/s eta 0:00:00
Downloading tqdm-4.66.4-py3-none-any.whl (78 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.3/78.3 kB 5.1 MB/s eta 0:00:00
Downloading filelock-3.14.0-py3-none-any.whl (12 kB)
Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 143.5/143.5 kB 9.0 MB/s eta 0:00:00
Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Downloading xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.8/194.8 kB 13.1 MB/s eta 0:00:00
Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (272 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 272.3/272.3 kB 20.2 MB/s eta 0:00:00
Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (128 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 128.7/128.7 kB 4.8 MB/s eta 0:00:00
Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (328 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 328.1/328.1 kB 26.0 MB/s eta 0:00:00
Installing collected packages: xxhash, tqdm, pyarrow-hotfix, multidict, fsspec, frozenlist, filelock, dill, yarl, multiprocess, huggingface-hub, aiosignal, aiohttp, datasets
Successfully installed aiohttp-3.9.5 aiosignal-1.3.1 datasets-2.19.1 dill-0.3.8 filelock-3.14.0 frozenlist-1.4.1 fsspec-2024.3.1 huggingface-hub-0.23.0 multidict-6.0.5 multiprocess-0.70.16 pyarrow-hotfix-0.6 tqdm-4.66.4 xxhash-3.4.1 yarl-1.9.4
Note: you may need to restart the kernel to use updated packages.
[3]:
import datasets
import matplotlib.pyplot as plt
import numpy as np

from coniferest.isoforest import IsolationForest
from coniferest.pineforest import PineForest
from coniferest.session import Session
from coniferest.session.callback import TerminateAfter, prompt_decision_callback, Label

Download and load the MNIST dataset

Download data from Hugging Faces with datasets library

[4]:
mnist = datasets.load_dataset("mnist")

Load the data into numpy arrays

[5]:
images_train = np.asarray(mnist['train']['image'])
images_test = np.asarray(mnist['test']['image'])
digits_train = np.asarray(mnist['train']['label'])
digits_test = np.asarray(mnist['test']['label'])

images = np.concatenate([images_train, images_test])
digits = np.concatenate([digits_train, digits_test])

Plot some examples

[6]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
    ax[i//5, i%5].imshow(images[digits == i][0], cmap='gray')
    ax[i//5, i%5].set_title(f'Digit {i}')
    ax[i//5, i%5].axis('off')
../_images/notebooks_mnist_11_0.png

Preprocess the data

Select the data to use: - image : the original images - fft : the power spectrum of the images - both : the original images and the power spectrum together

[7]:
DATA = 'both'  # 'image', 'fft', 'both'

Make 2-d FFT of the images

[8]:
# Make 2-d FFT of the images
data_fft = np.fft.fft2(images)
# Get power spectrum
power_spectrum = np.square(np.abs(data_fft))
# Normalize the power spectrum by zero frequency
power_spectrum = power_spectrum / power_spectrum[:, 0, 0][:, None, None]

Plot some examples of power spectrum

[9]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
    ax[i//5, i%5].imshow(np.log(power_spectrum[digits == i][0]), cmap='gray')
    ax[i//5, i%5].set_title(f'Digit {i}')
    ax[i//5, i%5].axis('off')
../_images/notebooks_mnist_18_0.png

Concatenate images and power spectrum

[10]:
if DATA == 'image':
    final = np.asarray(images, dtype=np.float32)
elif DATA == 'fft':
    final = np.asarray(power_spectrum.reshape(-1, 28 * 28), dtype=np.float32)
elif DATA == 'both':
    final = np.concatenate([images.reshape(-1, 28 * 28), power_spectrum.reshape(-1, 28 * 28)], axis=1)
else:
    raise ValueError(f"Unknown value for DATA: {DATA}")

Classic anomaly detection with Isolation forest

[11]:
model = IsolationForest(random_seed=10, n_trees=1000)
model.fit(np.array(final))
scores = model.score_samples(np.array(final))
ordered_index = np.argsort(scores)
ordered_digits = digits[ordered_index]

print(f"Top 10 weirdest digits : {ordered_digits[:10]}")
print(f"Top 10 most normal digits : {ordered_digits[-10:]}")
Top 10 weirdest digits : [6 5 5 5 3 7 5 5 9 5]
Top 10 most normal digits : [1 1 1 1 1 1 1 1 1 1]

Plot the top 10 weirdest digits and the top 10 most normal digits

[12]:
fig, ax = plt.subplots(2, 10, figsize=(20, 5))
for i in range(10):
    ax[0, i].imshow(images[ordered_index[i]], cmap='gray')
    ax[0, i].set_title(f'Digit {ordered_digits[i]}')
    ax[0, i].axis('off')
    ax[1, i].imshow(images[ordered_index[-i - 1]], cmap='gray')
    ax[1, i].set_title(f'Digit {ordered_digits[-i - 1]}')
    ax[1, i].axis('off')
fig.text(0.1, 0.9, 'Top 10 weirdest digits', ha='left', va='center', fontsize=16)
fig.text(0.1, 0.5, 'Top 10 most normal digits', ha='left', va='center', fontsize=16)
[12]:
Text(0.1, 0.5, 'Top 10 most normal digits')
../_images/notebooks_mnist_24_1.png

Anomaly detection with PineForest

Set expert budget

[13]:
EXPERT_BUDGET = 20

First, we need a function which would show us an image, its label and ask us if it is an anomaly.

Let’s say that even numbers are anomalies

[14]:
def decision(index, x, session):
    digit, image = digits[index], images[index]
    fig, ax = plt.subplots(1, 1, figsize=(2, 2))
    ax.imshow(image, cmap='gray')
    ax.set_title(f'Digit {digit}')
    ax.axis('off')
    plt.show()

    ### UNCOMMENT TO MAKE IT INTERACTIVE
    # return prompt_decision_callback(index, x, session)

    # Non-interactive
    return Label.ANOMALY if digit % 2 == 0 else Label.REGULAR

Create a model and a session.

[15]:
model = PineForest(
    # Number of trees to use for predictions
    n_trees=256,
    # Number of new tree to grow for each decision
    n_spare_trees=768,
    # Fix random seed for reproducibility
    random_seed=0,
)
session = Session(
    data=final,
    metadata=np.arange(len(final)),
    model=model,
    decision_callback=decision,
    on_decision_callbacks=[
        TerminateAfter(EXPERT_BUDGET),
    ],
)
session.run()
../_images/notebooks_mnist_31_0.png
../_images/notebooks_mnist_31_1.png
../_images/notebooks_mnist_31_2.png
../_images/notebooks_mnist_31_3.png
../_images/notebooks_mnist_31_4.png
../_images/notebooks_mnist_31_5.png
../_images/notebooks_mnist_31_6.png
../_images/notebooks_mnist_31_7.png
../_images/notebooks_mnist_31_8.png
../_images/notebooks_mnist_31_9.png
../_images/notebooks_mnist_31_10.png
../_images/notebooks_mnist_31_11.png
../_images/notebooks_mnist_31_12.png
../_images/notebooks_mnist_31_13.png
../_images/notebooks_mnist_31_14.png
../_images/notebooks_mnist_31_15.png
../_images/notebooks_mnist_31_16.png
../_images/notebooks_mnist_31_17.png
../_images/notebooks_mnist_31_18.png
../_images/notebooks_mnist_31_19.png
[15]:
<coniferest.session.Session at 0x7f8468b39bd0>

Let’s see what we have selected

[16]:
n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
Anomalies: 11/20 (55.00%)

Let’s do the opposite: odd numbers are anomalies

[17]:
def decision(index, x, session):
    digit, image = digits[index], images[index]
    fig, ax = plt.subplots(1, 1, figsize=(2, 2))
    ax.imshow(image, cmap='gray')
    ax.set_title(f'Digit {digit}')
    ax.axis('off')
    plt.show()

    ### UNCOMMENT TO MAKE IT INTERACTIVE
    # return prompt_decision_callback(index, x, session)

    # Non-interactive
    return Label.ANOMALY if digit % 2 == 1 else Label.REGULAR

model = PineForest(
    # Number of trees to use for predictions
    n_trees=256,
    # Number of new tree to grow for each decision
    n_spare_trees=768,
    # Fix random seed for reproducibility
    random_seed=0,
)
session = Session(
    data=final,
    metadata=np.arange(len(final)),
    model=model,
    decision_callback=decision,
    on_decision_callbacks=[
        TerminateAfter(EXPERT_BUDGET),
    ],
)
session.run()

n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
../_images/notebooks_mnist_35_0.png
../_images/notebooks_mnist_35_1.png
../_images/notebooks_mnist_35_2.png
../_images/notebooks_mnist_35_3.png
../_images/notebooks_mnist_35_4.png
../_images/notebooks_mnist_35_5.png
../_images/notebooks_mnist_35_6.png
../_images/notebooks_mnist_35_7.png
../_images/notebooks_mnist_35_8.png
../_images/notebooks_mnist_35_9.png
../_images/notebooks_mnist_35_10.png
../_images/notebooks_mnist_35_11.png
../_images/notebooks_mnist_35_12.png
../_images/notebooks_mnist_35_13.png
../_images/notebooks_mnist_35_14.png
../_images/notebooks_mnist_35_15.png
../_images/notebooks_mnist_35_16.png
../_images/notebooks_mnist_35_17.png
../_images/notebooks_mnist_35_18.png
../_images/notebooks_mnist_35_19.png
Anomalies: 17/20 (85.00%)

Change decision function to make it interactive and try your own experiments. For example, say yes to weird sevens only