Source code for vak.config.model

"""Class representing the model table of a toml configuration file."""

from __future__ import annotations

from attrs import asdict, define, field
from attrs.validators import instance_of

from .. import models

MODEL_TABLES = [
    "network",
    "optimizer",
    "loss",
    "metrics",
]


[docs] @define class ModelConfig: """Class representing the model table of a toml configuration file. Attributes ---------- name : str network : dict Keyword arguments for the network class, or a :class:`dict` of ``dict``s mapping network names to keyword arguments. optimizer: dict Keyword arguments for the optimizer class. loss : dict Keyword arguments for the class representing the loss function. metrics: dict A :class:`dict` of ``dict``s mapping metric names to keyword arguments. """ name: str network: dict = field(validator=instance_of(dict)) optimizer: dict = field(validator=instance_of(dict)) loss: dict = field(validator=instance_of(dict)) metrics: dict = field(validator=instance_of(dict))
[docs] @classmethod def from_config_dict(cls, config_dict: dict): """Return :class:`ModelConfig` instance from a :class:`dict`. The :class:`dict` passed in should be the one found by loading a valid configuration toml file with :func:`vak.config.parse.from_toml_path`, and then using a top-level table key, followed by key ``'model'``. E.g., ``config_dict['train']['model']` or ``config_dict['predict']['model']``. Examples -------- config_dict = vak.config.parse.from_toml_path(toml_path) model_config = vak.config.Model.from_config_dict(config_dict['train']) """ model_name = list(config_dict.keys()) if len(model_name) == 0: raise ValueError( "Did not find a single key in `config_dict` corresponding to model name. " f"Instead found no keys. Config dict:\n{config_dict}\n" "A configuration file should specify a single model per top-level table." ) if len(model_name) > 1: raise ValueError( "Did not find a single key in `config_dict` corresponding to model name. " f"Instead found multiple keys: {model_name}.\nConfig dict:\n{config_dict}.\n" "A configuration file should specify a single model per top-level table." ) model_name = model_name[0] MODEL_NAMES = list(models.registry.MODEL_NAMES) if model_name not in MODEL_NAMES: raise ValueError( f"Model name not found in registry: {model_name}\n" f"Model names in registry:\n{MODEL_NAMES}" ) # NOTE: we are getting model_config here model_config = config_dict[model_name] if not all(key in MODEL_TABLES for key in model_config.keys()): invalid_keys = ( key for key in model_config.keys() if key not in MODEL_TABLES ) raise ValueError( f"The following sub-tables in the model config are not valid: {invalid_keys}\n" f"Valid sub-table names are: {MODEL_TABLES}" ) # for any tables not specified, default to empty dict so we can still use ``**`` operator on it for model_table in MODEL_TABLES: if model_table not in model_config: model_config[model_table] = {} return cls(name=model_name, **model_config)
[docs] def asdict(self): """Convert this :class:`ModelConfig` instance to a :class:`dict` that can be passed into functions that take a ``model_config`` argument, like :func:`vak.train` and :func:`vak.predict`. """ return asdict(self)