torch.matmul
是 PyTorch 中用于执行矩阵乘法的函数,它根据输入张量的维度自动选择适当的矩阵乘法方式,包括:
- 向量内积(1D @ 1D)
- 矩阵乘向量(2D @ 1D)
- 向量乘矩阵(1D @ 2D)
- 矩阵乘矩阵(2D @ 2D)
- 批量矩阵乘法(>2D)
函数原型
torch.matmul(input, other, *, out=None) → Tensor
- input:第一个张量
- other:第二个张量
- out(可选):指定输出张量
详细说明
torch.matmul(a, b)
根据 a
和 b
的维度规则如下:
a 维度 | b 维度 | 操作类型 |
---|---|---|
1D | 1D | 向量点积 |
2D | 1D | 矩阵和向量相乘 |
1D | 2D | 向量和矩阵相乘 |
2D | 2D | 标准矩阵乘法 |
≥3D | ≥3D | 批量矩阵乘法(batch) |
示例代码
1. 向量点积(1D @ 1D)
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
result = torch.matmul(a, b)
print(result) # 输出:32.0
2. 矩阵乘向量(2D @ 1D)
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([5.0, 6.0])
result = torch.matmul(a, b)
print(result) # 输出:[17.0, 39.0]
3. 向量乘矩阵(1D @ 2D)
a = torch.tensor([5.0, 6.0])
b = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
result = torch.matmul(a, b)
print(result) # 输出:[23.0, 34.0]
4. 矩阵乘矩阵(2D @ 2D)
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
result = torch.matmul(a, b)
print(result)
# 输出:
# [[19.0, 22.0],
# [43.0, 50.0]]
5. 批量矩阵乘法(3D @ 3D)
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
result = torch.matmul(a, b)
print(result.shape) # 输出:torch.Size([10, 3, 5])
综合示例:自定义线性层(类似 nn.Linear
)
下面是一个使用 torch.matmul
构建自定义线性层的完整示例,适合理解如何手动定义一个具有权重、偏置、支持自动求导的神经网络层,适合自定义网络结构或深入理解 PyTorch 的底层机制。
功能描述
- 实现线性变换:
y = x @ W^T + b
- 使用
torch.matmul
执行矩阵乘法 - 权重和偏置作为可训练参数
- 支持 GPU 和自动求导
代码实现
import torch
import torch.nn as nn
class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super(MyLinear, self).__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features)) # shape: [out, in]
self.bias = nn.Parameter(torch.zeros(out_features)) # shape: [out]
def forward(self, x):
# x: shape [batch_size, in_features]
# weight: shape [out_features, in_features]
# transpose weight -> shape [in_features, out_features], then matmul
out = torch.matmul(x, self.weight.t()) + self.bias
return out
使用示例
batch_size = 4
in_dim = 6
out_dim = 3
x = torch.randn(batch_size, in_dim)
layer = MyLinear(in_dim, out_dim)
output = layer(x)
print(output.shape) # torch.Size([4, 3])
与官方 nn.Linear
等效性验证(可选)
# 官方线性层
torch.manual_seed(0)
official = nn.Linear(in_dim, out_dim)
# 自定义线性层,使用相同参数初始化
custom = MyLinear(in_dim, out_dim)
custom.weight.data.copy_(official.weight.data)
custom.bias.data.copy_(official.bias.data)
# 比较输出
x = torch.randn(2, in_dim)
out1 = official(x)
out2 = custom(x)
print(torch.allclose(out1, out2)) # True
说明
项 | 内容 |
---|---|
torch.matmul | 用于实现 x @ W.T 矩阵乘法 |
nn.Parameter | 注册为可训练参数,自动加入 .parameters() 中 |
Module.forward() | 用于定义前向传播逻辑 |
注意事项
- 输入张量必须满足矩阵乘法的维度匹配规则。
- 对于 >2D 的张量,PyTorch 会自动按 batch size 广播执行多组矩阵乘法。
torch.matmul
不支持标量乘法(标量乘张量可用*
运算符)。