Source code for vak.config.eval

"""parses [EVAL] section of config"""
import attr
from attr import converters, validators
from attr.validators import instance_of

from ..common import device
from ..common.converters import expanded_user_path
from .validators import is_valid_model_name


[docs] def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: post_tfm_kwargs = dict(post_tfm_kwargs) if "min_segment_dur" not in post_tfm_kwargs: # because there's no null in TOML, # users leave arg out of config then we set it to None post_tfm_kwargs["min_segment_dur"] = None else: post_tfm_kwargs["min_segment_dur"] = float( post_tfm_kwargs["min_segment_dur"] ) if "majority_vote" not in post_tfm_kwargs: # set default for this one too post_tfm_kwargs["majority_vote"] = False else: post_tfm_kwargs["majority_vote"] = bool( post_tfm_kwargs["majority_vote"] ) return post_tfm_kwargs
[docs] def are_valid_post_tfm_kwargs(instance, attribute, value): """check if ``post_tfm_kwargs`` is valid""" if not isinstance(value, dict): raise TypeError( "'post_tfm_kwargs' should be declared in toml config as an inline table " f"that parses as a dict, but type was: {type(value)}. " "Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`" ) if any( [k not in {"majority_vote", "min_segment_dur"} for k in value.keys()] ): invalid_kwargs = [ k for k in value.keys() if k not in {"majority_vote", "min_segment_dur"} ] raise ValueError( f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}." "Valid names are: {'majority_vote', 'min_segment_dur'}" ) if "majority_vote" in value: if not isinstance(value["majority_vote"], bool): raise TypeError( "'post_tfm_kwargs' keyword argument 'majority_vote' " f"should be of type bool but was: {type(value['majority_vote'])}" ) if "min_segment_dur" in value: if value["min_segment_dur"] and not isinstance( value["min_segment_dur"], float ): raise TypeError( "'post_tfm_kwargs' keyword argument 'min_segment_dur' type " f"should be float but was: {type(value['min_segment_dur'])}" )
[docs] @attr.s class EvalConfig: """class that represents [EVAL] section of config.toml file Attributes ---------- dataset_path : str Path to dataset, e.g., a csv file generated by running ``vak prep``. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model output_dir : str Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str path to 'labelmap.json' file. model : str Model name, e.g., ``model = "TweetyNet"`` batch_size : int number of samples per batch presented to models during training. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is ``vak.transforms.frame_labels.PostProcess`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote``, a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. transform_params: dict, optional Parameters for data transform. Passed as keyword arguments. Optional, default is None. dataset_params: dict, optional Parameters for dataset. Passed as keyword arguments. Optional, default is None. """ # required, external files checkpoint_path = attr.ib(converter=expanded_user_path) output_dir = attr.ib(converter=expanded_user_path) # required, model / dataloader model = attr.ib( validator=[instance_of(str), is_valid_model_name], ) batch_size = attr.ib(converter=int, validator=instance_of(int)) # dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add dataset_path after it creates the csv dataset_path = attr.ib( converter=converters.optional(expanded_user_path), default=None, ) # "optional" but actually required for frame classification models # TODO: check model family in __post_init__ and raise ValueError if labelmap # TODO: not specified for a frame classification model? labelmap_path = attr.ib( converter=converters.optional(expanded_user_path), default=None ) # optional, transform spect_scaler_path = attr.ib( converter=converters.optional(expanded_user_path), default=None, ) post_tfm_kwargs = attr.ib( validator=validators.optional(are_valid_post_tfm_kwargs), converter=converters.optional(convert_post_tfm_kwargs), default=None, ) # optional, data loader num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default()) transform_params = attr.ib( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, ) dataset_params = attr.ib( converter=converters.optional(dict), validator=validators.optional(instance_of(dict)), default=None, )