"""Class and functions for ``[vak.eval]`` table in configuration file."""
from __future__ import annotations
import pathlib
from attrs import converters, define, field, validators
from attrs.validators import instance_of
from ..common.converters import expanded_user_path
from .dataset import DatasetConfig
from .model import ModelConfig
from .trainer import TrainerConfig
[docs]
def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
post_tfm_kwargs = dict(post_tfm_kwargs)
if "min_segment_dur" not in post_tfm_kwargs:
# because there's no null in TOML,
# users leave arg out of config then we set it to None
post_tfm_kwargs["min_segment_dur"] = None
else:
post_tfm_kwargs["min_segment_dur"] = float(
post_tfm_kwargs["min_segment_dur"]
)
if "majority_vote" not in post_tfm_kwargs:
# set default for this one too
post_tfm_kwargs["majority_vote"] = False
else:
post_tfm_kwargs["majority_vote"] = bool(
post_tfm_kwargs["majority_vote"]
)
return post_tfm_kwargs
[docs]
def are_valid_post_tfm_kwargs(instance, attribute, value):
"""check if ``post_tfm_kwargs`` is valid"""
if not isinstance(value, dict):
raise TypeError(
"'post_tfm_kwargs' should be declared in toml config as an inline table "
f"that parses as a dict, but type was: {type(value)}. "
"Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`"
)
if any(
[k not in {"majority_vote", "min_segment_dur"} for k in value.keys()]
):
invalid_kwargs = [
k
for k in value.keys()
if k not in {"majority_vote", "min_segment_dur"}
]
raise ValueError(
f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}."
"Valid names are: {'majority_vote', 'min_segment_dur'}"
)
if "majority_vote" in value:
if not isinstance(value["majority_vote"], bool):
raise TypeError(
"'post_tfm_kwargs' keyword argument 'majority_vote' "
f"should be of type bool but was: {type(value['majority_vote'])}"
)
if "min_segment_dur" in value:
if value["min_segment_dur"] and not isinstance(
value["min_segment_dur"], float
):
raise TypeError(
"'post_tfm_kwargs' keyword argument 'min_segment_dur' type "
f"should be float but was: {type(value['min_segment_dur'])}"
)
REQUIRED_KEYS = (
"checkpoint_path",
"dataset",
"output_dir",
"model",
"trainer",
)
[docs]
@define
class EvalConfig:
"""Class that represents [vak.eval] table in configuration file.
Attributes
----------
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
output_dir : str
Path to location where .csv files with evaluation metrics should be saved.
model : vak.config.ModelConfig
The model to use: its name,
and the parameters to configure it.
Must be an instance of :class:`vak.config.ModelConfig`
batch_size : int
number of samples per batch presented to models during training.
dataset : vak.config.DatasetConfig
The dataset to use: the path to it,
and optionally a path to a file representing splits,
and the name, if it is a built-in dataset.
Must be an instance of :class:`vak.config.DatasetConfig`.
trainer : vak.config.TrainerConfig
Configuration for :class:`lightning.pytorch.Trainer`.
Must be an instance of :class:`vak.config.TrainerConfig`.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader. Default is 2.
labelmap_path : str
path to 'labelmap.json' file.
frames_standardizer_path : str
path to a saved :class:`vak.transforms.FramesStandardizer` object used to standardize (normalize) frames.
If spectrograms were normalized and this is not provided, will give
incorrect results.
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior.
The transform used is
``vak.transforms.frame_labels.PostProcess`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
"""
# required, external files
checkpoint_path: pathlib.Path = field(converter=expanded_user_path)
output_dir: pathlib.Path = field(converter=expanded_user_path)
# required, model / dataloader
model = field(
validator=instance_of(ModelConfig),
)
batch_size = field(converter=int, validator=instance_of(int))
dataset: DatasetConfig = field(
validator=instance_of(DatasetConfig),
)
trainer: TrainerConfig = field(
validator=instance_of(TrainerConfig),
)
# "optional" but actually required for frame classification models
# TODO: check model family in __post_init__ and raise ValueError if labelmap
# TODO: not specified for a frame classification model?
labelmap_path = field(
converter=converters.optional(expanded_user_path), default=None
)
# optional, transform
frames_standardizer_path = field(
converter=converters.optional(expanded_user_path),
default=None,
)
post_tfm_kwargs = field(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
default=None,
)
# optional, data loader
num_workers = field(validator=instance_of(int), default=2)
[docs]
@classmethod
def from_config_dict(cls, config_dict: dict) -> EvalConfig:
"""Return :class:`EvalConfig` instance from a :class:`dict`.
The :class:`dict` passed in should be the one found
by loading a valid configuration toml file with
:func:`vak.config.parse.from_toml_path`,
and then using key ``eval``,
i.e., ``EvalConfig.from_config_dict(config_dict['eval'])``."""
for required_key in REQUIRED_KEYS:
if required_key not in config_dict:
raise KeyError(
"The `[vak.eval]` table in a configuration file requires "
f"the option '{required_key}', but it was not found "
"when loading the configuration file into a Python dictionary. "
"Please check that the configuration file is formatted correctly."
)
config_dict["dataset"] = DatasetConfig.from_config_dict(
config_dict["dataset"]
)
config_dict["model"] = ModelConfig.from_config_dict(
config_dict["model"]
)
config_dict["trainer"] = TrainerConfig(**config_dict["trainer"])
return cls(**config_dict)