Source code for vak.cli.predict

import logging
from pathlib import Path

from .. import common, config
from .. import predict as predict_module
from ..common.logging import config_logging_for_cli, log_version

logger = logging.getLogger(__name__)


[docs] def predict(toml_path): """make predictions on dataset with trained model specified in config.toml file. Function called by command-line interface. Parameters ---------- toml_path : str, Path path to a configuration file in TOML format. """ toml_path = Path(toml_path) cfg = config.Config.from_toml_path(toml_path) if cfg.predict is None: raise ValueError( f"predict called with a config.toml file that does not have a PREDICT section: {toml_path}" ) # ---- set up logging ---------------------------------------------------------------------------------------------- config_logging_for_cli( log_dst=cfg.predict.output_dir, log_stem="predict", level="INFO", force=True, ) log_version(logger) logger.info("Logging results to {}".format(cfg.predict.output_dir)) if cfg.predict.dataset.path is None: raise ValueError( "No value is specified for 'dataset_path' in this .toml config file." f"To generate a .csv file that represents the dataset, " f"please run the following command:\n'vak prep {toml_path}'" ) predict_module.predict( model_config=cfg.predict.model.asdict(), dataset_config=cfg.predict.dataset.asdict(), trainer_config=cfg.predict.trainer.asdict(), checkpoint_path=cfg.predict.checkpoint_path, labelmap_path=cfg.predict.labelmap_path, num_workers=cfg.predict.num_workers, timebins_key=cfg.prep.spect_params.timebins_key if cfg.prep else common.constants.TIMEBINS_KEY, frames_standardizer_path=cfg.predict.frames_standardizer_path, annot_csv_filename=cfg.predict.annot_csv_filename, output_dir=cfg.predict.output_dir, min_segment_dur=cfg.predict.min_segment_dur, majority_vote=cfg.predict.majority_vote, save_net_outputs=cfg.predict.save_net_outputs, )