本文将介绍如何用PyTorch构建模型
torch.nn.Module和torch.nn.Parameter
除了Parameter之外,本视频中讨论的所有类都是torch.nn.Module的子类。这是PyTorch基类,用于封装PyTorch模型及其组件的特定行为。
torch.nn.Module的一个重要行为是注册参数。如果特定的Module子类具有学习权值,则这些权值表示为torch.nn.Parameter的实例。Parameter类是torch的子类。张量,具有特殊的行为,当它们被分配为模块的属性时,它们被添加到该模块的参数列表中。这些参数可以通过Module类的parameters()方法访问。
作为一个简单的例子,这里有一个非常简单的模型,有两个线性层和一个激活函数。我们将创建它的一个实例,并要求它报告其参数:
import torch
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.linear1 = torch.nn.Linear(100, 200)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(200, 10)
self.softmax = torch.nn.Softmax()
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.softmax(x)
return x
tinymodel = TinyModel()
print('The model:')
print(tinymodel)
print('\n\nJust one layer:')
print(tinymodel.linear2)
print('\n\nModel params:')
for param in tinymodel.parameters():
print(param)
print('\n\nLayer params:')
for param in tinymodel.linear2.parameters():
print(param)
输出为:
The model:
TinyModel(
(linear1): Linear(in_features=100, out_features=200, bias=True)
(activation): ReLU()
(linear2): Linear(in_features=200, out_features=10, bias=True)
(softmax): Softmax(dim=None)
)
Just one layer:
Linear(in_features=200, out_features=10, bias=True)
Model params:
Parameter containing:
tensor([[-0.0451, 0.0361, 0.0902, ..., -0.0564, -0.0323, 0.0335],
[ 0.0668, 0.0843, 0.0506, ..., 0.0162, 0.0668, -0.0089],
[-0.0505, -0.0148, 0.0485, ..., 0.0714, -0.0399, 0.0798],
...,
[-0.0639, 0.0345, -0.0766, ..., 0.0711, -0.0354, -0.0719],
[ 0.0827, -0.0614, 0.0078, ..., 0.0531, -0.0672, 0.0158],
[-0.0577, -0.0733, -0.0662, ..., -0.0263, -0.0143, -0.0904]],
requires_grad=True)
Parameter containing:
tensor([-1.8241e-02, -8.1554e-02, 3.1390e-02, -9.7299e-02, -3.9416e-02,
-3.4526e-02, 6.9457e-02, 9.3126e-02, 8.3945e-02, 2.5128e-02,
-1.9594e-02, 1.4253e-02, 7.5062e-02, -2.5254e-02, 2.5275e-02,
3.6509e-02, -5.4355e-02, 5.2070e-02, -1.1055e-02, 6.3872e-02,
-4.2867e-02, -6.9062e-02, -9.6398e-02, 6.0366e-02, 8.6856e-02,
-4.3543e-02, 7.1326e-02, 3.6623e-03, 5.4014e-02, -1.3758e-02,
4.6091e-02, 4.6796e-03, -2.9959e-02, -5.0925e-02, 1.9598e-02,
5.6875e-03, -2.5505e-02, 9.8728e-02, 4.3602e-02, 3