"""Parametric UMAP model, as described in [1]_.
Code adapted from implementation by @elyxlz
https://github.com/elyxlz/umap_pytorch
with changes made by Tim Sainburg:
https://github.com/lmcinnes/umap/issues/580#issuecomment-1368649550.
"""
from __future__ import annotations
import pathlib
from typing import Callable, ClassVar, Type
import lightning
import torch
import torch.utils.data
from .definition import ModelDefinition
from .registry import model_family
[docs]
@model_family
class ParametricUMAPModel(lightning.LightningModule):
"""Parametric UMAP model, as described in [1]_.
Notes
-----
Code adapted from implementation by @elyxlz
https://github.com/elyxlz/umap_pytorch
with changes made by Tim Sainburg:
https://github.com/lmcinnes/umap/issues/580#issuecomment-1368649550.
References
----------
.. [1] Sainburg, T., McInnes, L., & Gentner, T. Q. (2021).
Parametric UMAP embeddings for representation and semisupervised learning.
Neural Computation, 33(11), 2881-2907.
https://direct.mit.edu/neco/article/33/11/2881/107068.
"""
definition: ClassVar[ModelDefinition]
[docs]
def __init__(
self,
network: dict,
loss: torch.nn.Module | Callable,
optimizer: torch.optim.Optimizer,
metrics: dict[str:Type],
):
super().__init__()
self.network = torch.nn.ModuleDict(
network
)
self.loss = loss
self.optimizer = optimizer
self.metrics = metrics
[docs]
def training_step(self, batch, batch_idx):
(edges_to_exp, edges_from_exp) = batch
embedding_to = self.network['encoder'](edges_to_exp)
embedding_from = self.network['encoder'](edges_from_exp)
if 'decoder' in self.network:
reconstruction = self.network['decoder'](embedding_to)
before_encoding = edges_to_exp
else:
reconstruction = None
before_encoding = None
loss_umap, loss_reconstruction, loss = self.loss(
embedding_to, embedding_from, reconstruction, before_encoding
)
self.log("train_umap_loss", loss_umap, on_step=True)
if loss_reconstruction:
self.log(
"train_reconstruction_loss", loss_reconstruction, on_step=True
)
# note if there's no ``loss_reconstruction``, then ``loss`` == ``loss_umap``
self.log("train_loss", loss, on_step=True)
return loss
[docs]
def validation_step(self, batch, batch_idx):
(edges_to_exp, edges_from_exp) = batch
embedding_to = self.network['encoder'](edges_to_exp)
embedding_from = self.network['encoder'](edges_from_exp)
if 'decoder' in self.network is not None:
reconstruction = self.network['decoder'](embedding_to)
before_encoding = edges_to_exp
else:
reconstruction = None
before_encoding = None
loss_umap, loss_reconstruction, loss = self.loss(
embedding_to, embedding_from, reconstruction, before_encoding
)
self.log("val_umap_loss", loss_umap, on_step=True)
if loss_reconstruction:
self.log(
"val_reconstruction_loss", loss_reconstruction, on_step=True
)
# note if there's no ``loss_reconstruction``, then ``loss`` == ``loss_umap``
self.log("val_loss", loss, on_step=True)
[docs]
def load_state_dict_from_path(self, ckpt_path):
"""Loads a model from the path to a saved checkpoint.
Loads the checkpoint and then calls
``self.load_state_dict`` with the ``state_dict``
in that chekcpoint.
This method allows loading a state dict into an instance.
It's necessary because `lightning.pytorch.LightningModule.load`` is a
``classmethod``, so calling that method will trigger
``LightningModule.__init__`` instead of running
``vak.models.Model.__init__``.
Parameters
----------
ckpt_path : str, pathlib.Path
Path to a checkpoint saved by a model in ``vak``.
This checkpoint has the same key-value pairs as
any other checkpoint saved by a
``lightning.pytorch.LightningModule``.
Returns
-------
None
This method modifies the model state by loading the ``state_dict``;
it does not return anything.
"""
ckpt = torch.load(ckpt_path)
self.load_state_dict(ckpt["state_dict"])
[docs]
class ParametricUMAPDatamodule(lightning.pytorch.LightningDataModule):
[docs]
def __init__(
self,
dataset,
batch_size,
num_workers,
):
super().__init__()
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
[docs]
def train_dataloader(self) -> torch.utils.data.DataLoader:
return torch.utils.data.DataLoader(
dataset=self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
[docs]
class ParametricUMAP:
[docs]
def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module | None = None,
n_neighbors: int = 10,
min_dist: float = 0.1,
metric: str = "euclidean",
num_epochs: int = 200,
lr: float = 1e-3,
batch_size: int = 64,
num_workers: int = 16,
random_state: int | None = None,
):
self.encoder = encoder
self.decoder = decoder
self.n_neighbors = n_neighbors
self.min_dist = min_dist
self.metric = metric
self.lr = lr
self.num_epochs = num_epochs
self.batch_size = batch_size
self.num_workers = num_workers
self.random_state = random_state
self.model = ParametricUMAPModel(self.encoder, min_dist=self.min_dist)
def fit(
self,
trainer: lightning.pytorch.Trainer,
dataset_path: str | pathlib.Path,
transform=None,
):
from vak.datasets.parametric_umap import ParametricUMAPDataset
dataset = ParametricUMAPDataset.from_dataset_path(
dataset_path,
"train",
self.n_neighbors,
self.metric,
self.random_state,
self.num_epochs,
transform,
)
trainer.fit(
model=self.model,
datamodule=ParametricUMAPDatamodule(
dataset, self.batch_size, self.num_workers
),
)
@torch.no_grad()
def transform(self, X):
embedding = self.model.encoder(X).detach().cpu().numpy()
return embedding
@torch.no_grad()
def inverse_transform(self, Z):
return self.model.decoder(Z).detach().cpu().numpy()