vak.models.definition.ModelDefinition¶
- class vak.models.definition.ModelDefinition(network: Module | dict, loss: dict, optimizer: Optimizer, metrics: dict, default_config: dict)[source]¶
Bases:
object
A class that represents the definition of a neural network model.
A model definition is any class that has the following class variables:
- network: torch.nn.Module or dict
Neural network. If a dict, should map string network names to torch.nn.Module classes.
- loss: torch.nn.Module, callable
Either a built-in loss module, or a callable function that computes loss.
- optimizer: torch.optim.Optimizer
Optimizer used to optimize neural network parameters during training.
- metrics: dict
Metrics used to evaluate network. Should map string names of metric to callable classes that compute metric.
- default_configdict
That specifies default keyword arguments to use when instantiating any classes in
network
,optimizer
,loss
, ormetrics
. Used byvak.models.base.Model
and its sub-classes that represent model families. E.g., those classes will do:network = self.definition.network(**self.definition.default_config['network'])
.
Note it is not necessary to sub-class this class; it exists mainly for type-checking purposes.
For more detail, see
vak.models.decorator.model()
andvak.models.ModelFactory
.- __init__(network: Module | dict, loss: dict, optimizer: Optimizer, metrics: dict, default_config: dict) None ¶
Methods
__init__
(network, loss, optimizer, metrics, ...)Attributes
network
loss
optimizer
metrics
default_config