张芷铭的个人博客

[[余弦相似度]]

余弦相似度衡量向量方向一致性,广泛用于文本检索、对比学习等场景。

基础实现(NumPy)

1
2
3
4
5
6
7
import numpy as np

def cosine_similarity_np(vec1, vec2, eps=1e-8):
    dot_product = np.sum(vec1 * vec2, axis=-1)
    norm1 = np.linalg.norm(vec1, axis=-1)
    norm2 = np.linalg.norm(vec2, axis=-1)
    return np.clip(dot_product / (norm1 * norm2 + eps), -1.0, 1.0)

PyTorch 实现

1
2
3
4
5
6
7
8
import torch

def cosine_similarity_torch(vec1, vec2, eps=1e-8):
    vec1_norm = vec1 / (vec1.norm(dim=-1, keepdim=True) + eps)
    vec2_norm = vec2 / (vec2.norm(dim=-1, keepdim=True) + eps)
    if vec1.ndim == 2 and vec2.ndim == 2:
        return torch.matmul(vec1_norm, vec2_norm.t())
    return (vec1_norm * vec2_norm).sum(dim=-1)

应用场景

场景方法
文本语义检索Word2Vec + 余弦相似度
对比学习损失InfoNCE + 余弦相似度

InfoNCE 损失

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negatives):
        pos_sim = cosine_similarity_torch(anchor, positive).unsqueeze(1)
        neg_sim = cosine_similarity_torch(anchor.unsqueeze(1), negatives)
        logits = torch.cat([pos_sim, neg_sim], dim=1) / self.temperature
        labels = torch.zeros(anchor.shape[0], dtype=torch.long).to(anchor.device)
        return nn.CrossEntropyLoss()(logits, labels)

Comments