文章目录
- 解释
- 代码举例
解释
torch.max 是 PyTorch 中的一个函数,用于在张量中沿指定维度计算最大值。它有两种用法:
    ① 如果只提供一个输入张量,则返回该张量中的最大值和对应的索引。
     ② 如果提供两个输入张量,则返回两个张量中对应位置的较大值。
深度学习中主要使用第一种用法,下面对该用法举例说明:
代码举例
import torch
# 创建一个张量
# tensor = torch.rand(1, 4, 3, 3)
tensor = torch.tensor(
       [[[[2, 2, 0.7944],
          [2, 0.6368, 0.6928],
          [0.9620, 0.5716, 0.3827]],
         [[0.6216, 0, 1],
          [0.0588, 1, 0.0718],
          [1, 0.1084, 0.0462]],
         [[0.3117, 0.3333, 0.655],
          [0.8207, 0.5918, 3],
          [0.6565, 3, 0.2866]],
         [[0.6613, 0.1222, 0.0590],
          [0.4555, 0.0166, 0.0838],
          [0.3797, 0.6666, 4]]]])
# print(tensor)
print("原张量的shape为:", tensor.shape, '\n')
# 计算整个张量中的最大值和对应的索引
max_value, max_indices = torch.max(tensor, dim=1)
print("max_value:\n", max_value)  # 输出第二个维度上的最大值
print("max_indices:\n", max_indices, '\n')  # 输出第二个维度上最大值的索引
print("max_value.shape为:", max_value.shape)  # 输出每行的最大值
print("max_indices.shape为:", max_indices.shape)  # 输出每行最大值的索引
运行结果:



![[数据集][目标检测]道路圆石墩检测数据集VOC+YOLO格式461张1类别](https://img-blog.csdnimg.cn/direct/624be800869041f8836676c2c20aacb5.png)
















