dataset.py篇

news2025/7/14 20:19:37

dataset.py

目录:

  • 前言
  • 观察数据
  • 书写代码
  • 函数解释

前言

在步骤中需要写自己的dataset类,并将label和image一一对应后返回。

观察数据

在书写dataset前最重要的就是要观察数据集,对数据集进行分析,比如了解图片大小,通道数目,他的ndarray的dtype类型等等。甚至可以自己书写一个脚本,对数据本身进行分析。

如下以果蝇电镜图为例,我们观察数据集,知道Images中存放30张原图,Labels中存放30张已经分割好的图片,每张图片以.png方式进行存储;Images和Labels中的图片通过文件名进行一一对应,如Labels中0.png对应Images中0.png,进行共有30组对应数据;根据model.py文件了解需要传入的Tensor类型,思考如何将现有数据转为需要的Tensor返回。

请添加图片描述

知道以上信息后,我们书写dataset.py将image和label一一对应返回。

书写代码

在本步骤中,我们需要告诉程序如何读入你的数据,并且做一些预处理。我们的DIYDataset类继承自Dataset类,并重写__init____len____getitem__三个魔法方法。网络训练权重为float32,所以传入数据也一般要为float32。

  1. __init__方法向外索取3个输入数据:
  • 读取数据路径
  • 是训练集还是验证集(因为你训练集和验证集往往是在两个不同的文件夹)
  • 你使用的预处理方法(以transform为主,transform也可以根据验证集还是测试集调用不同的trans)
  1. __getitem__需要返回image和对应label的Tensor
  2. __len__用来返回集合中图片个数

以S-BIAD25为例,代码如下:

'''
该类用于返回dataset
input和label都是[512, 512]的图片,需要将其转换为[3, 512, 512]才能transforms
'''
# --- add path and DIY package
import sys, os
root_path = os.path.dirname(os.path.dirname(__file__))
project_path = os.path.dirname(__file__)
sys.path.append(root_path)
sys.path.append(project_path)
from _utils import tensor_info  # 耦合了
# ---
from torch.utils.data.dataset import Dataset
from PIL import Image
from utils.utils import cvtColor,resize_image
from torchvision import transforms


class CellDataset(Dataset):
    """返回细胞分隔的dataset"""

    def __init__(self, path:str, transforms:object=None) -> None:
        super().__init__()
        self.path       = path
        self.labels     = os.listdir(os.path.join(path,'Labels'))
        self.transforms = transforms
        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        yield_name  = self.labels[index]
        # 返回label图片
        label_path  = os.path.join(self.path, 'Labels', yield_name)
        label_image = Image.open(label_path)
        label_image = cvtColor(label_image)                     # 将单通道灰度图片, 四通道png转换为三通道RGB图片。因为此处torchvision.transforms只接受三通道图片,pytorch中模型一般只能训练dtype=float32的Tensor。
        label_image = resize_image(label_image,(512, 512))[0]   # 裁剪图片大小为512 * 512
        # 返回训练图片
        image_path  = os.path.join(self.path, 'Images', yield_name)
        image       = Image.open(image_path)
        image       = cvtColor(image)
        image       = resize_image(image,(512, 512))[0]
        # 返回单张input, label
        return self.transforms(image), self.transforms(label_image)


class MouseDataset(Dataset):
    """返回细胞分隔的dataset"""

    def __init__(self, path) -> None:
        super().__init__()
        self.path       = path
        self.f_actin    = os.path.join(path, "F-actin") 
        self.labels_factin = os.listdir(self.f_actin)
        # self.labels     = os.listdir(os.path.join(path,'Labels'))
        self.transforms = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.labels_factin)

    def __getitem__(self, index):
        yield_name  = self.labels_factin[index][-19:]
        # 返回label图片
        label_path  = os.path.join(self.path, 'F-actin', "img_568"+yield_name)
        label_image = Image.open(label_path)
        label_image = cvtColor(label_image)                     # 将单通道灰度图片, 四通道png转换为三通道RGB图片。因为此处torchvision.transforms只接受三通道图片,pytorch中模型一般只能训练dtype=float32的Tensor。
        label_image = resize_image(label_image,(512, 512))[0]   # 裁剪图片大小为512 * 512
        # 返回训练图片
        image_path  = os.path.join(self.path, 'retardance', "img_Retardance"+yield_name)
        image       = Image.open(image_path)
        image       = cvtColor(image)
        image       = resize_image(image,(512, 512))[0]
        # 返回单张input, label
        return self.transforms(image), self.transforms(label_image)


