pytorch小记(十九):深入理解 PyTorch 的 `torch.randint` 与 `.long` 转换
- 一、`torch.randint()` 基本概念
- 示例:生成一个二维随机整型张量
- 二、为什么需要调用 `.long()`
- 三、典型场景示例
- 1. 随机索引采样
- 2. 伪标签生成
- 3. 直接在 GPU 上生成 LongTensor
- 四、`.long()` 的几种等价写法
- 五、小结
在使用 PyTorch 进行深度学习建模或数据处理时,常常需要生成随机整数张量作为索引、伪标签或其它用途。本文将深入讲解 PyTorch 中的 torch.randint() 函数,以及为什么/如何结合 .long() 方法将张量转换为 64 位整型(LongTensor)。文末还会给出多种典型场景的实战示例,帮助你在项目中快速上手。
一、torch.randint() 基本概念
torch.randint() 用来在指定范围内均匀随机生成整数张量。它的函数签名如下:
torch.randint(
low: int = 0,
high: int,
size: Tuple[int, ...],
*,
dtype: torch.dtype = torch.int64,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
requires_grad: bool = False
) → Tensor
low:随机整数的下界(包含),默认为 0。high:随机整数的上界(不包含),必须指定。size:输出张量的形状,例如(batch_size,)、(2, 3)、(B, C, H, W)。dtype:输出张量的数据类型,默认是torch.int64(LongTensor)。device:生成张量所在设备,如'cpu'或者'cuda'。
示例:生成一个二维随机整型张量
import torch
# 在 [0, 10) 范围内,生成 2×3 的随机整数张量
x = torch.randint(0, 10, (2, 3))
print(x)
# 可能输出:
# tensor([[2, 7, 1],
# [5, 0, 9]])
print(x.dtype) # torch.int64 (默认 LongTensor)
二、为什么需要调用 .long()
虽然 torch.randint 默认即可生成 torch.int64 的张量,但在以下场景中,我们仍常见到 .long() 的调用:
-
确保索引类型
PyTorch 中,张量索引用的必须是 LongTensor(torch.int64)。如果手动指定了其它整型(如torch.int32或torch.uint8),则需要.long()转换:idx32 = torch.randint(0, 100, (16,), dtype=torch.int32) print(idx32.dtype) # torch.int32 idx64 = idx32.long() print(idx64.dtype) # torch.int64 # 这样才能用 idx64 在其它张量上进行索引 -
满足损失函数要求
例如torch.nn.CrossEntropyLoss要求标签(targets)是 LongTensor:num_classes = 10 batch_size = 32 labels = torch.randint(0, num_classes, (batch_size,)) # 默认就是 int64 # labels = labels.long() # 如果你不确定 dtype,可以显式调用 logits = torch.randn(batch_size, num_classes) loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(logits, labels) -
统一数据类型
在复杂模型或数据管道中,手动控制 dtype 能避免莫名的类型不一致错误。显式地在生成后调用.long(),可以给下游代码带来更好的可读性和健壮性。
三、典型场景示例
1. 随机索引采样
在自定义采样、数据重排或分批时,需要一组随机索引:
import torch
num_samples = 1000
batch_size = 64
# 生成 [0, num_samples) 范围内,大小为 batch_size 的随机索引
indices = torch.randint(0, num_samples, (batch_size,)).long()
# 假设 data 是一个形状为 [num_samples, ...] 的张量
data = torch.randn(num_samples, 3, 224, 224)
batch = data[indices] # 用 long 类型索引
2. 伪标签生成
在无监督或对抗训练中,有时需要生成伪标签(fake labels):
import torch
import torch.nn as nn
num_classes = 5
batch_size = 16
# 随机生成伪标签
fake_labels = torch.randint(0, num_classes, (batch_size,)).long()
# 用 CrossEntropyLoss 计算损失
logits = torch.randn(batch_size, num_classes, requires_grad=True)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, fake_labels)
loss.backward()
3. 直接在 GPU 上生成 LongTensor
如果希望生成的随机张量直接存放在 GPU 上,同样可以指定 device,并明确 dtype:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size, num_classes = 32, 10
# 一步到位生成 GPU 上的 LongTensor
labels = torch.randint(0, num_classes, (batch_size,),
device=device, dtype=torch.int64)
print(labels.device, labels.dtype) # cuda:0 torch.int64
四、.long() 的几种等价写法
tensor.long()tensor.to(torch.int64)tensor.type(torch.int64)
它们的效果相同,大家可根据个人或团队习惯任选其一。通常推荐使用 .long(),因为更简洁。
五、小结
-
torch.randint(low, high, size):生成位于[low, high)的均匀随机整数张量,默认 dtype 是torch.int64。 -
.long():将任意整型或浮点型张量转换为torch.int64(LongTensor),常用于索引、标签或保证数据类型一致。 -
典型用途:
- 随机采样索引
- 生成分类伪标签
- 在 GPU 上直接生成 long 型张量
-
最佳实践:在不确定 dtype 时显式调用
.long(),或通过dtype=torch.int64与device='cuda'一次性完成生成。



















