Source code for vak.config.model

from __future__ import annotations

import pathlib

import toml

from .. import models

MODEL_TABLES = [
    "network",
    "optimizer",
    "loss",
    "metrics",
]


[docs] def config_from_toml_dict(toml_dict: dict, model_name: str) -> dict: """Get configuration for a model from a .toml configuration file loaded into a ``dict``. Parameters ---------- toml_dict : dict Configuration from a .toml file, loaded into a dictionary. model_name : str Name of a model, specified as the ``model`` option in a table (such as TRAIN or PREDICT), that should have its own corresponding table specifying its configuration: hyperparameters such as learning rate, etc. Returns ------- model_config : dict Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. """ if model_name not in models.registry.MODEL_NAMES: raise ValueError( f"Invalid model name: {model_name}.\nValid model names are: {models.registry.MODEL_NAMES}" ) try: model_config = toml_dict[model_name] except KeyError as e: raise ValueError( f"A config section specifies the model name '{model_name}', " f"but there is no section named '{model_name}' in the config." ) from e # check if config declares parameters for required attributes; # if not, just put an empty dict that will get passed as the "kwargs" for attr in MODEL_TABLES: if attr not in model_config: model_config[attr] = {} return model_config
[docs] def config_from_toml_path( toml_path: str | pathlib.Path, model_name: str ) -> dict: """Get configuration for a model from a .toml configuration file, given the path to the file. Parameters ---------- toml_path : str, Path to configuration file in .toml format model_name : str of str, i.e. names of models specified by a section (such as TRAIN or PREDICT) that should each have corresponding sections specifying their configuration: hyperparameters such as learning rate, etc. Returns ------- model_config : dict Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. """ toml_path = pathlib.Path(toml_path) if not toml_path.is_file(): raise FileNotFoundError( f"File not found, or not recognized as a file: {toml_path}" ) with toml_path.open("r") as fp: config_dict = toml.load(fp) return config_from_toml_dict(config_dict, model_name)