if __name__ == "__main__":
    """test"""
    data_path = "/home/yingmuzhi/_data/S-BIAD25"
    cell_dataset = MouseDataset(data_path)

    input, label = cell_dataset[0]  # 同 input, label = cell_dataset.__getitem__(0)
    tensor_info.TensorInfo(input).show_info()
    tensor_info.TensorInfo(label).show_info()

测试结果

我们在if __name__ == "__main__":中进行测试,结果如下:

请添加图片描述

函数解释

对一些常用函数进行解释,可以当做字典查看

torchvision.transforms.ToTensor

只接受PIL.Image类型的对象或者numpy.ndarray类型的对象,将上面两个对象转为torch.Tensor对象。函数定义如下:

在这里插入图片描述

参考链接https://pytorch.org/vision/stable/generated/torchvision.transforms.ToTensor.html?highlight=totensor#torchvision.transforms.ToTensor

__init__

在该方法中需要根据传入的path参数和train参数找到你的测试集或者训练集的物理地址,并将集合中的images和labels的物理地址存储在list中,以供后面方法使用。案例如下:

def __init__(self, root: str, train: bool, transforms: object=None):
    super(DriveDataset, self).__init__()
    self.flag = "training" if train else "test" # 由 train: bool 的布尔值来判断是取train还是test
    data_root = os.path.join(root, "DRIVE", self.flag)
    assert os.path.exists(data_root), f"path '{data_root}' does not exists."
    self.transforms = transforms
    img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
    self.img_list = [os.path.join(data_root, "images", i) for i in img_names]               # 返回所有images地址的list
    self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 返回所有manuals地址的list
                   for i in img_names]
    # check manual files
    for i in self.manual:
        if os.path.exists(i) is False:
            raise FileNotFoundError(f"file {i} does not exists.")
    self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")    # 返回所有mask地址的list
                     for i in img_names]
    # check mask files
    for i in self.roi_mask:
        if os.path.exists(i) is False:
            raise FileNotFoundError(f"file {i} does not exists.")

__len__

返回测试集或者验证集中Image数量,因为Image和Label往往是一一对应的,所以返回哪个其实都一样。三维数据可能存在多对一情况。

案例如下:

def __len__(self):
	return len(self.img_list)

__getitem__

该方法根据image和label的物理地址,用PIL打开图片,再用transforms处理Image返回Tensor,最后返回处理过的Tensor类型元组(image, label)。

在该方法中,你可以使用PIL处理图像(mode),也可以将PIL转为numpy使用numpy处理图片(元素类型dtype),也可以使用Transforms处理图片(Normalization)等。

__getitem__案例见下:

def __getitem__(self, idx):
    """将Image转为RGB, 将label转为L"""
    img = Image.open(self.img_list[idx]).convert('RGB')
    manual = Image.open(self.manual[idx])
    manual = manual.convert('L')
    manual = np.array(manual) / 255
    roi_mask = Image.open(self.roi_mask[idx]).convert('L')
    roi_mask = 255 - np.array(roi_mask)
    # 将manual图片和Imae进行处理
    mask = np.clip(manual + roi_mask, a_min=0, a_max=255)
    # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
    mask = Image.fromarray(mask)
    if self.transforms is not None:
        img, mask = self.transforms(img, mask)
    return img, mask

使用PIL处理

参看链接https://blog.csdn.net/qq_43369406/article/details/127781871

使用Numpy处理

参看链接https://blog.csdn.net/qq_43369406/article/details/127781871

使用transforms处理

参看链接[coming soon]

这段代码我们常写在train.py中。在进行transforms累加时候,我们常将所需要的transforms全部添加至一个list中,再将这个list给transforms.Compose掉,注意一定要添加transforms.ToTensor方法。transform的更多内容可以参考笔者的transforms博客。

在调用transforms时完整逻辑如下:

# 获取dataset
train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))
# 获取tranforms
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
	"""对要获取的transforms进行判断,看是测试集的dataset还是验证集"""
    base_size = 565
    crop_size = 480

    if train:
        return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
    else:
        return SegmentationPresetEval(mean=mean, std=std)
# 测试集transforms
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(1.2 * base_size)

        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))
        trans.extend([
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)

增加了魔法方法__call__只是为了对image/input,label/target一块进行transforms。

torchvision.transforms.Compose()和torchvision.transforms.[functions]

Compose()类可以将多个transforms对象合在一起给数据进行预处理,常以多个transforms对象的list形式传入Compose()中,如transforms.Compose([])。transforms.[functions]则可以对多个输入数据进行变换,Compose()函数原型如下:

# 原型
torchvision.transforms.Compose(transforms)
# example
data_transform = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]),
    "val": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
}

