Skip to content

Commit

Permalink
Implemented option for paired negative keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
RElbers committed Oct 23, 2021
1 parent f840990 commit 9bdd791
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
61 changes: 45 additions & 16 deletions info_nce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ class InfoNCE(nn.Module):
reduction: Reduction method applied to the output.
Value must be one of ['none', 'sum', 'mean'].
See torch.nn.functional.cross_entropy for more details about each option.
negative_mode: Determines how the (optional) negative_keys are handled.
Value must be one of ['paired', 'unpaired'].
If 'paired', then each query sample is paired with a number of negative keys.
Comparable to a triplet loss, but with multiple negatives per sample.
If 'unpaired', then the set of negative keys are all unrelated to any positive key.
Input shape:
query: (N, D) Tensor with query samples (e.g. embeddings of the input).
positive_key: (N, D) Tensor with positive samples (e.g. embeddings of augmented input).
negative_keys (optional): (M, D) Tensor with negative samples (e.g. embeddings of other inputs).
negative_keys (optional): Tensor with negative samples (e.g. embeddings of other inputs)
If negative_mode = 'paired', then negative_keys is a (N, M, D) Tensor.
If negative_mode = 'unpaired', then negative_keys is a (M, D) Tensor.
If None, then the negative keys for a sample are the positive keys for the other samples.
Returns:
Expand All @@ -40,39 +47,61 @@ class InfoNCE(nn.Module):
>>> output = loss(query, positive_key, negative_keys)
"""

def __init__(self, temperature=0.1, reduction='mean'):
def __init__(self, temperature=0.1, reduction='mean', negative_mode='unpaired'):
super().__init__()
self.temperature = temperature
self.reduction = reduction
self.negative_mode = negative_mode

def forward(self, query, positive_key, negative_keys=None):
return info_nce(query, positive_key, negative_keys, temperature=self.temperature, reduction=self.reduction)


def info_nce(query, positive_key, negative_keys=None, temperature=0.1, reduction='mean'):
# Inputs all have 2 dimensions.
if query.dim() != 2 or positive_key.dim() != 2 or (negative_keys is not None and negative_keys.dim() != 2):
raise ValueError('query, positive_key and negative_keys should all have 2 dimensions.')
return info_nce(query, positive_key, negative_keys,
temperature=self.temperature,
reduction=self.reduction,
negative_mode=self.negative_mode)


def info_nce(query, positive_key, negative_keys=None, temperature=0.1, reduction='mean', negative_mode='unpaired'):
# Check input dimensionality.
if query.dim() != 2:
raise ValueError('<query> must have 2 dimensions.')
if positive_key.dim() != 2:
raise ValueError('<positive_key> must have 2 dimensions.')
if negative_keys is not None:
if negative_mode == 'unpaired' and negative_keys.dim() != 2:
raise ValueError("<negative_keys> must have 2 dimensions if <negative_mode> == 'unpaired'.")
if negative_mode == 'paired' and negative_keys.dim() != 3:
raise ValueError("<negative_keys> must have 3 dimensions if <negative_mode> == 'paired'.")

# Each query sample is paired with exactly one positive key sample.
# Check matching number of samples.
if len(query) != len(positive_key):
raise ValueError('query and positive_key must have the same number of samples.')
raise ValueError('<query> and <positive_key> must must have the same number of samples.')
if negative_keys is not None:
if negative_mode == 'paired' and len(query) != len(negative_keys):
raise ValueError("If negative_mode == 'paired', then <negative_keys> must have the same number of samples as <query>.")

# Embedding vectors should have same number of components.
if query.shape[1] != positive_key.shape[1] != (positive_key.shape[1] if negative_keys is None else negative_keys.shape[1]):
raise ValueError('query, positive_key and negative_keys should have the same number of components.')
if query.shape[-1] != positive_key.shape[-1]:
raise ValueError('Vectors of <query> and <positive_key> should have the same number of components.')
if negative_keys is not None:
if query.shape[-1] != negative_keys.shape[-1]:
raise ValueError('Vectors of <query> and <negative_keys> should have the same number of components.')

# Normalize to unit vectors
query, positive_key, negative_keys = normalize(query, positive_key, negative_keys)

if negative_keys is not None:
# Explicit negative keys

# Cosine between positive pairs
positive_logit = torch.sum(query * positive_key, dim=1, keepdim=True)

# Cosine between all query-negative combinations
negative_logits = query @ transpose(negative_keys)
if negative_mode == 'unpaired':
# Cosine between all query-negative combinations
negative_logits = query @ transpose(negative_keys)

elif negative_mode == 'paired':
query = query.unsqueeze(1)
negative_logits = query @ transpose(negative_keys)
negative_logits = negative_logits.squeeze(1)

# First index in last dimension are the positive samples
logits = torch.cat([positive_logit, negative_logits], dim=1)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = info-nce-pytorch
version = 0.1.2
version = 0.1.3
description = PyTorch implementation of the InfoNCE loss for self-supervised learning.
long-description = file: README.rst
author = Robin Elbers
Expand Down

0 comments on commit 9bdd791

Please sign in to comment.