Q:nn.Linear(2, 3, bias=False)是什么?如何用它进行矩阵相乘?
A:
nn.Linear(2, 3, bias=False)
是PyTorch中的一个线性层模块,表示输入维度为2,输出维度为3的线性变换。其中bias=False
表示该线性层没有偏置参数。
如果要用它进行矩阵相乘,则需要将输入数据构造成(batch_size, input_dim)
的形状,并通过调用该模块的forward
方法进行计算。例如:
import torch # 构造输入数据 x = torch.tensor([[1., 2.], [3., 4.], [5., 6.]]) # 输入数据维度为 (3, 2) # 创建线性层模型 linear = torch.nn.Linear(2, 3, bias=False) # 进行矩阵相乘 y = linear(x) # 输出数据维度为 (3, 3)
上面的代码中,输入数据x
的维度为(3, 2)
,即有3个样本,每个样本有2个特征。创建线性层模型时指定输入维度为2,输出维度为3,因此线性层的权重参数维度为(3, 2)
,可以实现从2维输入到3维输出的线性变换。通过调用linear
模块的forward
方法,将输入数据x
传入进行计算,得到输出数据y
的维度为(3, 3)
,即有3个样本,每个样本有3个特征。
Q:这里的权重是随机生成的嘛?
A:
是的,PyTorch中新建一个线性层时,其权重参数会被随机初始化。在这个例子中,由于没有指定linear
模块的权重参数,因此默认情况下,它的权重参数是随机生成的,符合标准正态分布(均值为0,标准差为1)。可以通过调用linear.weight
属性来查看权重参数的值:
print(linear.weight)
输出结果如下:
tensor([[-0.0381, -0.3997], [ 0.2302, -0.6035], [-0.4994, -0.3726]])
其中,每一行代表一个输出特征的权重向量,共有3个输出特征。可以看到,这些权重参数是随机生成的,并且每次运行代码得到的随机数是不同的。
Q:随机生成的权重形状怎么跟输入的x是一样的?如何进行矩阵相乘的?
A:
在PyTorch中,线性层的权重参数被存储为一个二维张量,其中第一维对应输出特征的数量,第二维对应输入特征的数量。例如,在这个例子中,nn.Linear(2,3,bias=False)
中的权重参数是一个形状为(3,2)
的张量,因为有3个输出特征和2个输入特征。
在矩阵相乘时,输入数据x
的形状为(batch_size, input_dim)
,因此可以将它视为一个形状为(batch_size, 1, input_dim)
的三维张量。接着,将线性层的权重参数视为一个形状为(1, output_dim, input_dim)
的三维张量。然后,通过调用PyTorch中的torch.bmm()
函数(批量矩阵相乘)来进行矩阵相乘操作,如下所示:
import torch # 构造输入数据 x = torch.tensor([[1., 2.], [3., 4.], [5., 6.]]) # 输入数据维度为 (3, 2) # 创建线性层模型 linear = torch.nn.Linear(2, 3, bias=False) # 进行矩阵相乘 x = x.unsqueeze(-1) # 转换成形状为 (batch_size, input_dim, 1) 的三维张量 w = linear.weight.unsqueeze(0) # 转换成形状为 (1, output_dim, input_dim) 的三维张量 y = torch.bmm(w, x).squeeze(-1) # 进行批量矩阵相乘,输出形状为 (batch_size, output_dim)
首先,使用unsqueeze()
函数将输入数据x
和线性层的权重参数linear.weight
转换为三维张量,并把最后一维添加一个新的维度。这里,我们使用了-1
作为参数,表示在最后一维上添加新的维度,因为它是input_dim
,而output_dim
是第二个轴。
接着,使用torch.bmm()
函数进行批量矩阵相乘,它的参数需要满足以下条件:
- 第一个参数为形状为
(batch_size, m, p)
的三维张量; - 第二个参数为形状为
(batch_size, p, n)
的三维张量; - 输出结果为形状为
(batch_size, m, n)
的三维张量。
在这个例子中,w
的形状为(1, output_dim, input_dim)
,x
的形状为(batch_size, input_dim, 1)
,因此可以看作是一个形状为(batch_size, input_dim, 1)
的三维张量,符合torch.bmm()
函数的要求。
最后,使用squeeze()
函数去掉输出张量y
中的最后一维,使其形状为(batch_size, output_dim)
。这样就完成了矩阵相乘的操作。