torchvision.datasets.ImageFolder()类

ImageFolder()用于加载数据集,它说到底还是继承自torch.utils.data.Dataset,后续可以作为Dataset对象直接传入torch.utils.data.DataLoader中。对于ImageFolder(),root是要加载数据集的路径,transforms是对数据进行预处理的方式,函数原型和example如下:

# 原型
torchvision.datasets.ImageFolder(root: str, transform: Optional[Callable] = None, 
									target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = <function default_loader>, is_valid_file: Optional[Callable[[str], bool]] = None)
# example
train_dataset = datasets.ImageFolder(root=project_path + "/flower_data/train", transform=data_transform["train"])

python 列表生成式和生成器

python 常见列表生成式 和 生成器,其中列表生成式以[]圈起,生成器以()圈起,常见列表生成器如下:

在这里插入图片描述

而生成器则是把中括号改成小括号。

# 列表生成式
generate_list = [i for i in range(10)  if i < 5]
# 生成器
cla_dict = dict((val, key) for key, val in flower_list.items())	# 解包取出键值对,生成dict字典,赋值给cla_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/17814.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

maven基础入门

maven 1、maven简介 Apache Maven 是一个项目管理和构建工具&#xff0c;它基于项目对象模型(POM)的概念&#xff0c;通过一小段描述信息来管理项目的构建、报告和文档。官网 &#xff1a;http://maven.apache.org/什么是Maven&#xff1f;这里先引用知乎的一个回答 我先不说…

