Source code for vak.config.learncurve

"""parses [LEARNCURVE] section of config"""
import attr
from attr import converters, validators

from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs
from .train import TrainConfig


[docs] @attr.s class LearncurveConfig(TrainConfig): """class that represents [LEARNCURVE] section of config.toml file Attributes ---------- model : str Model name, e.g., ``model = "TweetyNet"`` dataset_path : str Path to dataset, e.g., a csv file generated by running ``vak prep``. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. normalize_spectrograms : bool if True, use spect.utils.data.SpectScaler to normalize the spectrograms. Normalization is done by subtracting off the mean for each frequency bin of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data. ckpt_step : int step/epoch at which to save to checkpoint file. Default is None, in which case checkpoint is only saved at the last epoch. patience : int number of epochs to wait without the error dropping before stopping the training. Default is None, in which case training continues for num_epochs save_only_single_checkpoint_file : bool if True, save only one checkpoint file instead of separate files every time we save. Default is True. use_train_subsets_from_previous_run : bool if True, use training subsets saved in a previous run. Default is False. Requires setting previous_run_path option in config.toml file. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is ``vak.transforms.frame_labels.ToSegmentsWithPostProcessing`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote``, a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. """ post_tfm_kwargs = attr.ib( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, )