MNIST dataset¶
This notebook gives an example of Active Anomaly Detection with coniferest
and MNIST dataset.
Developers of conferest
: - Matwey Kornilov (MSU) - Vladimir Korolev - Konstantin Malanchev (LINCC Frameworks / CMU), notebook author
[1]:
## Install and import the required libraries
[2]:
# Install packages
%pip install coniferest
%pip install datasets
Requirement already satisfied: coniferest in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (0.0.15)
Requirement already satisfied: click in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (8.1.7)
Requirement already satisfied: joblib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (1.4.2)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (2.2.0)
Requirement already satisfied: scikit-learn<2,>=1.4 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (1.5.2)
Requirement already satisfied: matplotlib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (3.9.3)
Requirement already satisfied: onnxconverter-common in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from coniferest) (1.14.0)
Requirement already satisfied: scipy>=1.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (1.14.1)
Requirement already satisfied: threadpoolctl>=3.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (3.5.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (4.55.2)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (24.2)
Requirement already satisfied: pillow>=8 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (11.0.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (3.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from matplotlib->coniferest) (2.9.0.post0)
Requirement already satisfied: onnx in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (1.17.0)
Requirement already satisfied: protobuf==3.20.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (3.20.2)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->coniferest) (1.17.0)
Note: you may need to restart the kernel to use updated packages.
Requirement already satisfied: datasets in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (3.1.0)
Requirement already satisfied: filelock in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (3.16.1)
Requirement already satisfied: numpy>=1.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (2.2.0)
Requirement already satisfied: pyarrow>=15.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (18.1.0)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (2.2.3)
Requirement already satisfied: requests>=2.32.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (2.32.3)
Requirement already satisfied: tqdm>=4.66.3 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (4.67.1)
Requirement already satisfied: xxhash in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (3.5.0)
Requirement already satisfied: multiprocess<0.70.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)
Requirement already satisfied: aiohttp in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (3.11.10)
Requirement already satisfied: huggingface-hub>=0.23.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (0.26.5)
Requirement already satisfied: packaging in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (24.2)
Requirement already satisfied: pyyaml>=5.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from datasets) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (2.4.4)
Requirement already satisfied: aiosignal>=1.1.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (0.2.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from aiohttp->datasets) (1.18.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from requests>=2.32.2->datasets) (2024.8.30)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from pandas->datasets) (2024.2)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/latest/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)
Note: you may need to restart the kernel to use updated packages.
[3]:
import datasets
import matplotlib.pyplot as plt
import numpy as np
from coniferest.isoforest import IsolationForest
from coniferest.pineforest import PineForest
from coniferest.session import Session
from coniferest.session.callback import TerminateAfter, prompt_decision_callback, Label
Download and load the MNIST dataset¶
Download data from Hugging Faces with datasets
library
[4]:
mnist = datasets.load_dataset("mnist")
Load the data into numpy arrays
[5]:
images_train = np.asarray(mnist['train']['image'])
images_test = np.asarray(mnist['test']['image'])
digits_train = np.asarray(mnist['train']['label'])
digits_test = np.asarray(mnist['test']['label'])
images = np.concatenate([images_train, images_test])
digits = np.concatenate([digits_train, digits_test])
Plot some examples
[6]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
ax[i//5, i%5].imshow(images[digits == i][0], cmap='gray')
ax[i//5, i%5].set_title(f'Digit {i}')
ax[i//5, i%5].axis('off')
Preprocess the data¶
Select the data to use: - image
: the original images - fft
: the power spectrum of the images - both
: the original images and the power spectrum together
[7]:
DATA = 'both' # 'image', 'fft', 'both'
Make 2-d FFT of the images
[8]:
# Make 2-d FFT of the images
data_fft = np.fft.fft2(images)
# Get power spectrum
power_spectrum = np.square(np.abs(data_fft))
# Normalize the power spectrum by zero frequency
power_spectrum = power_spectrum / power_spectrum[:, 0, 0][:, None, None]
Plot some examples of power spectrum
[9]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
ax[i//5, i%5].imshow(np.log(power_spectrum[digits == i][0]), cmap='gray')
ax[i//5, i%5].set_title(f'Digit {i}')
ax[i//5, i%5].axis('off')
Concatenate images and power spectrum
[10]:
if DATA == 'image':
final = np.asarray(images, dtype=np.float32)
elif DATA == 'fft':
final = np.asarray(power_spectrum.reshape(-1, 28 * 28), dtype=np.float32)
elif DATA == 'both':
final = np.concatenate([images.reshape(-1, 28 * 28), power_spectrum.reshape(-1, 28 * 28)], axis=1)
else:
raise ValueError(f"Unknown value for DATA: {DATA}")
Classic anomaly detection with Isolation forest¶
[11]:
model = IsolationForest(random_seed=10, n_trees=1000)
model.fit(np.array(final))
scores = model.score_samples(np.array(final))
ordered_index = np.argsort(scores)
ordered_digits = digits[ordered_index]
print(f"Top 10 weirdest digits : {ordered_digits[:10]}")
print(f"Top 10 most normal digits : {ordered_digits[-10:]}")
Top 10 weirdest digits : [6 5 5 5 3 7 5 5 9 5]
Top 10 most normal digits : [1 1 1 1 1 1 1 1 1 1]
Plot the top 10 weirdest digits and the top 10 most normal digits
[12]:
fig, ax = plt.subplots(2, 10, figsize=(20, 5))
for i in range(10):
ax[0, i].imshow(images[ordered_index[i]], cmap='gray')
ax[0, i].set_title(f'Digit {ordered_digits[i]}')
ax[0, i].axis('off')
ax[1, i].imshow(images[ordered_index[-i - 1]], cmap='gray')
ax[1, i].set_title(f'Digit {ordered_digits[-i - 1]}')
ax[1, i].axis('off')
fig.text(0.1, 0.9, 'Top 10 weirdest digits', ha='left', va='center', fontsize=16)
fig.text(0.1, 0.5, 'Top 10 most normal digits', ha='left', va='center', fontsize=16)
[12]:
Text(0.1, 0.5, 'Top 10 most normal digits')
Anomaly detection with PineForest¶
Set expert budget
[13]:
EXPERT_BUDGET = 20
First, we need a function which would show us an image, its label and ask us if it is an anomaly.
Let’s say that even numbers are anomalies
[14]:
def decision(index, x, session):
digit, image = digits[index], images[index]
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
ax.imshow(image, cmap='gray')
ax.set_title(f'Digit {digit}')
ax.axis('off')
plt.show()
### UNCOMMENT TO MAKE IT INTERACTIVE
# return prompt_decision_callback(index, x, session)
# Non-interactive
return Label.ANOMALY if digit % 2 == 0 else Label.REGULAR
Create a model and a session.
[15]:
model = PineForest(
# Number of trees to use for predictions
n_trees=256,
# Number of new tree to grow for each decision
n_spare_trees=768,
# Fix random seed for reproducibility
random_seed=0,
)
session = Session(
data=final,
metadata=np.arange(len(final)),
model=model,
decision_callback=decision,
on_decision_callbacks=[
TerminateAfter(EXPERT_BUDGET),
],
)
session.run()
[15]:
<coniferest.session.Session at 0x7f8b7883be50>
Let’s see what we have selected
[16]:
n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
Anomalies: 11/20 (55.00%)
Let’s do the opposite: odd numbers are anomalies
[17]:
def decision(index, x, session):
digit, image = digits[index], images[index]
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
ax.imshow(image, cmap='gray')
ax.set_title(f'Digit {digit}')
ax.axis('off')
plt.show()
### UNCOMMENT TO MAKE IT INTERACTIVE
# return prompt_decision_callback(index, x, session)
# Non-interactive
return Label.ANOMALY if digit % 2 == 1 else Label.REGULAR
model = PineForest(
# Number of trees to use for predictions
n_trees=256,
# Number of new tree to grow for each decision
n_spare_trees=768,
# Fix random seed for reproducibility
random_seed=0,
)
session = Session(
data=final,
metadata=np.arange(len(final)),
model=model,
decision_callback=decision,
on_decision_callbacks=[
TerminateAfter(EXPERT_BUDGET),
],
)
session.run()
n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
Anomalies: 19/20 (95.00%)
Change decision
function to make it interactive and try your own experiments. For example, say yes to weird sevens only