第五届“传智杯”全国大学生计算机大赛(练习赛) [传智杯 #5 练习赛] 复读

[传智杯 #5 练习赛] 复读 题目描述 给定若干个字符串&#xff0c;不定数量&#xff0c;每行一个。有些字符串可能出现了多次。如果读入一个字符串后&#xff0c;发现这个字符串以前被读入过&#xff0c;则这个字符串被称为前面相同的字符串的复读&#xff0c;这个字符串被称为…

Redis分布式锁剖析和几种客户端的实现

1. 背景 在传统的单体项目中&#xff0c;即部署到单个IIS上&#xff0c;针对并发问题&#xff0c;比如进销存中的出库和入库问题&#xff0c;多个人同时操作&#xff0c;属于一个IIS进程中多个线程并发操作的问题&#xff0c;这个时候可以引入线程锁lock/Monitor等&#xff0c;…

信息论随笔(三)交互信息量

之前讨论了一个事件的自信息量&#xff0c;但是实际情况下往往有多个事件发生&#xff0c;而且这些事件之间相互是有联系的。比如知道一个人踢足球&#xff0c;那么这个人很有可能会看世界杯。也就是说&#xff0c;我们可以通过一个事件获得另外一个事件的信息&#xff0c;或者…

解决Android Studio等开发软件出现更新TKK失败的两种方案

解决Android Studio等开发软件出现更新TKK失败的两种方案方案一 配置hosts1. 配置域名与IP2.扫描国内可用的IP方案二 替换翻译引擎百度翻译引擎在Android Studio等开发软件中利用Translation等翻译插件时&#xff0c;出现无法翻译的提示&#xff1a;更新TKK失败&#xff0c;请检…

数据结构之栈的实现及相关OJ题

&#x1f57a;作者启明星使 &#x1f383;专栏&#xff1a;《数据库》《C语言》 &#x1f3c7;分享一句话&#xff1a; 对的人会站在你的前途里 志同道合的人才看得懂同一片风景 大家一起加油&#x1f3c4;‍♂️&#x1f3c4;‍♂️&#x1f3c4;‍♂️ 希望得到大家的支持&am…

【毕业设计】新闻分类系统 - 深度学习 机器学习

文章目录0 前言1 简介2 参与及比较算法3 先说结论4 实现过程4.1 数据爬取4.2 数据预处理5 CNN文本分类6 最后0 前言 &#x1f525; Hi&#xff0c;大家好&#xff0c;这里是丹成学长的毕设系列文章&#xff01; &#x1f525; 对毕设有任何疑问都可以问学长哦! 这两年开始&a…

事件总线EventBus

事件总线是对发布-订阅模式的一种实现&#xff0c;是一种集中式事件处理机制&#xff0c;允许不同的组件之间进行彼此通信而又不需要相互依赖&#xff0c;达到一种解耦的目的。 什么是“总线”&#xff1a;一个集中式的事件处理机制。同时服务多个事件和多个观察者。相当于一个…

C#编程深入研究变量,类型和方法

编写正确的C#代码 简单的调试技术 变量的语法 声明类型和值 仅声明类型 访问修饰符 使用类型 通用内置类型 类型转换 推断式声明 自定义类型 类型综述 命名变量 变量的作用域 运算符 定义方法 指定参数 指定返回值 常见的Unity方法 Start方法 Update方法 …

金山云:基于 JuiceFS 的 Elasticsearch 温冷热数据管理实践

01 Elasticsearch 广泛使用带来的成本问题 Elasticsearch&#xff08;下文简称“ES”&#xff09;是一个分布式的搜索引擎&#xff0c;还可作为分布式数据库来使用&#xff0c;常用于日志处理、分析和搜索等场景&#xff1b;在运维排障层面&#xff0c;ES 组成的 ELK&#xff…

MMDetection3D库中的一些模块介绍

本文目前仅包含2个体素编码器、2个中间编码器、1个主干网络、1个颈部网络和1个检测头。如果有机会&#xff0c;会继续补充更多模型。 若发现内容有误&#xff0c;欢迎指出。 MMDetection3D的点云数据一般会经历如下步骤/模块&#xff1a; #mermaid-svg-q9Wy2NQvFHfuPWKs {font-…

骨传导原理是什么,佩戴骨传导耳机的过程中对于耳道有无损害

随着新时代的到来&#xff0c;我们周围的数码产品逐渐被新产物所替代&#xff0c;以往在耳机市面上&#xff0c;普遍都是入耳式耳机&#xff0c;但长时间佩戴这种耳机的话对于我们耳道来说是有着不可逆的伤害&#xff0c;而在近几年骨传导耳机的出现&#xff0c;打破了传统耳机…

18.Redis系列之AOF方式持久化

本文学习redis7两大持久化技术之一&#xff1a;AOF&#xff08;Append Only File&#xff09;日志追加方式持久化备份与还原&#xff0c;重写以及AOF方式的优缺点 1. AOF相关配置 首先我们先简单了解下Redis7中AOF相关配置 // 开启AOF方式持久化&#xff0c;默认no appendon…

基于真实场景解读 K8s Pod 的各种异常

在 K8s 中&#xff0c;Pod 作为工作负载的运行载体&#xff0c;是最为核心的一个资源对象。Pod 具有复杂的生命周期&#xff0c;在其生命周期的每一个阶段&#xff0c;可能发生多种不同的异常情况。K8s 作为一个复杂系统&#xff0c;异常诊断往往要求强大的知识和经验储备。结合…

骚戴独家笔试---SQL笔试

SQL笔试训练 查询结果去重 两种答案 查找某个年龄段的用户信息 查找除复旦大学的用户信息 三种答案 用where过滤空值练习 三种答案 查询NULL时&#xff0c;不能使用比较运算符(或者< >)&#xff0c;需要使用IS NULL运算符或者IS NOT NULL运算符。 操作符混合运用 我这里…

力扣 792. 匹配子序列的单词数

题目 给定字符串 s 和字符串数组 words, 返回 words[i] 中是s的子序列的单词个数 。 字符串的 子序列 是从原始字符串中生成的新字符串&#xff0c;可以从中删去一些字符(可以是none)&#xff0c;而不改变其余字符的相对顺序。 例如&#xff0c; “ace” 是 “abcde” 的子序…

java spring引用外部jar包并使用

spring引用外部jar包并使用1、将jar包放到src/main/resources/lib2、编辑pom.xml文件build下面加入resources&#xff0c;不加话的打包会找不到资源3、project structure中引入该lib1、将jar包放到src/main/resources/lib 2、编辑pom.xml文件 打开pom文件&#xff0c;找到相应…

计算机网络基本知识

计算机网络基本知识 计算机网络定义&#xff1a;是一个将分散的、具有独立功能的计算机系统&#xff0c;通过通信设备与线路连接起来&#xff0c;由功能完善的软件实现资源共享和信息传递的系统。 1.1计算机网络在信息时代作用 1.2因特网概述 1.2.1网络、互联网、因特网 网…

DeepLab V1学习笔记

DeepLab V1摘要相关的工作遇到的问题和解决的方法信号下采样空间不变性(spatial insensitivity/invariance)论文的优点(贡献)网络的模型空洞卷积CRF多尺度预测模型总结实验结果Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs论文地址 : D…

[附源码]java毕业设计乒乓球俱乐部管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…