"""parses [EVAL] section of config"""
import attr
from attr import converters, validators
from attr.validators import instance_of
from ..common import device
from ..common.converters import expanded_user_path
from .validators import is_valid_model_name
[docs]
def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
post_tfm_kwargs = dict(post_tfm_kwargs)
if "min_segment_dur" not in post_tfm_kwargs:
# because there's no null in TOML,
# users leave arg out of config then we set it to None
post_tfm_kwargs["min_segment_dur"] = None
else:
post_tfm_kwargs["min_segment_dur"] = float(
post_tfm_kwargs["min_segment_dur"]
)
if "majority_vote" not in post_tfm_kwargs:
# set default for this one too
post_tfm_kwargs["majority_vote"] = False
else:
post_tfm_kwargs["majority_vote"] = bool(
post_tfm_kwargs["majority_vote"]
)
return post_tfm_kwargs
[docs]
def are_valid_post_tfm_kwargs(instance, attribute, value):
"""check if ``post_tfm_kwargs`` is valid"""
if not isinstance(value, dict):
raise TypeError(
"'post_tfm_kwargs' should be declared in toml config as an inline table "
f"that parses as a dict, but type was: {type(value)}. "
"Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`"
)
if any(
[k not in {"majority_vote", "min_segment_dur"} for k in value.keys()]
):
invalid_kwargs = [
k
for k in value.keys()
if k not in {"majority_vote", "min_segment_dur"}
]
raise ValueError(
f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}."
"Valid names are: {'majority_vote', 'min_segment_dur'}"
)
if "majority_vote" in value:
if not isinstance(value["majority_vote"], bool):
raise TypeError(
"'post_tfm_kwargs' keyword argument 'majority_vote' "
f"should be of type bool but was: {type(value['majority_vote'])}"
)
if "min_segment_dur" in value:
if value["min_segment_dur"] and not isinstance(
value["min_segment_dur"], float
):
raise TypeError(
"'post_tfm_kwargs' keyword argument 'min_segment_dur' type "
f"should be float but was: {type(value['min_segment_dur'])}"
)
[docs]
@attr.s
class EvalConfig:
"""class that represents [EVAL] section of config.toml file
Attributes
----------
dataset_path : str
Path to dataset, e.g., a csv file generated by running ``vak prep``.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
output_dir : str
Path to location where .csv files with evaluation metrics should be saved.
labelmap_path : str
path to 'labelmap.json' file.
model : str
Model name, e.g., ``model = "TweetyNet"``
batch_size : int
number of samples per batch presented to models during training.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader. Default is 2.
device : str
Device on which to work with model + data.
Defaults to 'cuda' if torch.cuda.is_available is True.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
If spectrograms were normalized and this is not provided, will give
incorrect results.
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior.
The transform used is
``vak.transforms.frame_labels.PostProcess`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
transform_params: dict, optional
Parameters for data transform.
Passed as keyword arguments.
Optional, default is None.
dataset_params: dict, optional
Parameters for dataset.
Passed as keyword arguments.
Optional, default is None.
"""
# required, external files
checkpoint_path = attr.ib(converter=expanded_user_path)
output_dir = attr.ib(converter=expanded_user_path)
# required, model / dataloader
model = attr.ib(
validator=[instance_of(str), is_valid_model_name],
)
batch_size = attr.ib(converter=int, validator=instance_of(int))
# dataset_path is actually 'required' but we can't enforce that here because cli.prep looks at
# what sections are defined to figure out where to add dataset_path after it creates the csv
dataset_path = attr.ib(
converter=converters.optional(expanded_user_path),
default=None,
)
# "optional" but actually required for frame classification models
# TODO: check model family in __post_init__ and raise ValueError if labelmap
# TODO: not specified for a frame classification model?
labelmap_path = attr.ib(
converter=converters.optional(expanded_user_path), default=None
)
# optional, transform
spect_scaler_path = attr.ib(
converter=converters.optional(expanded_user_path),
default=None,
)
post_tfm_kwargs = attr.ib(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
default=None,
)
# optional, data loader
num_workers = attr.ib(validator=instance_of(int), default=2)
device = attr.ib(validator=instance_of(str), default=device.get_default())
transform_params = attr.ib(
converter=converters.optional(dict),
validator=validators.optional(instance_of(dict)),
default=None,
)
dataset_params = attr.ib(
converter=converters.optional(dict),
validator=validators.optional(instance_of(dict)),
default=None,
)