Tensorboard、Dataset、Transforms、Dataloader
该文档主要参考【土堆】的视频教程:pytorch入门教程–土堆
一、Tensorboard
安装tensorboard:pip install tensorboard
使用步骤:
- 引入相关库:
from torch.utils.tensorboard import SummaryWriter - 构建
SummaryWriter对象:writer = SummaryWriter(log_dir="logs")- 在工程目录下创建一个名为
logs的文件夹,用于存放Tensorboard绘图所用的文件
- 在工程目录下创建一个名为
- 打开
tensorboard- 命令行执行:
tensorboard --logdir=logs- 如果有错误,使用
logs的绝对地址
- 如果有错误,使用
- 点击链接,即可查看
- 命令行执行:
[常用函数]
add_scalar、add_images、add_graph
1.1、add_scalar
功能:添加标量数据(例如记录训练epoch及对应的loss)
常用参数:
- 标题:
tag - 标量数据(Y轴):
scalar_value - 计步数据(X轴):
global_step
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="logs")
for i in range(100):
writer.add_scalar(tag="y=3x", scalar_value=3 * i, global_step=i)
注意:使用同一个SummaryWriter对象且tag相同时,会绘制在同一幅图上,为避免该情况可以删除logs中的内容,或者每次都新建文件夹(log_dir="logs_1")
图片如下(左图为只绘制一次的结果y=3x,右图为在同一幅图上分别绘制y=2x以及y=10x的结果)


1.2、add_image
功能:添加图片数据(例如记录每个batch的输入图片)
常用参数:
标题:tag
图片数据:img_tensor(一般是torch.Tensor或者numpy.array类型)
计步数据:global_step
图片格式:dataformats(CHW, HWC, HW, WH等,默认为'CHW')
from torch.utils.tensorboard import SummaryWriter
import os
import cv2
project_path = os.getcwd()
file_name_1 = 'dog_1.png'
file_name_2 = 'dog_2.png'
file_path = os.path.join(project_path, r'data\dog')
full_file_path_1 = os.path.join(file_path, file_name_1)
full_file_path_2 = os.path.join(file_path, file_name_2)
# 使用cv2(即OpenCV库)读取图片时,图片通常是以HWC(高度、宽度、通道)格式存储的,并且每个像素的颜色值(对于RGB图像)都是0到255之间的整数
image_data_1 = cv2.imread(full_file_path_1)
image_data_2 = cv2.imread(full_file_path_2)
# 将图片从BGR转换为RGB(因为cv2默认通道顺序为BGR,使用PIL读取图片的通道顺序为RGB):如果不进行调整,则图片颜色会失真
image_data_1 = cv2.cvtColor(image_data_1, cv2.COLOR_BGR2RGB)
image_data_2 = cv2.cvtColor(image_data_2, cv2.COLOR_BGR2RGB)
writer = SummaryWriter(log_dir="logs")
writer.add_image(tag='dog', img_tensor=image_data_1, global_step=0, dataformats='HWC')
writer.add_image(tag='dog', img_tensor=image_data_2, global_step=1, dataformats='HWC')
writer.close()
图片如下(通过拖动进度条可以查看不同step对应的图片)


1.3、add_graph
功能:添加模型结构数据(例如记录神经网络的结构)
常用参数:
模型:model(可以构建自己的模型或者使用公开的经典模型,例如VGG-16)
模型输入:input_to_model(要求图片数据是Tensor类型)
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import cv2
from torchvision import transforms
if __name__ == '__main__':
# 使用内置的分类模型VGG-16(后续会更新如何搭建模型的文章)
vgg_16 = torchvision.models.vgg16(progress=False)
print(vgg_16)
image = cv2.imread('data\\dog\\dog_1.png')
# 图片类型转换为Tensor并调整尺寸为224*224(vgg16的标准输入尺寸)
image = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])(image)
# 添加一个批次维度(模型通常期望输入具有批次维度),使得形状从(C, H, W)变为(1, C, H, W),其中1表示批次大小。
image = torch.unsqueeze(image, 0)
writer = SummaryWriter('model')
# image提供模型在前向传播过程中所需的输入数据,TensorBoard据此生成模型的计算图
writer.add_graph(model=vgg_16, input_to_model=image)
writer.close()


