torch.nn.Module 和 torch.nn.Parameter#
在本视频中,我们将讨论 PyTorch 提供的一些用于构建深度学习网络的工具。
除了 Parameter,本视频中讨论的类都是 torch.nn.Module 的子类。这是 PyTorch 的基类,旨在封装 PyTorch 模型及其组件特有的行为。
torch.nn.Module 的一个重要行为是注册参数。如果某个 Module 子类有学习权重,这些权重被表示为 torch.nn.Parameter 的实例。Parameter 类是 torch.Tensor 的子类,其特殊行为是当它们被分配为 Module 的属性时,它们会被添加到该模块参数列表中。这些参数可以通过 Module 类的 parameters() 方法访问。
作为一个简单的例子,这里有一个非常简单的模型,包含两个线性层和一个激活函数。我们将创建它的一个实例并让它报告其参数
import torch
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.linear1 = torch.nn.Linear(100, 200)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(200, 10)
self.softmax = torch.nn.Softmax()
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.softmax(x)
return x
tinymodel = TinyModel()
print('The model:')
print(tinymodel)
print('\n\nJust one layer:')
print(tinymodel.linear2)
print('\n\nModel params:')
for param in tinymodel.parameters():
print(param)
print('\n\nLayer params:')
for param in tinymodel.linear2.parameters():
print(param)
The model:
TinyModel(
(linear1): Linear(in_features=100, out_features=200, bias=True)
(activation): ReLU()
(linear2): Linear(in_features=200, out_features=10, bias=True)
(softmax): Softmax(dim=None)
)
Just one layer:
Linear(in_features=200, out_features=10, bias=True)
Model params:
Parameter containing:
tensor([[ 0.0540, 0.0889, -0.0109, ..., -0.0959, 0.0421, 0.0948],
[-0.0192, 0.0616, 0.0658, ..., 0.0874, -0.0139, 0.0918],
[ 0.0931, 0.0411, 0.0915, ..., -0.0565, 0.0179, -0.0704],
...,
[-0.0601, -0.0116, 0.0308, ..., -0.0949, -0.0367, 0.0736],
[ 0.0848, -0.0630, -0.0730, ..., -0.0832, -0.0086, -0.0087],
[ 0.0875, 0.0732, -0.0594, ..., 0.0169, 0.0162, 0.0542]],
requires_grad=True)
Parameter containing:
tensor([-2.8391e-02, 8.8786e-02, -6.4435e-03, 1.9568e-02, 6.6545e-02,
-3.8073e-02, 4.0056e-02, 9.8252e-02, 6.0742e-02, 2.6323e-02,
-6.3688e-02, 9.5054e-02, 8.1455e-02, 2.7224e-03, 2.7485e-02,
-5.3290e-02, 8.9486e-02, -3.0375e-02, -1.6629e-02, -9.4276e-02,
6.3886e-02, -1.7389e-02, 1.6478e-03, -6.8702e-02, -2.5034e-02,
-2.9890e-02, 1.2130e-02, 7.0402e-02, -2.6131e-02, 3.0848e-02,
-2.3914e-03, -6.8471e-02, -1.6653e-02, 3.0541e-02, 7.3755e-02,
-4.1249e-02, 9.4892e-02, -9.2014e-02, -9.5326e-02, 6.7583e-03,
-4.8404e-02, 7.3692e-02, -9.5953e-03, 2.0520e-02, 9.6995e-02,
-9.6371e-02, -9.3585e-02, 8.1368e-02, 6.1899e-02, -1.9492e-03,
-2.7659e-02, -2.4900e-03, 1.0500e-02, -8.0740e-02, -6.1757e-02,
7.2164e-02, 6.2586e-02, -7.9982e-02, -5.4769e-02, -4.9737e-02,
-6.4661e-02, 4.1963e-02, -8.7076e-02, -5.0482e-03, -3.0410e-05,
8.8162e-02, -5.6084e-02, 9.3488e-02, 8.9329e-02, 1.5383e-02,
-5.5996e-03, -9.7878e-02, 8.8348e-02, -7.0886e-02, 5.7076e-02,
8.5237e-02, 6.7058e-02, -4.5111e-02, 3.6577e-02, -8.0919e-02,
-2.8820e-02, 6.7889e-02, 1.8501e-02, -8.4626e-02, 1.0139e-02,
-5.2166e-02, 8.8196e-03, 3.7661e-02, 3.5405e-02, -5.7670e-02,
-3.9214e-02, 9.2920e-02, 9.1581e-02, 9.5697e-02, -6.1620e-02,
-9.0639e-02, -2.7645e-02, 5.5318e-02, 5.2429e-02, 4.9890e-02,
-8.5084e-02, -6.8121e-04, 1.6863e-02, -5.6012e-03, -9.4513e-02,
4.7324e-02, -1.6331e-03, -5.7407e-03, -4.8910e-02, 2.7390e-02,
-2.9120e-02, 5.2268e-02, 7.9739e-03, 5.9733e-02, 1.4329e-02,
5.4806e-02, -9.2461e-02, -4.2292e-02, 7.1391e-02, -9.3267e-03,
2.5865e-02, -3.2159e-02, -3.5534e-02, 4.5665e-03, 4.3144e-03,
1.6937e-02, -6.3085e-03, 4.5387e-03, -8.1251e-02, 2.7151e-02,
-9.3098e-02, -3.0626e-02, -1.6267e-02, 6.1479e-02, 9.2800e-02,
4.5886e-02, 7.1244e-02, -6.4789e-02, -9.4300e-02, -8.9892e-02,
-9.6265e-02, 5.7603e-02, 2.7417e-02, -9.3216e-02, -2.9369e-02,
-9.0568e-02, 5.2199e-02, -5.3580e-02, 5.1615e-02, -6.1951e-02,
1.7894e-02, -7.9597e-02, -3.8138e-02, -2.8243e-02, 2.8240e-03,
-6.0696e-02, 4.4213e-02, -4.6199e-02, 6.5946e-02, 1.4723e-02,
8.3900e-02, 8.1386e-02, 1.3186e-02, -3.9898e-02, -8.6006e-02,
8.7549e-02, -7.3356e-02, 7.0558e-02, 1.7812e-02, 6.3452e-02,
-6.6243e-02, -7.6435e-02, 5.1467e-02, 7.3187e-03, -4.1000e-02,
9.1473e-03, -4.3123e-02, 4.6625e-02, -3.0680e-02, 2.0004e-02,
-3.2730e-02, 7.6111e-03, 5.6459e-02, -5.9493e-02, -6.5789e-02,
8.8485e-02, -5.5954e-03, 3.0834e-02, -1.7522e-02, 8.6342e-02,
-8.5151e-02, -9.9866e-02, -2.2536e-02, 5.8566e-02, -7.6556e-02,
9.1213e-02, 9.7890e-02, -2.7655e-02, -2.7763e-02, 8.5908e-02],
requires_grad=True)
Parameter containing:
tensor([[ 0.0287, -0.0437, -0.0418, ..., 0.0395, 0.0280, -0.0323],
[-0.0242, 0.0524, -0.0388, ..., -0.0188, 0.0374, -0.0056],
[-0.0486, 0.0385, -0.0122, ..., 0.0675, 0.0428, -0.0242],
...,
[-0.0644, -0.0628, -0.0046, ..., -0.0388, 0.0258, 0.0546],
[ 0.0386, 0.0101, 0.0022, ..., 0.0001, -0.0164, -0.0397],
[ 0.0271, 0.0234, 0.0067, ..., -0.0335, -0.0107, 0.0539]],
requires_grad=True)
Parameter containing:
tensor([ 0.0093, -0.0178, -0.0259, 0.0465, -0.0456, 0.0262, -0.0185, -0.0208,
-0.0189, -0.0548], requires_grad=True)
Layer params:
Parameter containing:
tensor([[ 0.0287, -0.0437, -0.0418, ..., 0.0395, 0.0280, -0.0323],
[-0.0242, 0.0524, -0.0388, ..., -0.0188, 0.0374, -0.0056],
[-0.0486, 0.0385, -0.0122, ..., 0.0675, 0.0428, -0.0242],
...,
[-0.0644, -0.0628, -0.0046, ..., -0.0388, 0.0258, 0.0546],
[ 0.0386, 0.0101, 0.0022, ..., 0.0001, -0.0164, -0.0397],
[ 0.0271, 0.0234, 0.0067, ..., -0.0335, -0.0107, 0.0539]],
requires_grad=True)
Parameter containing:
tensor([ 0.0093, -0.0178, -0.0259, 0.0465, -0.0456, 0.0262, -0.0185, -0.0208,
-0.0189, -0.0548], requires_grad=True)
这展示了 PyTorch 模型的基本结构:有一个 __init__() 方法,它定义了模型的层和其他组件;还有一个 forward() 方法,在那里执行计算。请注意,我们可以打印模型或其任何子模块来了解其结构。