torchvision
torchvision
是 PyTorch 的一个重要扩展库,专门针对计算机视觉任务设计。它提供了丰富的预训练模型、常用数据集、图像变换工具和计算机视觉组件,大大简化了视觉相关深度学习项目的开发流程。
我们可以在Pytorch的官网找到torchvision的文档
文档中提供了很多数据集
这里以CIFAR10为例,它是图像分类常用的数据集
CIFAR-10 数据集由 60,000 张 32x32 像素的彩色图像组成,分为 10 个类别,每个类别有 6,000 张图像。其中 50,000 张是训练图像,10,000 张是测试图像。
数据集分为五个训练批次和一个测试批次,每个批次包含 10,000 张图像。测试批次包含每个类别中随机选择的 1,000 张图像。训练批次包含剩余的图像,顺序随机,但某些训练批次可能包含一个类别的更多图像。所有训练批次加起来正好包含每个类别的 5,000 张图像。
除了数据集之外,还提供了模型torchvision.models
模块包含了一系列预训练的深度学习模型,广泛应用于图像分类、目标检测、语义分割等任务。
我们可以通过代码下载数据集
import torchvision
trans_set = torchvision.datasets.CIFAR10(root = "./dataset",train= True,download= True)
test_set = torchvision.datasets.CIFAR10(root = "./dataset",train= False,download= True)
参数列表
- root (str):
- 数据集存储的路径,数据将下载到此目录下。
- train (bool, optional):
- 如果为
True
,则加载训练集;如果为False
,则加载测试集。默认值为True
。
- 如果为
- transform (callable, optional):
- 一个函数/转换,用于对图像进行预处理,比如数据增强、归一化等。
- target_transform (callable, optional):
- 一个函数/转换,用于对目标(标签)进行处理。
- download (bool, optional):
- 如果为
True
,则从网上下载数据集(如果在指定路径中不存在)。默认值为False
。
- 如果为
下载完成后可以看到项目目录中的数据集
我们可以打印一下print("训练集数量:", len(trans_set))
查看训练集数量
完整代码如下,可以看到我们的第一个图片是cat
import torchvision
# 下载并加载CIFAR10训练数据集
trans_set = torchvision.datasets.CIFAR10(root = "./dataset", train= True, download= True)
# 下载并加载CIFAR10测试数据集
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train= False, download= True)
# 获取测试集的第一个样本和对应的标签
img, target = test_set[0]
# 显示测试集中的类别标签
print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 显示样本的图像数据
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x1BF002C7710>
# 显示样本的标签
print(target) # 3
# 根据标签索引对应的类别名称
print(test_set.classes[target]) # cat
# 显示图像
# 在这里使用PIL库的Image模块的show方法,直接在屏幕上展示图像
img.show()
这个数据集的图片都比较小(32x32 像素),放大以后虽然这个看起来并不像猫,反而像老鼠,但是它就是cat
上面我们得到的数据类型是PIL
,我们需要转为tensor
类型,我们只需要新增一个Compose然后修改dataset代码
# 定义数据集转换
dataset_transform = torchvision.transforms.Compose([
# 将图像数据转换为 Tensor
torchvision.transforms.ToTensor()
# 还可以对 Tensor 进行归一化,参数分别表示均值和标准差
#torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 下载并加载CIFAR10训练数据集
# 参数:
# root: 指定数据集的保存路径
# train: 指示是训练数据集(True)还是测试数据集(False)
# transform: 对数据集中的每个图像应用的转换操作
# download: 如果数据集不存在于指定路径且设置为True,则会自动下载数据集
trans_set = torchvision.datasets.CIFAR10(root = "./dataset", train= True, transform= dataset_transform,download= True)
# 下载并加载CIFAR10测试数据集,参数同上
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train= False,transform= dataset_transform, download= True)
然后我们执行之后,控制台会打印图片,此时是我们想要的tensor数据类型(tensor类型图片不能使用show()
)
我们就可以显示在tensorBoard中
writer = SummaryWriter("pics")
# 获取测试集的10个样本和对应的标签
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
仔细看,能够依稀辨认出第十张图片是车