Source code for vak.config.learncurve
"""Class that represents ``[vak.learncurve]`` table in configuration file."""
from __future__ import annotations
from attrs import converters, define, field, validators
from .dataset import DatasetConfig
from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs
from .model import ModelConfig
from .train import TrainConfig
from .trainer import TrainerConfig
REQUIRED_KEYS = (
"dataset",
"model",
"root_results_dir",
"trainer",
)
[docs]
@define
class LearncurveConfig(TrainConfig):
"""Class that represents ``[vak.learncurve]`` table in configuration file.
Attributes
----------
model : vak.config.ModelConfig
The model to use: its name,
and the parameters to configure it.
Must be an instance of :class:`vak.config.ModelConfig`
num_epochs : int
number of training epochs. One epoch = one iteration through the entire
training set.
batch_size : int
number of samples per batch presented to models during training.
root_results_dir : str
directory in which results will be created.
The vak.cli.train function will create
a subdirectory in this directory each time it runs.
dataset : vak.config.DatasetConfig
The dataset to use: the path to it,
and optionally a path to a file representing splits,
and the name, if it is a built-in dataset.
Must be an instance of :class:`vak.config.DatasetConfig`.
trainer : vak.config.TrainerConfig
Configuration for :class:`lightning.pytorch.Trainer`.
Must be an instance of :class:`vak.config.TrainerConfig`.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
standardize_frames : bool
if True, use :class:`vak.transforms.FramesStandardizer` to standardize the frames.
Normalization is done by subtracting off the mean for each row
of the training set and then dividing by the std for that frequency bin.
This same normalization is then applied to validation + test data.
val_step : int
Step on which to estimate accuracy using validation set.
If val_step is n, then validation is carried out every time
the global step / n is a whole number, i.e., when val_step modulo the global step is 0.
Default is None, in which case no validation is done.
ckpt_step : int
step/epoch at which to save to checkpoint file.
Default is None, in which case checkpoint is only saved at the last epoch.
patience : int
number of epochs to wait without the error dropping before stopping the
training. Default is None, in which case training continues for num_epochs
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior.
The transform used is
``vak.transforms.frame_labels.ToSegmentsWithPostProcessing`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
"""
post_tfm_kwargs = field(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
default=None,
)
# we over-ride this method from TrainConfig mainly so the docstring is correct.
# TODO: can we do this by just over-writing `__doc__` for the method on this class?
[docs]
@classmethod
def from_config_dict(cls, config_dict: dict) -> LearncurveConfig:
"""Return :class:`LearncurveConfig` instance from a :class:`dict`.
The :class:`dict` passed in should be the one found
by loading a valid configuration toml file with
:func:`vak.config.parse.from_toml_path`,
and then using key ``learncurve``,
i.e., ``LearncurveConfig.from_config_dict(config_dict['learncurve'])``.
"""
for required_key in REQUIRED_KEYS:
if required_key not in config_dict:
raise KeyError(
"The `[vak.learncurve]` table in a configuration file requires "
f"the option '{required_key}', but it was not found "
"when loading the configuration file into a Python dictionary. "
"Please check that the configuration file is formatted correctly."
)
config_dict["model"] = ModelConfig.from_config_dict(
config_dict["model"]
)
config_dict["dataset"] = DatasetConfig.from_config_dict(
config_dict["dataset"]
)
config_dict["trainer"] = TrainerConfig(**config_dict["trainer"])
return cls(**config_dict)