[[余弦相似度]]
余弦相似度衡量向量方向一致性,广泛用于文本检索、对比学习等场景。
基础实现(NumPy)
1
2
3
4
5
6
7
| import numpy as np
def cosine_similarity_np(vec1, vec2, eps=1e-8):
dot_product = np.sum(vec1 * vec2, axis=-1)
norm1 = np.linalg.norm(vec1, axis=-1)
norm2 = np.linalg.norm(vec2, axis=-1)
return np.clip(dot_product / (norm1 * norm2 + eps), -1.0, 1.0)
|
PyTorch 实现
1
2
3
4
5
6
7
8
| import torch
def cosine_similarity_torch(vec1, vec2, eps=1e-8):
vec1_norm = vec1 / (vec1.norm(dim=-1, keepdim=True) + eps)
vec2_norm = vec2 / (vec2.norm(dim=-1, keepdim=True) + eps)
if vec1.ndim == 2 and vec2.ndim == 2:
return torch.matmul(vec1_norm, vec2_norm.t())
return (vec1_norm * vec2_norm).sum(dim=-1)
|
应用场景
| 场景 | 方法 |
|---|
| 文本语义检索 | Word2Vec + 余弦相似度 |
| 对比学习损失 | InfoNCE + 余弦相似度 |
InfoNCE 损失
1
2
3
4
5
6
7
8
9
10
11
| class InfoNCELoss(nn.Module):
def __init__(self, temperature=0.07):
super().__init__()
self.temperature = temperature
def forward(self, anchor, positive, negatives):
pos_sim = cosine_similarity_torch(anchor, positive).unsqueeze(1)
neg_sim = cosine_similarity_torch(anchor.unsqueeze(1), negatives)
logits = torch.cat([pos_sim, neg_sim], dim=1) / self.temperature
labels = torch.zeros(anchor.shape[0], dtype=torch.long).to(anchor.device)
return nn.CrossEntropyLoss()(logits, labels)
|
Comments