vak.config.predict.PredictConfig

class vak.config.predict.PredictConfig(checkpoint_path, labelmap_path, model, batch_size, dataset: DatasetConfig, trainer: TrainerConfig, frames_standardizer_path=None, num_workers=2, annot_csv_filename=None, output_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/vak/checkouts/latest/doc'), min_segment_dur=None, majority_vote=True, save_net_outputs=False)[source]

Bases: object

Class that represents [vak.predict] table of configuration file.

checkpoint_pathstr

path to directory with checkpoint files saved by Torch, to reload model

labelmap_pathstr

path to ‘labelmap.json’ file.

modelvak.config.ModelConfig

The model to use: its name, and the parameters to configure it. Must be an instance of vak.config.ModelConfig

batch_sizeint

number of samples per batch presented to models during training.

datasetvak.config.DatasetConfig

The dataset to use: the path to it, and optionally a path to a file representing splits, and the name, if it is a built-in dataset. Must be an instance of vak.config.DatasetConfig.

trainervak.config.TrainerConfig

Configuration for lightning.pytorch.Trainer. Must be an instance of vak.config.TrainerConfig.

num_workersint

Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2.

frames_standardizer_pathstr

path to a saved vak.transforms.FramesStandardizer object used to standardize (normalize) frames. If spectrograms were normalized and this is not provided, will give incorrect results.

annot_csv_filenamestr

name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv is used, with ‘.annot.csv’ appended to it.

output_dirstr

path to location where .csv containing predicted annotation should be saved. Defaults to current working directory.

min_segment_durfloat

minimum duration of segment, in seconds. If specified, then any segment with a duration less than min_segment_dur is removed from lbl_tb. Default is None, in which case no segments are removed.

majority_votebool

if True, transform segments containing multiple labels into segments with a single label by taking a “majority vote”, i.e. assign all time bins in the segment the most frequently occurring label in the segment. This transform can only be applied if the labelmap contains an ‘unlabeled’ label, because unlabeled segments makes it possible to identify the labeled segments. Default is False.

save_net_outputsbool

If True, save ‘raw’ outputs of neural networks before they are converted to annotations. Default is False. Typically the output will be “logits” to which a softmax transform might be applied. For each item in the dataset–each row in the dataset_path .csv– the output will be saved in a separate file in output_dir, with the extension {MODEL_NAME}.output.npz. E.g., if the input is a spectrogram with spect_path filename gy6or6_032312_081416.npz, and the network is TweetyNet, then the net output file will be gy6or6_032312_081416.tweetynet.output.npz.

__init__(checkpoint_path, labelmap_path, model, batch_size, dataset: DatasetConfig, trainer: TrainerConfig, frames_standardizer_path=None, num_workers=2, annot_csv_filename=None, output_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/vak/checkouts/latest/doc'), min_segment_dur=None, majority_vote=True, save_net_outputs=False) None

Method generated by attrs for class PredictConfig.

Methods

__init__(checkpoint_path, labelmap_path, ...)

Method generated by attrs for class PredictConfig.

from_config_dict(config_dict)

Return PredictConfig instance from a dict.

Attributes

checkpoint_path

labelmap_path

model

batch_size

dataset

trainer

frames_standardizer_path

num_workers

annot_csv_filename

output_dir

min_segment_dur

majority_vote

save_net_outputs

classmethod from_config_dict(config_dict: dict) PredictConfig[source]

Return PredictConfig instance from a dict.

The dict passed in should be the one found by loading a valid configuration toml file with vak.config.parse.from_toml_path(), and then using key predict, i.e., PredictConfig.from_config_dict(config_dict['predict']).