"""Registry for models.Makes it possible to register a model declared outside of ``vak``with a decorator, so that the model can be used at runtime."""from__future__importannotationsimportinspectfromtypingimportAny,Type,TYPE_CHECKINGimportlightningifTYPE_CHECKING:from.factoryimportModelFactoryMODEL_FAMILY_REGISTRY={}
[docs]defmodel_family(family_class:Type)->None:"""Decorator that adds a :class:`lightning.LightningModule` class to the registry of model families."""ifnotissubclass(family_class,lightning.LightningModule):raiseTypeError("The ``family_class`` provided to the `vak.models.model_family` decorator""must be a subclass of `lightning.LightningModule`, "f"but the class specified is not: {family_class}. "f"Subclasses of `lightning.LightningModule` are: {lightning.LightningModule.__subclasses__()}")model_family_name=family_class.__name__ifmodel_family_nameinMODEL_FAMILY_REGISTRY:raiseValueError(f"Attempted to register a model family with the name '{model_family_name}', "f"but this name is already in the registry:\n{MODEL_FAMILY_REGISTRY}")MODEL_FAMILY_REGISTRY[model_family_name]=family_class# need to return class after we register it or we replace it with None# when this function is used as a decoratorreturnfamily_class
MODEL_REGISTRY={}
[docs]defregister_model(model:ModelFactory)->ModelFactory:"""Function that registers a model in the model registry. This function is called by :func:`vak.models.decorator.model`, that creates an instance of a :class:`vak.models.ModelFactory`, given a :class:`vak.models.definition.ModelDefinition` and a :class:`lightning.LightningModule` class that has been registered as a model family with :func:`model_family`. So you will not usually need to use this function directly, and should prefer to use :func:`vak.models.decorator.model` instead. """model_family_classes=list(MODEL_FAMILY_REGISTRY.values())model_family=model.familyifmodel_familynotinmodel_family_classes:raiseTypeError("The family of `model` passed to the `register_model` decorator "f"is not recognized as a model family. Class was '{model}' and "f"its family is '{model_family}'. "f"Please specify a valid model family. "f"Valid model family classes are: {model_family_classes}")model_name=model.__name__ifmodel_nameinMODEL_REGISTRY:raiseValueError(f"Attempted to register a model family with the name '{model_name}', "f"but this name is already in the registry.\n")MODEL_REGISTRY[model_name]=model# need to return class after we register it,# or we would replace it with None when this function is used as a decoratorreturnmodel
def__getattr__(name:str)->Any:"""Module-level __getattr__ function that we use to dynamically determine models."""ifname=="MODEL_FAMILY_FROM_NAME":return{model_name:model.family.__name__formodel_name,modelinMODEL_REGISTRY.items()}elifname=="MODEL_NAMES":returnlist(MODEL_REGISTRY.keys())else:raiseAttributeError(f"Not an attribute of `vak.models.registry`: {name}")