"""Parametric UMAP loss function."""
from __future__ import annotations
import warnings
import torch
# isort: off
# Ignore warnings from Numba deprecation:
# https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit
# Numba is required by UMAP.
from numba.core.errors import NumbaDeprecationWarning
from torch.nn.functional import mse_loss
warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
from umap.umap_ import find_ab_params # noqa : E402
# isort: on
[docs]
def convert_distance_to_probability(distances, a=1.0, b=1.0):
"""Convert distances to probability.
Computes equation (2.6) of Sainburg McInnes Gentner 2021,
:math:`q_{ij} = (1 + a \abs{z_i - z_j}^{2b} )^{-1}`.
The function uses torch.log1p to avoid floating point error:
``-torch.log1p(a * distances ** (2 * b))``.
See https://en.wikipedia.org/wiki/Natural_logarithm#lnp1
"""
# next line, equivalent to 1.0 / (1.0 + a * distances ** (2 * b))
# but avoids floating point error
return -torch.log1p(a * distances ** (2 * b))
[docs]
def compute_cross_entropy(
probabilities_graph,
probabilities_distance,
EPS=1e-4,
repulsion_strength=1.0,
):
"""Computes cross entropy as used for UMAP cost function"""
# cross entropy
attraction_term = -probabilities_graph * torch.nn.functional.logsigmoid(
probabilities_distance
)
repulsion_term = (
-(1.0 - probabilities_graph)
* (
torch.nn.functional.logsigmoid(probabilities_distance)
- probabilities_distance
)
* repulsion_strength
)
# balance the expected losses between attraction and repulsion
CE = attraction_term + repulsion_term
return attraction_term, repulsion_term, CE
[docs]
def umap_loss(
embedding_to: torch.Tensor,
embedding_from: torch.Tensor,
a,
b,
negative_sample_rate: int = 5,
):
"""UMAP loss function
Converts distances to probabilities,
and then computes cross entropy.
"""
# get negative samples by randomly shuffling the batch
embedding_neg_to = embedding_to.repeat(negative_sample_rate, 1)
repeat_neg = embedding_from.repeat(negative_sample_rate, 1)
embedding_neg_from = repeat_neg[torch.randperm(repeat_neg.shape[0])]
distance_embedding = torch.cat(
(
(embedding_to - embedding_from).norm(dim=1),
(embedding_neg_to - embedding_neg_from).norm(dim=1),
# ``to`` method in next line to avoid error `Expected all tensors to be on the same device`
),
dim=0,
).to(embedding_to.device)
# convert probabilities to distances
probabilities_distance = convert_distance_to_probability(
distance_embedding, a, b
)
# set true probabilities based on negative sampling
batch_size = embedding_to.shape[0]
probabilities_graph = torch.cat(
(
torch.ones(batch_size),
torch.zeros(batch_size * negative_sample_rate),
),
dim=0,
# ``to`` method in next line to avoid error `Expected all tensors to be on the same device`
).to(embedding_to.device)
# compute cross entropy
(attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy(
probabilities_graph,
probabilities_distance,
)
loss = torch.mean(ce_loss)
return loss
[docs]
class UmapLoss(torch.nn.Module):
""""""
[docs]
def __init__(
self,
spread: float = 1.0,
min_dist: float = 0.1,
negative_sample_rate: int = 5,
beta: float = 1.0,
):
super().__init__()
self.min_dist = min_dist
self.a, self.b = find_ab_params(spread, min_dist)
self.negative_sample_rate = negative_sample_rate
self.beta = beta
[docs]
def forward(
self,
embedding_to: torch.Tensor,
embedding_from: torch.Tensor,
reconstruction: torch.Tensor | None = None,
before_encoding: torch.Tensor | None = None,
):
loss_umap = umap_loss(
embedding_to,
embedding_from,
self.a,
self.b,
self.negative_sample_rate,
)
if reconstruction is not None:
loss_reconstruction = mse_loss(reconstruction, before_encoding)
loss = loss_umap + self.beta * loss_reconstruction
else:
loss_reconstruction = None
loss = loss_umap
return loss_umap, loss_reconstruction, loss