Source code for vak.config.eval

"""parses [EVAL] section of config"""
import attr
from attr import converters, validators
from attr.validators import instance_of

from .validators import is_valid_model_name
from .. import device
from ..converters import comma_separated_list, expanded_user_path


def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
    post_tfm_kwargs = dict(post_tfm_kwargs)

    if 'min_segment_dur' not in post_tfm_kwargs:
        # because there's no null in TOML,
        # users leave arg out of config then we set it to None
        post_tfm_kwargs['min_segment_dur'] = None
    else:
        post_tfm_kwargs['min_segment_dur'] = float(post_tfm_kwargs['min_segment_dur'])

    if 'majority_vote' not in post_tfm_kwargs:
        # set default for this one too
        post_tfm_kwargs['majority_vote'] = False
    else:
        post_tfm_kwargs['majority_vote'] = bool(post_tfm_kwargs['majority_vote'])

    return post_tfm_kwargs


def are_valid_post_tfm_kwargs(instance, attribute, value):
    """check if ``post_tfm_kwargs`` is valid"""
    if not isinstance(value, dict):
        raise TypeError(
            "'post_tfm_kwargs' should be declared in toml config as an inline table "
            f"that parses as a dict, but type was: {type(value)}. "
            "Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`"
        )
    if any(
        [k not in {'majority_vote', 'min_segment_dur'} for k in value.keys()]
    ):
        invalid_kwargs = [k for k in value.keys()
                          if k not in {'majority_vote', 'min_segment_dur'}]
        raise ValueError(
            f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}."
            "Valid names are: {'majority_vote', 'min_segment_dur'}"
        )
    if 'majority_vote' in value:
        if not isinstance(value['majority_vote'], bool):
            raise TypeError(
                "'post_tfm_kwargs' keyword argument 'majority_vote' "
                f"should be of type bool but was: {type(value['majority_vote'])}"
            )
    if 'min_segment_dur' in value:
        if value['min_segment_dur'] and not isinstance(value['min_segment_dur'], float):
            raise TypeError(
                "'post_tfm_kwargs' keyword argument 'min_segment_dur' type "
                f"should be float but was: {type(value['min_segment_dur'])}"
            )


[docs]@attr.s class EvalConfig: """class that represents [EVAL] section of config.toml file Attributes ---------- csv_path : str path to where dataset was saved as a csv. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model output_dir : str Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str path to 'labelmap.json' file. models : list of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' batch_size : int number of samples per batch presented to models during training. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is ``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote``, a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. """ # required, external files checkpoint_path = attr.ib(converter=expanded_user_path) labelmap_path = attr.ib(converter=expanded_user_path) output_dir = attr.ib(converter=expanded_user_path) # required, model / dataloader models = attr.ib( converter=comma_separated_list, validator=[instance_of(list), is_valid_model_name], ) batch_size = attr.ib(converter=int, validator=instance_of(int)) # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add csv_path after it creates the csv csv_path = attr.ib( converter=converters.optional(expanded_user_path), default=None, ) # optional, transform spect_scaler_path = attr.ib( converter=converters.optional(expanded_user_path), default=None, ) post_tfm_kwargs = attr.ib( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, # empty dict so we can pass into transform with **kwargs expansion ) # optional, data loader num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default())