I have a multi label classification problem, where I was initially using a binary cross entropy loss and my labels are one hot encoded. I found a paper similar to my application and have used contrastive loss function, but I am not sure how to use it in my code. I came across an implementation of supervised contrastive loss, but I didn't understand what are the inputs to the function. One of the input is the labels and the other is 'projection'. What is projection in my case?
import torch
import torch.nn as nn
from math import log
class SupervisedContrastiveLoss(nn.Module):
def init(self, temperature=0.07):
"""
Implementation of the loss described in the paper Supervised Contrastive Learning :
https://arxiv.org/abs/2004.11362
:param temperature: int
"""
super(SupervisedContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, projections, targets):
"""
:param projections: torch.Tensor, shape [batch_size, projection_dim]
:param targets: torch.Tensor, shape [batch_size]
:return: torch.Tensor, scalar
"""
device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")
dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
# Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
exp_dot_tempered = (
torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
)
mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
mask_combined = mask_similar_class * mask_anchor_out
cardinality_per_samples = torch.sum(mask_combined, dim=1)
log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)
return supervised_contrastive_loss
This is the model I have used
class CV_CNN_Net(nn.Module):
def __init__(self,device):
super(CV_CNN_Net, self).__init__()
self.device = device
self.nn = nn.Sequential(
ComplexConv2d(1, 128, kernel_size = (3, 3),stride = (2, 2),padding=(1,1)),
ComplexBatchNorm2d(128),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
ComplexBatchNorm2d(128),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexConv2d(128, 128, kernel_size=(2, 2), stride=(1, 1)),
ComplexBatchNorm2d(128),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexConv2d(128, 128, kernel_size=(2, 2), stride=(1, 1)),
ComplexBatchNorm2d(128),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
nn.Flatten(),
ComplexLinear(4608,2048),
##ComplexLinear(12800,1024),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexDropout(p=0.3,device = self.device),
ComplexLinear(2048, 1024),
##ComplexLinear(1024, 512),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexDropout(p=0.3,device = self.device),
ComplexLinear(1024, 512),
##ComplexLinear(1024, 512),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexDropout(p=0.3,device = self.device),
ComplexLinear(512,20),
##ComplexLinear(512, 181),
# ComplexReLU(),
C_CSELU(),
# ZReLU(),
ComplexDropout(p=0.5, device=self.device),
)
self.FNN = nn.Sequential(
##nn.Linear(2048, 181 * 181, bias=False),
nn.Linear(40, 20, bias=False),
nn.Sigmoid()
)
#self.output_layer = nn.Linear(2048, 180 * 180)
def forward(self, data_train):
#print("Input shape:", data_train.shape)l
Outputs = self.nn(data_train)
#print("Shape before concatenation:", Outputs.shape)
features = torch.cat([Outputs.real,Outputs.imag],dim=1)
##Outputs = self.output_layer(Outputs)
##Outputs = Outputs.view(-1, 180, 180) # Reshape to M x N grid
Outputs = self.FNN(features)
#Outputs = temperature_scaled_sigmoid(Outputs)
##Outputs = Outputs.view(-1, 181, 181)
## Outputs = Outputs.unsqueeze(0)
return Outputs, features
I thought features in my code can be used as projection, but I am getting an error during training
C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\cuda\Loss.cu:106: block: [0,0,0], thread: [96,0,0] Assertion
input_val >= zero && input_val <= onefailed.
Can anyone guide me on how to correctly use the supervised contrastive loss? Thank you in advance