torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
这个函数主要是进行空间的线性映射
- in_features:输入数据的数据维度
- out_features:输出数据的数据维度
函数执行过程:
假设我们有一批数据
,
的维度为20维,这一批数据一共有128个,我们要将20维的
映射到30维空间的
中,下面是计算过程,其中
是Linear函数的weight权重

其中
, 

一个简单的例子
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| import torch
x = torch.randn(128, 20) linear = torch.nn.Linear(20, 30) output = linear(x)
print('linear.weight.shape: ', linear.weight.shape) print('linear.bias.shape: ', linear.bias.shape) print('output.shape: ', output.shape)
ans = torch.mm(x, linear.weight.t()) + linear.bias print('ans.shape: ', ans.shape) print(torch.equal(ans, output))
'''output: linear.weight.shape: torch.Size([30, 20]) linear.bias.shape: torch.Size([30]) output.shape: torch.Size([128, 30]) ans.shape: torch.Size([128, 30]) True '''
|