Source code for vak.config.trainer
from __future__ import annotations
from attrs import asdict, define, field, validators
from .. import common
[docs]
def is_valid_accelerator(instance, attribute, value):
"""Check if ``accelerator`` is valid"""
if value == "auto":
raise ValueError(
"Using the 'auto' value for the `lightning.pytorch.Trainer` parameter `accelerator` currently "
"breaks functionality for the command-line interface of `vak`. "
"Please see this issue: https://github.com/vocalpy/vak/issues/691"
"If you need to use the 'auto' mode of `lightning.pytorch.Trainer`, please use `vak` directly in a script."
)
elif value in ("cpu", "gpu", "tpu", "ipu"):
return
else:
raise ValueError(
f"Invalid value for 'accelerator' key in 'trainer' table of configuration file: {value}. "
'Value must be one of: {"cpu", "gpu", "tpu", "ipu"}'
)
[docs]
def is_valid_devices(instance, attribute, value):
"""Check if ``devices`` is valid"""
if not (
(isinstance(value, int))
or (
isinstance(value, list)
and all([isinstance(el, int) for el in value])
)
):
raise ValueError(
"Invalid value for 'devices' key in 'trainer' table of configuration file: {value}"
)
[docs]
@define
class TrainerConfig:
"""Class that represents ``trainer`` sub-table
in a toml configuration file.
Used to configure :class:`lightning.pytorch.Trainer`.
Attributes
----------
accelerator : str
Value for the `accelerator` argument to :class:`lightning.pytorch.Trainer`.
Default is the return value of :func:`vak.common.accelerator.get_default`.
devices: int, list of int
Number of devices (int) or exact device(s) (list of int) to use.
Notes
-----
Using the 'auto' value for the `lightning.pytorch.Trainer` parameter `accelerator` currently
breaks functionality for the command-line interface of `vak`.
Please see this issue: https://github.com/vocalpy/vak/issues/691
If you need to use the 'auto' mode of `lightning.pytorch.Trainer`, please use `vak` directly in a script.
Likewise, setting a value for the `lightning.pytorch.Trainer` parameter `devices` that is not either 1
(meaning \"use a single GPU\") or a list with a single number (meaning \"use this exact GPU\") currently
breaks functionality for the command-line interface of `vak`.
Please see this issue: https://github.com/vocalpy/vak/issues/691
If you need to use multiple GPUs, please use `vak` directly in a script.
"""
accelerator: str = field(
validator=is_valid_accelerator,
default=common.accelerator.get_default(),
)
devices: int | list[int] = field(
validator=validators.optional(is_valid_devices),
# for devices, we need to look at accelerator in post-init to determine default
default=None,
)
def __attrs_post_init__(self):
# set default self.devices *before* we validate,
# so that we don't throw error because of the default None
# that we need to change here depending on the value of self.accelerator
if self.devices is None:
if self.accelerator == "cpu":
# ~"use all available"
self.devices = 1
elif self.accelerator in ("gpu", "tpu", "ipu"):
# we can only use a single device, assume there's only one
self.devices = [0]
if self.accelerator in ("gpu", "tpu", "ipu"):
if not (
(isinstance(self.devices, int) and self.devices == 1)
or (
isinstance(self.devices, list)
and len(self.devices) == 1
and all([isinstance(el, int) for el in self.devices])
)
):
raise ValueError(
"Setting a value for the `lightning.pytorch.Trainer` parameter `devices` that is not either 1 "
'(meaning "use a single GPU") or a list with a single number '
'(meaning "use this exact GPU") currently '
"breaks functionality for the command-line interface of `vak`. "
"Please see this issue: https://github.com/vocalpy/vak/issues/691"
"If you need to use multiple GPUs, please use `vak` directly in a script."
)
elif self.accelerator == "cpu":
if isinstance(self.devices, list):
raise ValueError(
f"Value for `devices` cannot be a list when `accelerator` is `cpu`. Value was: {self.devices}\n"
"When `accelerator` is `cpu`, please set `devices` to 1 or -1 (which are equivalent)."
)
if self.devices < 1:
raise ValueError(
"When value for 'accelerator' is 'cpu', value for `devices` "
f"should be an int > 0, but was: {self.devices}"
)
[docs]
def asdict(self):
"""Convert this :class:`TrainerConfig` instance
to a :class:`dict` that can be passed
into functions that take a ``trainer_config`` argument,
like :func:`vak.train` and :func:`vak.predict`.
"""
return asdict(self)