Source code for vak.transforms.defaults.get

"""Helper function that gets default transforms for a model."""

from __future__ import annotations

from typing import Callable, Literal

from ... import models
from . import frame_classification, parametric_umap


[docs] def get_default_transform( model_name: str, mode: Literal["eval", "predict", "train"], transform_kwargs: dict | None = None, ) -> Callable: """Get default transform for a model, according to its family and what mode the model is being used in. Parameters ---------- model_name : str Name of model. mode : str One of {'eval', 'predict', 'train'}. Returns ------- item_transform : callable Transform to be applied to input :math:`x` to a model and, during training, the target :math:`y`. """ try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: raise ValueError( f"No model family found for the model name specified: {model_name}" ) from e if model_family == "FrameClassificationModel": return frame_classification.get_default_frame_classification_transform( mode, transform_kwargs ) elif model_family == "ParametricUMAPModel": return parametric_umap.get_default_parametric_umap_transform( transform_kwargs )