Source code for coniferest.session.callback

import click
import webbrowser

from typing import List, Optional

import numpy as np

from coniferest.datasets import Label


[docs] def prompt_decision_callback(metadata, data, session) -> Label: """ Prompt user to label the object as anomaly or regular. If user sends keyboard interrupt, terminate the session. """ try: result = click.confirm(f"Is {metadata} anomaly?") return Label.ANOMALY if result else Label.REGULAR except click.Abort: session.terminate() return Label.UNKNOWN
[docs] def viewer_decision_callback(metadata, data, session) -> Label: """ Open SNAD Viewer for ZTF DR object. Metadata must be ZTF DR object ID. """ url = "https://ztf.snad.space/view/{}".format(metadata) try: webbrowser.get().open_new_tab(url) except webbrowser.Error: click.echo("Check {} for details".format(url)) return prompt_decision_callback(metadata, data, session)
[docs] class TerminateAfter: """ Terminate session after given number of iterations. This callback to be used as "on decision callback": Session(..., on_decision_callbacks=[TerminateAfter(budget)]) Parameters ---------- budget : int Number of iterations after which session will be terminated. """ def __init__(self, budget: int): self.budget = budget self.iteration = 0 def __call__(self, metadata, data, session) -> None: self.iteration += 1 if self.iteration >= self.budget: session.terminate()
[docs] class TerminateAfterNAnomalies: """ Terminate session after given number of newly labeled anomalies. This callback to be used as "on decision callback": Session(..., on_decision_callbacks=[TerminateAfter(budget)]) Parameters ---------- budget : int Number of anomalies to stop after. """ def __init__(self, budget: int): self.budget = budget self.anomalies_count = 0 def __call__(self, label, _data, session) -> None: self.anomalies_count += label == Label.ANOMALY if self.anomalies_count >= self.budget: session.terminate()