from typing import Callable, Dict, Optional
import numpy as np
from coniferest.coniferest import Coniferest
from coniferest.pineforest import PineForest
from .callback import prompt_decision_callback
from ..label import Label
[docs]
class Session:
"""
Active anomaly detection session
Parameters
----------
data : array-like, shape (n_samples, n_features), dtype is number
2-D array of data points
metadata : array-like, shape (n_samples,), dtype is any
1-D array of metadata for each data point
decision_callback : callable, optional
Function to be called when expert decision is required, it must return
`Label` object with the decision and may terminate the session via
`Session.terminate()`. Default is `prompt_decision_callback`
Signature: '(metadata, data, session) -> Label', where metadata is
metadata of the object to be labeled, data is data of the object to be
labeled, session is this session instance.
on_refit_callbacks : list of callable, or callable, or None, optional
Functions to be called when model is refitted (before
"decision_callback"), default is empty list. This function may call
`Session.terminate()`.
Signature: '(session) -> None', where session is this session instance.
on_decision_callbacks : list of callable, or callable, or None, optional
Functions to be called when expert decision is made (after
"decision_callback"), default is empty list. This function may call
`Session.terminate()`.
Signature: '(metadata, data, session) -> None', where metadata is
metadata of the object has just been labeled, data is data of this
object, session is this session instance.
known_labels : dict, optional
Dictionary of known anomaly labels, keys are data/metadata indices,
values are labels of type `Label`. Default is empty dictionary.
model : Coniferest or None, optional
Anomaly detection model to use, default is `PineForest()`.
Attributes
----------
current : int
Index of the last anomaly candidate
last_decision : Label or None
Label of the last anomaly candidate or None if no decision was made
scores : array-like, shape (n_samples,)
Current anomaly scores for all data points
terminated : bool
True if session is terminated
known_labels : dict[int, Label]
Current dictionary of known anomaly labels
known_anomalies : array-like
Array of indices of known anomalies
known_regulars : array-like
Array of indices of known regular objects
known_unknowns : array-like
Array of indices of known objects marked with `Label::UNKNOWN`
model : Coniferest
Anomaly detection model used
Examples
--------
>>> from coniferest.datasets import ztf_m31
>>> from coniferest.session import Label, Session
>>> data, metadata = ztf_m31()
>>> s = Session(
... data=data,
... metadata=metadata,
... decision_callback=lambda *_: Label.ANOMALY,
... on_decision_callbacks=[lambda _metadata, _data, session: session.terminate()],
... )
>>> _ = s.run()
>>> assert len(s.known_labels) == len(s.known_anomalies) == 1
"""
@staticmethod
def _prepare_callbacks(input_argument):
if input_argument is None:
callbacks = []
elif isinstance(input_argument, list):
callbacks = input_argument
else:
callbacks = [input_argument, ]
if not all([isinstance(cb, Callable) for cb in callbacks]):
raise ValueError("At least one of the callbacks is not callable")
return callbacks
@staticmethod
def _invoke_callbacks(callbacks, *args, **kwargs):
for cb in callbacks:
cb(*args, **kwargs)
def __init__(self, data, metadata, decision_callback = prompt_decision_callback, *, on_refit_callbacks = None, on_decision_callbacks = None, known_labels: Dict[int, Label] = None, model: Coniferest = None):
self._data = np.atleast_2d(data)
self._metadata = np.atleast_1d(metadata)
if not isinstance(decision_callback, Callable):
raise ValueError("decision_callback is not a callable")
self._decision_cb = decision_callback
try:
self._on_refit_cb = self._prepare_callbacks(on_refit_callbacks)
except ValueError:
raise ValueError("on_refit_callbacks contains not callable object")
try:
self._on_decision_cb = self._prepare_callbacks(on_decision_callbacks)
except ValueError:
raise ValueError("on_decision_callbacks contains not callable object")
if known_labels is None:
self._known_labels = {}
else:
self._known_labels = dict(known_labels)
if model is None:
model = PineForest()
if not isinstance(model, Coniferest):
raise ValueError("model is not a Coniferest object")
self._model = model
self._scores = None
self._current = None
self._terminated = False
[docs]
def run(self) -> 'Session':
"""Evaluate interactive anomaly detection session"""
if self._terminated:
raise RuntimeError("Session is already terminated")
self.model.fit(self._data)
while not self._terminated:
known_data = self._data[list(self._known_labels.keys())]
known_labels = np.fromiter(self._known_labels.values(), dtype=int, count=len(self._known_labels))
self.model.fit_known(self._data, known_data, known_labels)
self._invoke_callbacks(self._on_refit_cb, self)
self._scores = self.model.score_samples(self._data)
self._current = None
for ind in np.argsort(self._scores):
if ind not in self._known_labels:
self._current = ind
break
if self._current is None:
self.terminate()
break
decision = self._decision_cb(self._metadata[self._current], self._data[self._current], self)
self._known_labels[self._current] = decision
self._invoke_callbacks(self._on_decision_cb, self._metadata[self._current], self._data[self._current], self)
return self
[docs]
def terminate(self) -> None:
self._terminated = True
@property
def current(self) -> int:
return self._current
@property
def last_decision(self) -> Optional[Label]:
return self._known_labels.get(self._current, None)
@property
def scores(self) -> np.ndarray:
return self._scores
@property
def known_labels(self) -> Dict[int, Label]:
return self._known_labels
@property
def known_anomalies(self) -> np.ndarray:
return np.array([idx for idx, label in self._known_labels.items() if label == Label.ANOMALY])
@property
def known_regulars(self) -> np.ndarray:
return np.array([idx for idx, label in self._known_labels.items() if label == Label.REGULAR])
@property
def known_unknowns(self) -> np.ndarray:
return np.array([idx for idx, label in self._known_labels.items() if label == Label.UNKNOWN])
@property
def model(self) -> Coniferest:
return self._model
@property
def terminated(self) -> bool:
return self._terminated