MNIST dataset¶
This notebook gives an example of Active Anomaly Detection with coniferest
and MNIST dataset.
Developers of conferest
: - Matwey Kornilov (MSU) - Vladimir Korolev - Konstantin Malanchev (LINCC Frameworks / CMU), notebook author
[1]:
## Install and import the required libraries
[2]:
# Install packages
%pip install coniferest
%pip install datasets
Requirement already satisfied: coniferest in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (0.0.14)
Requirement already satisfied: click in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (8.1.7)
Requirement already satisfied: joblib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.4.2)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.26.4)
Requirement already satisfied: scikit-learn<2,>=1.4 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.4.2)
Requirement already satisfied: matplotlib in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (3.9.0)
Requirement already satisfied: onnxconverter-common in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from coniferest) (1.14.0)
Requirement already satisfied: scipy>=1.6.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (1.13.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from scikit-learn<2,>=1.4->coniferest) (3.5.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (24.0)
Requirement already satisfied: pillow>=8 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (10.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from matplotlib->coniferest) (2.9.0.post0)
Requirement already satisfied: onnx in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (1.16.0)
Requirement already satisfied: protobuf==3.20.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from onnxconverter-common->coniferest) (3.20.2)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->coniferest) (1.16.0)
Note: you may need to restart the kernel to use updated packages.
Collecting datasets
Downloading datasets-2.19.1-py3-none-any.whl.metadata (19 kB)
Collecting filelock (from datasets)
Downloading filelock-3.14.0-py3-none-any.whl.metadata (2.8 kB)
Requirement already satisfied: numpy>=1.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=12.0.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (16.1.0)
Collecting pyarrow-hotfix (from datasets)
Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (2.2.2)
Requirement already satisfied: requests>=2.19.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (2.31.0)
Collecting tqdm>=4.62.1 (from datasets)
Downloading tqdm-4.66.4-py3-none-any.whl.metadata (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.6/57.6 kB 3.8 MB/s eta 0:00:00
Collecting xxhash (from datasets)
Downloading xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.3.1,>=2023.1.0 (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets)
Downloading fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)
Collecting aiohttp (from datasets)
Downloading aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.5 kB)
Collecting huggingface-hub>=0.21.2 (from datasets)
Downloading huggingface_hub-0.23.0-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: packaging in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from datasets) (6.0.1)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets)
Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)
Requirement already satisfied: attrs>=17.3.0 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from aiohttp->datasets) (23.2.0)
Collecting frozenlist>=1.1.1 (from aiohttp->datasets)
Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)
Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting yarl<2.0,>=1.0 (from aiohttp->datasets)
Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (31 kB)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from huggingface-hub>=0.21.2->datasets) (4.11.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from requests>=2.19.0->datasets) (2024.2.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: six>=1.5 in /home/docs/checkouts/readthedocs.org/user_builds/coniferest/envs/v0.0.14/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 542.0/542.0 kB 18.3 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 4.9 MB/s eta 0:00:00
Downloading fsspec-2024.3.1-py3-none-any.whl (171 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 172.0/172.0 kB 16.5 MB/s eta 0:00:00
Downloading aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 40.9 MB/s eta 0:00:00
Downloading huggingface_hub-0.23.0-py3-none-any.whl (401 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 401.2/401.2 kB 28.3 MB/s eta 0:00:00
Downloading tqdm-4.66.4-py3-none-any.whl (78 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.3/78.3 kB 5.1 MB/s eta 0:00:00
Downloading filelock-3.14.0-py3-none-any.whl (12 kB)
Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 143.5/143.5 kB 9.0 MB/s eta 0:00:00
Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Downloading xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.8/194.8 kB 13.1 MB/s eta 0:00:00
Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Downloading frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (272 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 272.3/272.3 kB 20.2 MB/s eta 0:00:00
Downloading multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (128 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 128.7/128.7 kB 4.8 MB/s eta 0:00:00
Downloading yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (328 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 328.1/328.1 kB 26.0 MB/s eta 0:00:00
Installing collected packages: xxhash, tqdm, pyarrow-hotfix, multidict, fsspec, frozenlist, filelock, dill, yarl, multiprocess, huggingface-hub, aiosignal, aiohttp, datasets
Successfully installed aiohttp-3.9.5 aiosignal-1.3.1 datasets-2.19.1 dill-0.3.8 filelock-3.14.0 frozenlist-1.4.1 fsspec-2024.3.1 huggingface-hub-0.23.0 multidict-6.0.5 multiprocess-0.70.16 pyarrow-hotfix-0.6 tqdm-4.66.4 xxhash-3.4.1 yarl-1.9.4
Note: you may need to restart the kernel to use updated packages.
[3]:
import datasets
import matplotlib.pyplot as plt
import numpy as np
from coniferest.isoforest import IsolationForest
from coniferest.pineforest import PineForest
from coniferest.session import Session
from coniferest.session.callback import TerminateAfter, prompt_decision_callback, Label
Download and load the MNIST dataset¶
Download data from Hugging Faces with datasets
library
[4]:
mnist = datasets.load_dataset("mnist")
Load the data into numpy arrays
[5]:
images_train = np.asarray(mnist['train']['image'])
images_test = np.asarray(mnist['test']['image'])
digits_train = np.asarray(mnist['train']['label'])
digits_test = np.asarray(mnist['test']['label'])
images = np.concatenate([images_train, images_test])
digits = np.concatenate([digits_train, digits_test])
Plot some examples
[6]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
ax[i//5, i%5].imshow(images[digits == i][0], cmap='gray')
ax[i//5, i%5].set_title(f'Digit {i}')
ax[i//5, i%5].axis('off')

Preprocess the data¶
Select the data to use: - image
: the original images - fft
: the power spectrum of the images - both
: the original images and the power spectrum together
[7]:
DATA = 'both' # 'image', 'fft', 'both'
Make 2-d FFT of the images
[8]:
# Make 2-d FFT of the images
data_fft = np.fft.fft2(images)
# Get power spectrum
power_spectrum = np.square(np.abs(data_fft))
# Normalize the power spectrum by zero frequency
power_spectrum = power_spectrum / power_spectrum[:, 0, 0][:, None, None]
Plot some examples of power spectrum
[9]:
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i in range(10):
ax[i//5, i%5].imshow(np.log(power_spectrum[digits == i][0]), cmap='gray')
ax[i//5, i%5].set_title(f'Digit {i}')
ax[i//5, i%5].axis('off')

Concatenate images and power spectrum
[10]:
if DATA == 'image':
final = np.asarray(images, dtype=np.float32)
elif DATA == 'fft':
final = np.asarray(power_spectrum.reshape(-1, 28 * 28), dtype=np.float32)
elif DATA == 'both':
final = np.concatenate([images.reshape(-1, 28 * 28), power_spectrum.reshape(-1, 28 * 28)], axis=1)
else:
raise ValueError(f"Unknown value for DATA: {DATA}")
Classic anomaly detection with Isolation forest¶
[11]:
model = IsolationForest(random_seed=10, n_trees=1000)
model.fit(np.array(final))
scores = model.score_samples(np.array(final))
ordered_index = np.argsort(scores)
ordered_digits = digits[ordered_index]
print(f"Top 10 weirdest digits : {ordered_digits[:10]}")
print(f"Top 10 most normal digits : {ordered_digits[-10:]}")
Top 10 weirdest digits : [6 5 5 5 3 7 5 5 9 5]
Top 10 most normal digits : [1 1 1 1 1 1 1 1 1 1]
Plot the top 10 weirdest digits and the top 10 most normal digits
[12]:
fig, ax = plt.subplots(2, 10, figsize=(20, 5))
for i in range(10):
ax[0, i].imshow(images[ordered_index[i]], cmap='gray')
ax[0, i].set_title(f'Digit {ordered_digits[i]}')
ax[0, i].axis('off')
ax[1, i].imshow(images[ordered_index[-i - 1]], cmap='gray')
ax[1, i].set_title(f'Digit {ordered_digits[-i - 1]}')
ax[1, i].axis('off')
fig.text(0.1, 0.9, 'Top 10 weirdest digits', ha='left', va='center', fontsize=16)
fig.text(0.1, 0.5, 'Top 10 most normal digits', ha='left', va='center', fontsize=16)
[12]:
Text(0.1, 0.5, 'Top 10 most normal digits')

Anomaly detection with PineForest¶
Set expert budget
[13]:
EXPERT_BUDGET = 20
First, we need a function which would show us an image, its label and ask us if it is an anomaly.
Let’s say that even numbers are anomalies
[14]:
def decision(index, x, session):
digit, image = digits[index], images[index]
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
ax.imshow(image, cmap='gray')
ax.set_title(f'Digit {digit}')
ax.axis('off')
plt.show()
### UNCOMMENT TO MAKE IT INTERACTIVE
# return prompt_decision_callback(index, x, session)
# Non-interactive
return Label.ANOMALY if digit % 2 == 0 else Label.REGULAR
Create a model and a session.
[15]:
model = PineForest(
# Number of trees to use for predictions
n_trees=256,
# Number of new tree to grow for each decision
n_spare_trees=768,
# Fix random seed for reproducibility
random_seed=0,
)
session = Session(
data=final,
metadata=np.arange(len(final)),
model=model,
decision_callback=decision,
on_decision_callbacks=[
TerminateAfter(EXPERT_BUDGET),
],
)
session.run()




















[15]:
<coniferest.session.Session at 0x7f8468b39bd0>
Let’s see what we have selected
[16]:
n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")
Anomalies: 11/20 (55.00%)
Let’s do the opposite: odd numbers are anomalies
[17]:
def decision(index, x, session):
digit, image = digits[index], images[index]
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
ax.imshow(image, cmap='gray')
ax.set_title(f'Digit {digit}')
ax.axis('off')
plt.show()
### UNCOMMENT TO MAKE IT INTERACTIVE
# return prompt_decision_callback(index, x, session)
# Non-interactive
return Label.ANOMALY if digit % 2 == 1 else Label.REGULAR
model = PineForest(
# Number of trees to use for predictions
n_trees=256,
# Number of new tree to grow for each decision
n_spare_trees=768,
# Fix random seed for reproducibility
random_seed=0,
)
session = Session(
data=final,
metadata=np.arange(len(final)),
model=model,
decision_callback=decision,
on_decision_callbacks=[
TerminateAfter(EXPERT_BUDGET),
],
)
session.run()
n_anomalies = len(session.known_anomalies)
n_total = len(session.known_labels)
print(f"Anomalies: {n_anomalies}/{n_total} ({n_anomalies/n_total:.2%})")




















Anomalies: 17/20 (85.00%)
Change decision
function to make it interactive and try your own experiments. For example, say yes to weird sevens only