Source code for vak.models.base

"""Base class for a model in ``vak``,
that other families of models should subclass.
"""
from __future__ import annotations

import inspect
from typing import Callable, ClassVar

import pytorch_lightning as lightning
import torch

from .definition import ModelDefinition
from .definition import validate as validate_definition


[docs] class Model(lightning.LightningModule): """Base class for a model in ``vak``, that other families of models should subclass. This class provides methods for working with neural network models, e.g. training the model and generating productions, and it also converts a model definition into a model instance. It provides the methods for working with neural network models by subclassing ``lighting.LightningModule``, and it handles converting a model definition into a model instance inside its ``__init__`` method. Model definitions are declared programmatically using a ``vak.model.ModelDefinition``; see the documentation on that class for more detail. """ definition: ClassVar[ModelDefinition]
[docs] def __init__( self, network: torch.nn.Module | dict | None = None, loss: torch.nn.Module | Callable | None = None, optimizer: torch.optim.Optimizer | None = None, metrics: dict | None = None, ): """Initializes an instance of a model, using its definition. Takes in instances of the attributes defined by the class variable ``self.definition``: ``network``, ``loss``, ``optimizer``, and ``metrics``. If any of those arguments are ``None``, then ``__init__`` instantiates the corresponding attribute with its defaults. If any of those arguments are not an instance of the type defined by ``self.definition``, then a TypeError is raised. Parameters ---------- network : torch.nn.Module, dict An instance of a ``torch.nn.Module`` that implements a neural network, or a ``dict`` that maps human-readable string names to a set of such instances. loss : torch.nn.Module, callable An instance of a ``torch.nn.Module`` that implements a loss function, or a callable Python function that computes a scalar loss. optimizer : torch.optim.Optimizer An instance of a ``torch.optim.Optimizer`` class used with ``loss`` to optimize the parameters of ``network``. metrics : dict A ``dict`` that maps human-readable string names to ``Callable`` functions, used to measure performance of the model. """ from .decorator import ModelDefinitionValidationError super().__init__() # check that we are a sub-class of some other class with required class variables if not hasattr(self, "definition"): raise ValueError( "This model does not have a definition." "Define a model by wrapping a class with the required class variables with " "a ``vak.models`` decorator, e.g. ``vak.models.windowed_frame_classification_model``" ) try: validate_definition(self.definition) except ModelDefinitionValidationError as err: raise ValueError( "Creating model instance failed because model definition is invalid." ) from err # ---- validate any instances that user passed in self.validate_init(network, loss, optimizer, metrics) if network is None: net_kwargs = self.definition.default_config.get("network") if isinstance(self.definition.network, dict): network = { network_name: network_class(**net_kwargs[network_name]) for network_name, network_class in self.definition.network.items() } else: network = self.definition.network(**net_kwargs) self.network = network if loss is None: if inspect.isclass(self.definition.loss): loss_kwargs = self.definition.default_config.get("loss") loss = self.definition.loss(**loss_kwargs) elif inspect.isfunction(self.definition.loss): loss = self.definition.loss self.loss = loss if optimizer is None: optimizer_kwargs = self.definition.default_config.get("optimizer") if isinstance(network, dict): params = [ param for net_name, net_instance in network.items() for param in net_instance.parameters() ] else: params = network.parameters() optimizer = self.definition.optimizer( params=params, **optimizer_kwargs ) self.optimizer = optimizer if metrics is None: metric_kwargs = self.definition.default_config.get("metrics") metrics = {} for metric_name, metric_class in self.definition.metrics.items(): metric_class_kwargs = metric_kwargs.get(metric_name, {}) metrics[metric_name] = metric_class(**metric_class_kwargs) self.metrics = metrics
[docs] @classmethod def validate_init( cls, network: torch.nn.Module | dict | None = None, loss: torch.nn.Module | Callable | None = None, optimizer: torch.optim.Optimizer | None = None, metrics: dict | None = None, ): """Validate arguments to ``vak.models.base.Model.__init__``. Parameters ---------- network : torch.nn.Module, dict An instance of a ``torch.nn.Module`` that implements a neural network, or a ``dict`` where each key is a string that maps a human-readable name to a ``torch.nn.Module`` instance. loss : torch.nn.Module, callable An instance of a ``torch.nn.Module`` that implements a loss function, or a callable Python function that computes a scalar loss. optimizer : torch.optim.Optimizer An instance of a ``torch.optim.Optimizer`` class used with ``loss`` to optimize the parameters of ``network``. metrics : dict A ``dict`` that maps human-readable string names to ``Callable`` functions, used to measure performance of the model. Returns ------- None This method does not return values; it just raises an error if any value is invalid. """ if network: if inspect.isclass(cls.definition.network): if not isinstance(network, cls.definition.network): raise TypeError( f"``network`` should be an instance of {cls.definition.network}" f"but was of type {type(network)}" ) elif isinstance(cls.definition.network, dict): if not isinstance(network, dict): raise TypeError( "Expected ``network`` to be a ``dict`` mapping network names " f"to ``torch.nn.Module`` instances, but type was {type(network)}" ) expected_network_dict_keys = list( cls.definition.network.keys() ) network_dict_keys = list(network.keys()) if not all( [ expected_network_dict_key in network_dict_keys for expected_network_dict_key in expected_network_dict_keys ] ): missing_keys = set(expected_network_dict_keys) - set( network_dict_keys ) raise ValueError( f"The following keys were missing from the ``network`` dict: {missing_keys}" ) if any( [ network_dict_key not in expected_network_dict_keys for network_dict_key in network_dict_keys ] ): extra_keys = set(network_dict_keys) - set( expected_network_dict_keys ) raise ValueError( f"The following keys in the ``network`` dict are not valid: {extra_keys}." f"Valid keys are: {expected_network_dict_keys}" ) for network_name, network_instance in network.items(): if not isinstance( network_instance, cls.definition.network[network_name] ): raise TypeError( f"Network with name '{network_name}' in ``network`` dict " f"should be an instance of {cls.definition.network[network_name]}" f"but was of type {type(network)}" ) else: raise TypeError( f"Invalid type for ``network``: {type(network)}" ) if loss: if issubclass(cls.definition.loss, torch.nn.Module): if not isinstance(loss, cls.definition.loss): raise TypeError( f"``loss`` should be an instance of {cls.definition.loss}" f"but was of type {type(loss)}" ) elif callable(cls.definition.loss): if loss is not cls.definition.loss: raise ValueError( f"``loss`` should be the following callable (probably a function): {cls.definition.loss}" ) else: raise TypeError(f"Invalid type for ``loss``: {type(loss)}") if optimizer: if not isinstance(optimizer, cls.definition.optimizer): raise TypeError( f"``optimizer`` should be an instance of {cls.definition.optimizer}" f"but was of type {type(optimizer)}" ) if metrics: if not isinstance(metrics, dict): raise TypeError( "``metrics`` should be a ``dict`` mapping string metric names " f"to callable metrics, but type of ``metrics`` was {type(metrics)}" ) for metric_name, metric_callable in metrics.items(): if metric_name not in cls.definition.metrics: raise ValueError( f"``metrics`` has name '{metric_name}' but that name " f"is not in the model definition. " f"Valid metric names are: {', '.join(list(cls.definition.metrics.keys()))}" ) if not isinstance( metric_callable, cls.definition.metrics[metric_name] ): raise TypeError( f"metric '{metric_name}' should be an instance of {cls.definition.metrics[metric_name]}" f"but was of type {type(metric_callable)}" )
[docs] def load_state_dict_from_path(self, ckpt_path): """Loads a model from the path to a saved checkpoint. Loads the checkpoint and then calls ``self.load_state_dict`` with the ``state_dict`` in that chekcpoint. This method allows loading a state dict into an instance. It's necessary because `lightning.LightningModule.load`` is a ``classmethod``, so calling that method will trigger ``LightningModule.__init__`` instead of running ``vak.models.Model.__init__``. Parameters ---------- ckpt_path : str, pathlib.Path Path to a checkpoint saved by a model in ``vak``. This checkpoint has the same key-value pairs as any other checkpoint saved by a ``lightning.LightningModule``. Returns ------- None This method modifies the model state by loading the ``state_dict``; it does not return anything. """ ckpt = torch.load(ckpt_path) self.load_state_dict(ckpt["state_dict"])
[docs] @classmethod def attributes_from_config(cls, config: dict): """Get attributes for an instance of a model, given a configuration. Given a ``dict``, ``config``, return instances of class variables Parameters ---------- config : dict Returned by calling ``vak.config.models.map_from_path`` or ``vak.config.models.map_from_config_dict``. Returns ------- network : torch.nn.Module, dict An instance of a ``torch.nn.Module`` that implements a neural network, or a ``dict`` that maps human-readable string names to a set of such instances. loss : torch.nn.Module, callable An instance of a ``torch.nn.Module`` that implements a loss function, or a callable Python function that computes a scalar loss. optimizer : torch.optim.Optimizer An instance of a ``torch.optim.Optimizer`` class used with ``loss`` to optimize the parameters of ``network``. metrics : dict A ``dict`` that maps human-readable string names to ``Callable`` functions, used to measure performance of the model. """ network_kwargs = config.get( "network", cls.definition.default_config["network"] ) if inspect.isclass(cls.definition.network): network = cls.definition.network(**network_kwargs) elif isinstance(cls.definition.network, dict): network = {} for net_name, net_class in cls.definition.network.items(): net_class_kwargs = network_kwargs.get(net_name, {}) network[net_name] = net_class(**net_class_kwargs) if isinstance(cls.definition.network, dict): params = [ param for net_name, net_instance in network.items() for param in net_instance.parameters() ] else: params = network.parameters() optimizer_kwargs = config.get( "optimizer", cls.definition.default_config["optimizer"] ) optimizer = cls.definition.optimizer(params=params, **optimizer_kwargs) if inspect.isclass(cls.definition.loss): loss_kwargs = config.get( "loss", cls.definition.default_config["loss"] ) loss = cls.definition.loss(**loss_kwargs) else: loss = cls.definition.loss metrics_config = config.get( "metrics", cls.definition.default_config["metrics"] ) metrics = {} for metric_name, metric_class in cls.definition.metrics.items(): metrics_class_kwargs = metrics_config.get(metric_name, {}) metrics[metric_name] = metric_class(**metrics_class_kwargs) return network, loss, optimizer, metrics
[docs] @classmethod def from_config(cls, config: dict): """Return an initialized model instance from a config ``dict`` Parameters ---------- config : dict Returned by calling ``vak.config.models.map_from_path`` or ``vak.config.models.map_from_config_dict``. Returns ------- cls : vak.models.base.Model An instance of the model with its attributes initialized using parameters from ``config``. """ network, loss, optimizer, metrics = cls.attributes_from_config(config) return cls( network=network, loss=loss, optimizer=optimizer, metrics=metrics )