nn.Linear

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

这个函数主要是进行空间的线性映射

  • in_features:输入数据的数据维度
  • out_features:输出数据的数据维度

函数执行过程:

假设我们有一批数据的维度为20维,这一批数据一共有128个,我们要将20维的映射到30维空间的中,下面是计算过程,其中Linear函数的weight权重

其中,

一个简单的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch


x = torch.randn(128, 20) # 输入的维度是(128,20)
linear = torch.nn.Linear(20, 30) # 20, 30是指维度
output = linear(x)

print('linear.weight.shape: ', linear.weight.shape)
print('linear.bias.shape: ', linear.bias.shape)
print('output.shape: ', output.shape)

# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
# .t就是w转置之后的部分
ans = torch.mm(x, linear.weight.t()) + linear.bias
print('ans.shape: ', ans.shape)
print(torch.equal(ans, output))


'''output:
linear.weight.shape: torch.Size([30, 20])
linear.bias.shape: torch.Size([30])
output.shape: torch.Size([128, 30])
ans.shape: torch.Size([128, 30])
True
'''