二、Dataset
2.1、使用公开数据集
-
常用的数据集:
MNIST、CIFAR10等 -
这些封装好的数据集都继承了
torch.utils.data中的Dataset类,该类有两个重要的方法:getitem()、len(); -
可以通过参数
transform以及target_transform在加载数据时进行实时的数据增强操作(如旋转、裁剪、缩放等);- 对图像数据的增强操作详见章节三
Transforms
- 对图像数据的增强操作详见章节三
-
可以通过继承
Dataset类并重写getitem()、len()方法创建自己的数据集类(使用自己的数据)
from torchvision import datasets
if __name__ == '__main__':
# 指定数据集路径(下载好的数据集会自动解压到路径下)
data_path = 'common_dataset'
# train=True表示为训练集,download=True表示下载数据集(若已经下载好则自动加载本地数据集)
# 若在线下载速度慢,可进入CIFAR10类中,直接通过数据集的下载链接下载(下载好放在data_path下即可)
train_data = datasets.CIFAR10(root=data_path, train=True, download=True)
test_data = datasets.CIFAR10(root=data_path, train=False, download=True)
# 打印数据集所包含的数据个数
print(len(train_data), len(test_data))
# 获取第一个数据的图片(PIL Image类型)及标签(类别)
img, label = train_data[0]
# 打印类别索引及真实的类别
print(label, train_data.classes[label])
img.show()
2.2、使用自己的数据
from torch.utils.data import Dataset
import cv2
import os
class MyDataset(Dataset):
def __init__(self, data_path, label):
# super.__init__()
self.data_path = data_path
self.label = label
self.full_path = os.path.join(self.data_path, self.label)
self.images_name = os.listdir(self.full_path)
def __getitem__(self, item):
image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))
# BGR转换为RGB,不然会失真
image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
return image_data, self.label
def __len__(self):
return len(self.images_name)
if __name__ == '__main__':
data_path = os.path.join(os.getcwd(), 'data')
label = 'dog'
dataset_instance = MyDataset(data_path, label)
print(len(dataset_instance))
image, label = dataset_instance[0]
print(image.shape, label)
print(type(image))
三、Transforms
Transforms是用于处理图片的库,内置的类基本可以满足图片处理的需求,例如图片类型转换(PIL Image、ndarray、tensor)、尺寸调整、裁剪等
-
若没有
torchvision则需要先安装:pip install torchvision -
[常用功能(类)]
ToTensor、Normalize、Resize、RandomCrop、Compose -
使用方式:根据需求选择类,创建类的实例,使用类的实例完成图片处理
3.1、ToTensor
-
功能:将
PIL Image、ndarray类型的图片转换为Tensor类型(Convert a PIL Image or ndarray to tensor and scale the values accordingly) -
输入:
PIL Image、ndarray类型的图片(PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]) -
输出:
Tensor类型,shape为CHW,每个元素均为[0.0, 1.0]之间的数(torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0])
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import cv2
import os
# 创建Dataset的子类
class MyDataset(Dataset):
def __init__(self, data_path, label):
# super.__init__()
self.data_path = data_path
self.label = label
self.full_path = os.path.join(self.data_path, self.label)
self.images_name = os.listdir(self.full_path)
def __getitem__(self, item):
image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))
# BGR转换为RGB
image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
return image_data, self.label
def __len__(self):
return len(self.images_name)
if __name__ == '__main__':
data_path = os.path.join(os.getcwd(), 'data')
label = 'dog'
# 创建dataset子类实例,用于读取图片
dataset_instance = MyDataset(data_path, label)
# 输出图片数量
print(len(dataset_instance))
image, label = dataset_instance[0]
# 根据索引获取图片
print(image.shape, label)
writer = SummaryWriter(log_dir='transforms_logs')
# 使用ToTensor
to_tensor = transforms.ToTensor()
image_tensor = to_tensor(image)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=0)
3.2、Normalize
-
功能:对每一个通道(
channel)分别根据其均值、标准差进行标准化(Normalize a tensor image with mean and standard deviation) -
输入:
Tensor类型的图片(This transform does not support PIL Image) -
输出:标准化后的
Tensor类型的图片,output[channel] = (input[channel] - mean[channel]) / std[channel]
# 使用Normalize
# 创建对象的时候给定均值、标准差
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
image_tensor = normalize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=1)
3.3、Resize
-
功能:调整图片
H、W(Resize the input image to the given size)-
若
size为序列(例如size=[500, 800]),则调整后的图片H=500,W=800; -
若
size为整数(size=500),则根据H、W中较小的值确定调整后的尺寸- 例如
H=600,W=1200,则调整后的H=500,调整后的W为 1200 / 600 ∗ 500 = 1000 1200/600*500=1000 1200/600∗500=1000
- 例如
-
-
输入:
PIL Image、Tensor类型的图片 -
输出:与输入的类型相同
# 使用Resize
resize = transforms.Resize(size=(500, 1000))
image_tensor = resize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=2)
3.4、RandomCrop
- 功能:对图片进行随机裁剪(Crop the given image at a random location)
- 若
size为序列(例如size=[500, 800]),则裁剪后的图片H=500,W=800; - 若
size为整数(size=500),则裁剪后的图片H=500,W=500; - 裁剪后的图片H、W均不大于原有图片的H、W,否则会报错
- 若
- 输入:
PIL Image、Tensor类型的图片 - 输出:与输入的类型相同
# 使用RandomCrop
random_crop = transforms.RandomCrop(size=(300, 800))
image_tensor = random_crop(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=3)
3.5、结果展示
从上到下分别是Normalize、Resize、RandomCrop顺序执行后的结果



