Pytorch矩阵乘法总结

pytorch一共有5种乘法

  • *乘,element-wise乘法,支持broadcast操作
  • torch.mul(),和*乘完全一样
  • torch.mm(),矩阵叉乘,即对应元素相乘相加,不支持broadcast操作
  • torch.bmm(),三维矩阵乘法,一般用于mini-batch训练中
  • torch.matmul(),叉乘,支持broadcast操作

先定义下面的tensor(本文不展示print结果):

import torch

tensorA_2x3 = torch.tensor(
    [[1,2,3],
     [3,2,1]]
)

tensorB_1x3 = torch.tensor(
    [[1,2,3]]
)

tensorC_scalar = 5

tensorD_2x1 = torch.tensor(
    [[2],
     [3]]
)

tensorE_3x1 = torch.tensor(
    [[2],
     [3],
     [4]]
)

tensorF_3x2 = torch.tensor(
    [[1,2],
     [3,4],
     [5,6]]
)

tensorG_2x3x2 = torch.tensor(
    [[[1,2],
      [3,4],
      [2,1]],
     [[4,5],
      [5,4],
      [1,2]]]
)

tensorH_2x2x3 = torch.tensor(
    [[[1,2,3],
      [3,2,1]],
     [[4,5,6],
      [6,5,4]]]
)

*乘、torch.mul()

允许以下操作:

print(tensorA_2x3 * tensorA_2x3) # shape相同,对应元素相乘
print(tensorA_2x3 * tensorB_1x3) # shape不相同,broadcast
print(tensorA_2x3 * tensorD_2x1) # shape不相同,broadcast
print(tensorA_2x3 * tensorC_scalar) # 和标量相乘

上面的操作等价于:

print(torch.mul(tensorA_2x3, tensorA_2x3))
print(torch.mul(tensorA_2x3, tensorB_1x3))
print(torch.mul(tensorA_2x3, tensorD_2x1))

星乘和torch.mul()方法一样,支持broadcast操作。任意一维相同或者都相同,例如2×3矩阵能够和2xN、Nx3以及2×3的矩阵做*乘、torch.mul()操作。

torch.mm()

二维矩阵乘法,不支持broadcast操作,实现的是数学中的矩阵叉乘。

print(torch.mm(tensorA_2x3, tensorF_3x2))

torch.bmm()

三维矩阵乘法,b表示的是batch,要求两个矩阵都是三维矩阵,mini-batch训练的时候可能会用到。

print(torch.bmm(tensorH_2x2x3, tensorG_2x3x2))

用字母表示就是,BxMxN大小的矩阵能够和BxNxY的矩阵做叉乘,结果为BxMxY。

torch.matmul()

二维矩阵乘法,支持broadcast操作。

print(torch.matmul(tensorH_2x2x3, tensorF_3x2))

同理,对于三维矩阵来说,BxMxN大小的矩阵能够和NxY的矩阵做叉乘,NxY的矩阵能够广播成BxNxY,然后做torch.bmm()操作。
举个例子,从CNN中得到了一个16x4x512的特征A,其中16为batch size,4×512表示4个512维特征,要将512维特征降为64维,那么我们要乘上一个512×64的weight矩阵,这个时候torch.matmul就排上用场了,直接用torch.matmul(A, weight),最后得到16x4x64的特征。

    原文作者:Wanncye
    原文地址: https://blog.csdn.net/weixin_42065178/article/details/119517404
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