snad_logo_3.png

MNIST dataset

This notebook gives an example of Active Anomaly Detection with coniferest and MNIST dataset.

Developers of conferest: - Matwey Kornilov (MSU) - Vladimir Korolev - Konstantin Malanchev (LINCC Frameworks / CMU), notebook author

Run this NB in Google Colab

[1]:
## Install and import the required libraries
[2]:
# Install packages
%pip install coniferest
%pip install datasets
Requirement already satisfied: coniferest in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (0.0.15)
Requirement already satisfied: click in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (8.1.7)
Requirement already satisfied: joblib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (1.4.2)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (2.2.0)
Requirement already satisfied: scikit-learn<2,>=1.4 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (1.5.2)
Requirement already satisfied: matplotlib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (3.9.3)
Requirement already satisfied: onnxconverter-common in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (1.14.0)
Requirement already satisfied: scipy>=1.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (1.14.1)
Requirement already satisfied: threadpoolctl>=3.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (3.5.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (4.55.2)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (24.2)
Requirement already satisfied: pillow>=8 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (11.0.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (3.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (2.9.0.post0)
Requirement already satisfied: onnx in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (1.17.0)
Requirement already satisfied: protobuf==3.20.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (3.20.2)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->coniferest) (1.17.0)
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: datasets in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (3.1.0)
Requirement already satisfied: filelock in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (3.16.1)
Requirement already satisfied: numpy>=1.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (2.2.0)
Requirement already satisfied: pyarrow>=15.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (18.1.0)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (2.2.3)
Requirement already satisfied: requests>=2.32.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (2.32.3)
Requirement already satisfied: tqdm>=4.66.3 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (4.67.1)
Requirement already satisfied: xxhash in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (3.5.0)
Requirement already satisfied: multiprocess<0.70.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)
Requirement already satisfied: aiohttp in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (3.11.10)
Requirement already satisfied: huggingface-hub>=0.23.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (0.26.5)
Requirement already satisfied: packaging in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (24.2)
Requirement already satisfied: pyyaml>=5.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (2.4.4)
Requirement already satisfied: aiosignal>=1.1.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (0.2.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (1.18.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (2024.8.30)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)
Note: you may need to restart the kernel to use updated packages.
[3]:
import datasets
import matplotlib.pyplot as plt
import numpy as np

from coniferest.isoforest import IsolationForest
from coniferest.pineforest import PineForest
from coniferest.session import Session
from coniferest.session.callback import TerminateAfter, prompt_decision_callback, Label

Download and load the MNIST dataset

Download data from Hugging Faces with datasets library

[4]:
mnist = datasets.load_dataset("mnist")

Load the data into numpy arrays

[5]:
images_train = np.asarray(mnist['train']['image'])
images_test = np.asarray(mnist['test']['image'])
digits_train = np.asarray(mnist['train']['label'])
digits_test = np.asarray(mnist['test']['label'])

images = np.concatenate([images_train, images_test])
digits = np.concatenate([digits_train, digits_test])

Plot some examples

[6]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
    ax[i//5, i%5].imshow(images[digits == i][0], cmap='gray')
    ax[i//5, i%5].set_title(f'Digit {i}')
    ax[i//5, i%5].axis('off')
../_images/notebooks_mnist_11_0.png

Preprocess the data

Select the data to use: - image : the original images - fft : the power spectrum of the images - both : the original images and the power spectrum together

[7]:
DATA = 'both'  # 'image', 'fft', 'both'

Make 2-d FFT of the images

[8]:
# Make 2-d FFT of the images
data_fft = np.fft.fft2(images)
# Get power spectrum
power_spectrum = np.square(np.abs(data_fft))
# Normalize the power spectrum by zero frequency
power_spectrum = power_spectrum / power_spectrum[:, 0, 0][:, None, None]

Plot some examples of power spectrum

[9]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
    ax[i//5, i%5].imshow(np.log(power_spectrum[digits == i][0]), cmap='gray')
    ax[i//5, i%5].set_title(f'Digit {i}')
    ax[i//5, i%5].axis('off')
../_images/notebooks_mnist_18_0.png

Concatenate images and power spectrum

[10]:
if DATA == 'image':
    final = np.asarray(images, dtype=np.float32)
elif DATA == 'fft':
    final = np.asarray(power_spectrum.reshape(-1, 28 * 28), dtype=np.float32)
elif DATA == 'both':
    final = np.concatenate([images.reshape(-1, 28 * 28), power_spectrum.reshape(-1, 28 * 28)], axis=1)
else:
    raise ValueError(f"Unknown value for DATA: {DATA}")

Classic anomaly detection with Isolation forest

[11]:
model = IsolationForest(random_seed=10, n_trees=1000)
model.fit(np.array(final))
scores = model.score_samples(np.array(final))
ordered_index = np.argsort(scores)
ordered_digits = digits[ordered_index]

print(f"Top 10 weirdest digits : {ordered_digits[:10]}")
print(f"Top 10 most normal digits : {ordered_digits[-10:]}")
Top 10 weirdest digits : [6 5 5 5 3 7 5 5 9 5]
Top 10 most normal digits : [1 1 1 1 1 1 1 1 1 1]

Plot the top 10 weirdest digits and the top 10 most normal digits

[12]:
fig, ax = plt.subplots(2, 10, figsize=(20, 5))
for i in range(10):
    ax[0, i].imshow(images[ordered_index[i]], cmap='gray')
    ax[0, i].set_title(f'Digit {ordered_digits[i]}')
    ax[0, i].axis('off')
    ax[1, i].imshow(images[ordered_index[-i - 1]], cmap='gray')
    ax[1, i].set_title(f'Digit {ordered_digits[-i - 1]}')
    ax[1, i].axis('off')
fig.text(0.1, 0.9, 'Top 10 weirdest digits', ha='left', va='center', fontsize=16)
fig.text(0.1, 0.5, 'Top 10 most normal digits', ha='left', va='center', fontsize=16)
[12]:
Text(0.1, 0.5, 'Top 10 most normal digits')
../_images/notebooks_mnist_24_1.png

Anomaly detection with PineForest

Set expert budget

[13]:
EXPERT_BUDGET = 20

First, we need a function which would show us an image, its label and ask us if it is an anomaly.

Let’s say that even numbers are anomalies

[14]:
def decision(index, x, session):
    digit, image = digits[index], images[index]
    fig, ax = plt.subplots(1, 1, figsize=(2, 2))
    ax.imshow(image, cmap='gray')
    ax.set_title(f'Digit {digit}')
    ax.axis('off')
    plt.show()

    ### UNCOMMENT TO MAKE IT INTERACTIVE
    # return prompt_decision_callback(index, x, session)

    # Non-interactive
    return Label.ANOMALY if digit % 2 == 0 else Label.REGULAR

Create a model and a session.

[15]:
model = PineForest(
    # Number of trees to use for predictions
    n_trees=256,
    # Number of new tree to grow for each decision
    n_spare_trees=768,
    # Fix random seed for reproducibility
    random_seed=0,
)
session = Session(
    data=final,
    metadata=np.arange(len(final)),
    model=model,
    decision_callback=decision,
    on_decision_callbacks=[
        TerminateAfter(EXPERT_BUDGET),
    ],
)
session.run()
../_images/notebooks_mnist_31_0.png
../_images/notebooks_mnist_31_1.png
../_images/notebooks_mnist_31_2.png
../_images/notebooks_mnist_31_3.png
../_images/notebooks_mnist_31_4.png
../_images/notebooks_mnist_31_5.png
../_images/notebooks_mnist_31_6.png
../_images/notebooks_mnist_31_7.png
../_images/notebooks_mnist_31_8.png
../_images/notebooks_mnist_31_9.png
../_images/notebooks_mnist_31_10.png
../_images/notebooks_mnist_31_11.png
../_images/notebooks_mnist_31_12.png
../_images/notebooks_mnist_31_13.png
../_images/notebooks_mnist_31_14.png
../_images/notebooks_mnist_31_15.png
../_images/notebooks_mnist_31_16.png
../_images/notebooks_mnist_31_17.png
../_images/notebooks_mnist_31_18.png
../_images/notebooks_mnist_31_19.png
[15]:
<coniferest.session.Session at 0x7f8b7883be50>

Let’s see what we have selected

[16]:
n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
Anomalies: 11/20 (55.00%)

Let’s do the opposite: odd numbers are anomalies

[17]:
def decision(index, x, session):
    digit, image = digits[index], images[index]
    fig, ax = plt.subplots(1, 1, figsize=(2, 2))
    ax.imshow(image, cmap='gray')
    ax.set_title(f'Digit {digit}')
    ax.axis('off')
    plt.show()

    ### UNCOMMENT TO MAKE IT INTERACTIVE
    # return prompt_decision_callback(index, x, session)

    # Non-interactive
    return Label.ANOMALY if digit % 2 == 1 else Label.REGULAR

model = PineForest(
    # Number of trees to use for predictions
    n_trees=256,
    # Number of new tree to grow for each decision
    n_spare_trees=768,
    # Fix random seed for reproducibility
    random_seed=0,
)
session = Session(
    data=final,
    metadata=np.arange(len(final)),
    model=model,
    decision_callback=decision,
    on_decision_callbacks=[
        TerminateAfter(EXPERT_BUDGET),
    ],
)
session.run()

n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
../_images/notebooks_mnist_35_0.png
../_images/notebooks_mnist_35_1.png
../_images/notebooks_mnist_35_2.png
../_images/notebooks_mnist_35_3.png
../_images/notebooks_mnist_35_4.png
../_images/notebooks_mnist_35_5.png
../_images/notebooks_mnist_35_6.png
../_images/notebooks_mnist_35_7.png
../_images/notebooks_mnist_35_8.png
../_images/notebooks_mnist_35_9.png
../_images/notebooks_mnist_35_10.png
../_images/notebooks_mnist_35_11.png
../_images/notebooks_mnist_35_12.png
../_images/notebooks_mnist_35_13.png
../_images/notebooks_mnist_35_14.png
../_images/notebooks_mnist_35_15.png
../_images/notebooks_mnist_35_16.png
../_images/notebooks_mnist_35_17.png
../_images/notebooks_mnist_35_18.png
../_images/notebooks_mnist_35_19.png
Anomalies: 19/20 (95.00%)

Change decision function to make it interactive and try your own experiments. For example, say yes to weird sevens only