3.6、Compose
-
功能:指定一系列图片处理步骤,对图片进行流式处理(Composes several transforms together)
- 使用列表指定需要对图片进行的处理,
[transform_1, transform_2,...]
- 使用列表指定需要对图片进行的处理,
-
输入:
PIL Image、ndarray、Tensor类型的图片 -
输出:由图片类型、指定的处理步骤决定
# 使用
compose = transforms.Compose([to_tensor, normalize, resize, random_crop])
image_tensor = compose(image_tensor)
writer.add_image(tag='dog_compose', img_tensor=image_tensor, global_step=0)
四、Dataloader
Dataloader用于批量加载和处理数据,能数据集分成小批量,并在训练过程中按需加载这些小批量数据,以提高训练效率并节省内存。
- 批量加载数据:参数
batch_size,每次加载batch_size个数据,而不是一次性加载整个数据集; - 数据“洗牌”:参数
shuffle,在每个训练周期开始时随机打乱数据顺序,防止模型过拟合; - 并行处理:参数
num_workers,利用多个线程或进程加快数据加载过程;
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
if __name__ == '__main__':
data_path = 'common_dataset'
# 为了方便使用tensorboard进行展示,使用transform=transforms.ToTensor()将图片由PIL类型转换为Tensor类型
train_data = datasets.CIFAR10(root=data_path, train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor(), download=True)
# 打印数据集所包含的数据个数
print(len(train_data), len(test_data))
# 获取第一个数据的图片及标签(类别)
img, label = train_data[0]
# 打印类别索引及真实的类别
print(label, train_data.classes[label])
# img.show()
writer = SummaryWriter(log_dir='CIFAR10_logs')
# dataloader示例
# drop_last=True可以舍弃最后的不足一批(batch_size)的图片
data_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, drop_last=False)
# 指定训练的最大epoch
max_epoch = 1
for epoch in range(max_epoch):
i = 0
for images, labels in data_loader:
writer.add_images(tag=f'CIFAR10_{epoch}', img_tensor=images, global_step=i)
i += 1
writer.close()
tensorboard中部分batch的图片如下:











![[翻译] Asset Administration Shells](https://i-blog.csdnimg.cn/direct/41ce485b4c6d41ffa720e73b4d5ec559.png)









