Multiplication in PyTorch
1. dot
a = torch.tensor(1, 2, 3)
b = torch.tensor(1, -1, 1)
torch.dot(a, b) # 2
# x1y1 + x2y2 + x3y3 ...
2. mul
This is the element-wise multiplication (broadcast). Or we can use ‘*’.
a = torch.tensor(1, 2, 3)
b = torch.tensor(1, -1, 1)
torch.mul(a, b) # (1, -2, 3)
a * b # (1, -2, 3)
3. mm & bmm
This is the matrix-multiplication. Also, ‘bmm’ is used for batch data.
a = torch.ones([2, 3])
b = torch.ones([3, 4])
torch.mm(a, b) # torch.ones([2, 4])
4. matmul
This is for high-dim tensors.
# 1d 1d -> dot
# 2d 2d -> mm
# 1d 2d -> (1, n) mm (n, p) -> (p)
# xd yd -> broadcast non-mat-dim -> mm
a = torch.ones([j,1,n,p])
b = torch.ones([k,p,m])
torch.matmul(a,b) # [j,k,n,m]
5. repeat
# 右对齐,然后扩展成相同维度,每个维度进行重复
a = torch.ones([3, 4])
a.repeat([1, 2])
# repeat-dim >= mat-dim