# Pytorch矩阵乘法总结

pytorch一共有5种乘法

• torch.mul()，和*乘完全一样
• torch.bmm()，三维矩阵乘法，一般用于mini-batch训练中

``````import torch

tensorA_2x3 = torch.tensor(
[[1,2,3],
[3,2,1]]
)

tensorB_1x3 = torch.tensor(
[[1,2,3]]
)

tensorC_scalar = 5

tensorD_2x1 = torch.tensor(
[[2],
[3]]
)

tensorE_3x1 = torch.tensor(
[[2],
[3],
[4]]
)

tensorF_3x2 = torch.tensor(
[[1,2],
[3,4],
[5,6]]
)

tensorG_2x3x2 = torch.tensor(
[[[1,2],
[3,4],
[2,1]],
[[4,5],
[5,4],
[1,2]]]
)

tensorH_2x2x3 = torch.tensor(
[[[1,2,3],
[3,2,1]],
[[4,5,6],
[6,5,4]]]
)
``````

# *乘、torch.mul()

``````print(tensorA_2x3 * tensorA_2x3) # shape相同，对应元素相乘
print(tensorA_2x3 * tensorC_scalar) # 和标量相乘
``````

``````print(torch.mul(tensorA_2x3, tensorA_2x3))
print(torch.mul(tensorA_2x3, tensorB_1x3))
print(torch.mul(tensorA_2x3, tensorD_2x1))
``````

# torch.mm()

``````print(torch.mm(tensorA_2x3, tensorF_3x2))
``````

# torch.bmm()

``````print(torch.bmm(tensorH_2x2x3, tensorG_2x3x2))
``````

# torch.matmul()

``````print(torch.matmul(tensorH_2x2x3, tensorF_3x2))
``````

