PyTorch张量操作实战:从基础运算到CNN应用
1. PyTorch张量基础从概念到创建第一次接触PyTorch张量时我完全被各种术语搞晕了。什么标量、向量、矩阵还有这个奇怪的张量词。后来才发现其实张量就是多维数组的另一种说法只不过在深度学习中我们习惯用这个数学术语。1.1 张量的本质与属性张量有三个核心属性让我花了些时间才真正理解秩(Rank)表示张量的维度数量。比如标量是0维向量是1维矩阵是2维轴(Axis)张量的具体维度比如一个2D张量有行轴和列轴形状(Shape)每个轴上的元素数量比如(3,4)表示3行4列# 创建一个3x4的2D张量 t torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]]) print(t.shape) # 输出: torch.Size([3, 4]) print(t.ndim) # 输出: 2 (表示2维)1.2 创建张量的四种正确姿势在项目中踩过坑后我总结出创建张量的最佳实践torch.tensor()最常用的方法会复制数据torch.as_tensor()共享数据内存适合NumPy数组转换torch.from_numpy()专为NumPy数组设计共享内存特殊初始化方法像torch.zeros(), torch.ones(), torch.rand()import numpy as np # 从Python列表创建 data [[1,2], [3,4]] t1 torch.tensor(data) # 从NumPy数组创建(共享内存) arr np.array(data) t2 torch.as_tensor(arr) # 特殊初始化 t3 torch.zeros(2,3) # 2行3列的全0张量 t4 torch.rand(2,2) # 2x2的随机张量注意在GPU计算时要注意数据是在CPU还是GPU上。用.cuda()方法可以将张量移到GPU但数据传输有开销小数据量时可能得不偿失。2. 张量变形像玩橡皮泥一样操作数据在实际CNN项目中我经常需要调整张量形状来适配网络层。reshape、view、squeeze这些操作刚开始容易混淆直到我找到了它们的本质区别。2.1 reshape vs view内存布局的玄机两者都能改变张量形状但有个关键区别reshape可能返回原始张量的视图或副本view必须返回视图要求内存连续t torch.arange(12) print(t.reshape(3,4)) # 3行4列 print(t.view(4,3)) # 4行3列 # 先转置会使内存不连续此时不能用view t_t t.reshape(3,4).T try: t_t.view(12) # 会报错 except RuntimeError as e: print(e) # view size is not compatible with input tensors...2.2 squeeze和unsqueeze维度的增删这两个操作在处理CNN输入时特别有用squeeze删除长度为1的维度unsqueeze在指定位置增加长度为1的维度# 模拟CNN输入批次(批量大小1, 1通道, 28x28图像) input torch.rand(1,1,28,28) # 删除批量维度 squeezed input.squeeze(dim0) # 形状变为(1,28,28) # 增加通道维度 t torch.rand(28,28) # 灰度图像 unsqueezed t.unsqueeze(0) # 形状变为(1,28,28)2.3 张量拼接构建批次的神器在准备训练数据时我常用torch.cat和torch.stackcat沿现有维度拼接stack创建新维度拼接# 两个28x28的图像张量 img1 torch.rand(28,28) img2 torch.rand(28,28) # 沿高度维度拼接(56x28) cat_result torch.cat([img1, img2], dim0) # 创建批次维度(2x28x28) stack_result torch.stack([img1, img2], dim0)3. 张量运算从元素操作到广播机制在实现自定义层时深刻理解了PyTorch的运算规则。有些坑只有踩过才知道怎么避开。3.1 元素级运算的三种形式PyTorch支持三种等价的元素操作方式运算符重载t1 t2函数形式torch.add(t1, t2)原地操作t1.add_(t2)(节省内存)t1 torch.tensor([[1,2], [3,4]]) t2 torch.tensor([[5,6], [7,8]]) # 三种加法等价 result1 t1 t2 result2 torch.add(t1, t2) t1.add_(t2) # 会修改t1本身3.2 广播机制形状不同的张量如何运算广播规则让我又爱又恨。它自动扩展小张量来匹配大张量形状但理解不透容易出错。广播三步走从最后一个维度开始比较维度大小相同或其中一个为1才能广播缺失维度视为1# 矩阵(2,3) 向量(3,) → (2,3) matrix torch.tensor([[1,2,3], [4,5,6]]) vector torch.tensor([10,20,30]) result matrix vector # [[11,22,33], [14,25,36]] # 不匹配的例子会报错 bad_vector torch.tensor([10,20]) try: matrix bad_vector except RuntimeError as e: print(e) # The size of tensor a (3) must match...3.3 约简操作降维与统计在计算损失和指标时约简操作必不可少t torch.rand(3,4) # 3行4列 # 常用约简操作 print(t.sum()) # 所有元素和 print(t.mean(dim0)) # 每列均值 print(t.max(dim1)) # 每行最大值及位置 print(t.argmin()) # 最小值的扁平索引4. CNN实战张量操作的综合应用在构建CNN时我真正体会到张量操作的重要性。从输入预处理到特征图变换每一步都离不开张量操作。4.1 CNN输入张量的标准格式PyTorch使用NCHW格式N批量大小C通道数H高度W宽度# 模拟一个批次(2张RGB图像224x224) batch_size 2 channels 3 height width 224 images torch.rand(batch_size, channels, height, width) # 归一化处理(广播的应用) mean torch.tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1) std torch.tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1) normalized (images - mean) / std4.2 卷积层中的张量变换理解卷积输入输出形状变化很关键import torch.nn as nn # 定义卷积层 conv nn.Conv2d(in_channels3, out_channels64, kernel_size3, stride1, padding1) # 前向传播 output conv(normalized) print(output.shape) # torch.Size([2, 64, 224, 224]) # 手动计算输出形状 def conv_output_size(input_size, kernel_size, stride1, padding0): return (input_size - kernel_size 2*padding) // stride 1 h_out conv_output_size(224, 3, padding1) # 2244.3 全连接层前的展平操作从卷积到全连接的过渡需要展平操作# 方法1使用view flattened output.view(output.size(0), -1) # 保持批量维度 # 方法2PyTorch的flatten flattened torch.flatten(output, start_dim1) # 从第1维开始展平 print(flattened.shape) # torch.Size([2, 64*224*224])4.4 张量操作性能优化技巧在大规模训练中我总结了几个性能要点避免不必要的拷贝尽量使用as_tensor而不是tensor利用原地操作add_、mul_等后缀为_的操作合理使用torch.no_grad()在推理时禁用梯度计算注意设备一致性确保所有张量在相同设备(CPU/GPU)上# 性能对比示例 def slow_version(t): result t 1 result result * 2 return result def fast_version(t): result t.add_(1) result.mul_(2) return result # 使用no_grad加速推理 with torch.no_grad(): output model(input_tensor)在真实项目中这些张量操作构成了深度学习的基础。记得刚开始时我因为不理解广播规则调试了一整天。现在回头看掌握PyTorch张量就像学会了乐高积木的基本拼法之后搭建任何复杂模型都游刃有余了。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2471549.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!