MVP / mvp /utils /loss.py
yzhouchen001's picture
model code
d9df210
raw
history blame
2.73 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
v1_norm = torch.norm(v1, dim=1, keepdim=True)
v2_norm = torch.norm(v2, dim=1, keepdim=True)
v2T = torch.transpose(v2, 0, 1)
inner_prod = torch.matmul(v1, v2T)
v2_normT = torch.transpose(v2_norm, 0, 1)
norm_mat = torch.matmul(v1_norm, v2_normT)
loss_mat = torch.div(inner_prod, norm_mat)
loss_mat = loss_mat * (1/tau)
loss_mat = torch.exp(loss_mat)
numerator = torch.diagonal(loss_mat)
numerator = torch.unsqueeze(numerator, 0)
Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True)
Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1)
#Lv1_v2_denom = Lv1_v2_denom - numerator
Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True)
#Lv2_v1_denom = Lv2_v1_denom - numerator
Lv1_v2 = torch.div(numerator, Lv1_v2_denom)
Lv1_v2 = -1 * torch.log(Lv1_v2)
Lv1_v2 = torch.mean(Lv1_v2)
Lv2_v1 = torch.div(numerator, Lv2_v1_denom)
Lv2_v1 = -1 * torch.log(Lv2_v1)
Lv2_v1 = torch.mean(Lv2_v1)
return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
def cand_spec_sim_loss(spec_enc, cand_enc):
cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
spec_enc = spec_enc.unsqueeze(0) # 1 x B x d
sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
loss = torch.mean(sim)
return loss
class cons_spec_loss:
def __init__(self, loss_type) -> None:
self.loss_compute = {'cosine': self.cos_loss,
'l2':torch.nn.MSELoss()}[loss_type]
def __call__(self,cons_spec, ind_spec):
return self.loss_compute(cons_spec, ind_spec)
def cos_loss(self, cons_spec, ind_spec):
sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
loss = 1-torch.mean(sim)
return loss
class fp_loss:
def __init__(self, loss_type) -> None:
self.loss_compute = {'cosine': self.fp_loss_cos,
'bce': nn.BCELoss()}[loss_type]
def __call__(self, predicted_fp, target_fp):
return self.loss_compute(predicted_fp, target_fp)
def fp_loss_cos(self, predicted_fp, target_fp):
sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
return 1 - torch.mean(sim)