0

I have a multi label classification problem, where I was initially using a binary cross entropy loss and my labels are one hot encoded. I found a paper similar to my application and have used contrastive loss function, but I am not sure how to use it in my code. I came across an implementation of supervised contrastive loss, but I didn't understand what are the inputs to the function. One of the input is the labels and the other is 'projection'. What is projection in my case?

import torch
import torch.nn as nn

from math import log class SupervisedContrastiveLoss(nn.Module): def init(self, temperature=0.07): """ Implementation of the loss described in the paper Supervised Contrastive Learning : https://arxiv.org/abs/2004.11362

    :param temperature: int
    """
    super(SupervisedContrastiveLoss, self).__init__()
    self.temperature = temperature

def forward(self, projections, targets):
    """

    :param projections: torch.Tensor, shape [batch_size, projection_dim]
    :param targets: torch.Tensor, shape [batch_size]
    :return: torch.Tensor, scalar
    """
    device = torch.device("cuda") if projections.is_cuda else torch.device("cpu")

    dot_product_tempered = torch.mm(projections, projections.T) / self.temperature
    # Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
    exp_dot_tempered = (
        torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
    )

    mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets).to(device)
    mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(device)
    mask_combined = mask_similar_class * mask_anchor_out
    cardinality_per_samples = torch.sum(mask_combined, dim=1)

    log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
    supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
    supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

    return supervised_contrastive_loss

This is the model I have used

class CV_CNN_Net(nn.Module):
    def __init__(self,device):
        super(CV_CNN_Net, self).__init__()
        self.device = device
        self.nn = nn.Sequential(
                                ComplexConv2d(1, 128, kernel_size = (3, 3),stride = (2, 2),padding=(1,1)),
                                ComplexBatchNorm2d(128),
                                # ComplexReLU(),
                                C_CSELU(),
                                # ZReLU(),
                            ComplexConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                            ComplexBatchNorm2d(128),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),

                            ComplexConv2d(128, 128, kernel_size=(2, 2), stride=(1, 1)),
                            ComplexBatchNorm2d(128),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),

                            ComplexConv2d(128, 128, kernel_size=(2, 2), stride=(1, 1)),
                            ComplexBatchNorm2d(128),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),

                            nn.Flatten(),
                            ComplexLinear(4608,2048),
                            ##ComplexLinear(12800,1024),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),


                            ComplexDropout(p=0.3,device = self.device),
                            ComplexLinear(2048, 1024),
                            ##ComplexLinear(1024, 512),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),

                            ComplexDropout(p=0.3,device = self.device),
                            ComplexLinear(1024, 512),
                            ##ComplexLinear(1024, 512),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),


                            ComplexDropout(p=0.3,device = self.device),
                            ComplexLinear(512,20),
                            ##ComplexLinear(512, 181),
                            # ComplexReLU(),
                            C_CSELU(),
                            # ZReLU(),
                            ComplexDropout(p=0.5, device=self.device),
                            )

    self.FNN = nn.Sequential(
        ##nn.Linear(2048, 181 * 181, bias=False),
        nn.Linear(40, 20, bias=False),
        nn.Sigmoid()
    )

    #self.output_layer = nn.Linear(2048, 180 * 180)

def forward(self, data_train):
    #print("Input shape:", data_train.shape)l
    Outputs = self.nn(data_train)
    #print("Shape before concatenation:", Outputs.shape)
    features = torch.cat([Outputs.real,Outputs.imag],dim=1)
    ##Outputs = self.output_layer(Outputs)
    ##Outputs = Outputs.view(-1, 180, 180)  # Reshape to M x N grid
    Outputs = self.FNN(features)
    #Outputs  = temperature_scaled_sigmoid(Outputs)
    ##Outputs = Outputs.view(-1, 181, 181)
   ## Outputs = Outputs.unsqueeze(0)
    return Outputs, features

I thought features in my code can be used as projection, but I am getting an error during training

C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\cuda\Loss.cu:106: block: [0,0,0], thread: [96,0,0] Assertion input_val >= zero && input_val <= one failed.

Can anyone guide me on how to correctly use the supervised contrastive loss? Thank you in advance

ThinkPad
  • 41
  • 6

0 Answers0