DeepLabV3+:对预测处理的详解

news2025/7/31 18:30:17

相信大家对于这一部分才是最感兴趣的,能够实实在在的看到效果。这里我们就只需要两个.py文件(deeplab.py、predict_img.py)。

创建DeeplabV3类

deeplab.py的作用是为了创建一个DeeplabV3类,提供一个检测图片的方法,而predict_img.py则是为了单独检测图片的效果。

在这里我需要一个defaults字典用来包含我在这个类要使用的变量,而需要把数据类型转换成字典数据再做存储,这时候就需要用到类的内置属性__dict__。

这里简单说一下字典defaults的key和value。

1、model_path=model_date/deeplab_mobilenetv2.pth

此文件是基于VOC拓展数据集训练的权重,放心使用,附上下载地址的权值文件。

2、num_classes=2

对于需要区分的类数+1,比如我这里是识别裂缝,所以我的num_classes为1+1,再比如经典的猫狗分类问题,那么它们的num_classes为2+1=3。

3、backbone=mobilenet

这里是使用的主干网络,有mobilenet和xception可供选择。

4、input_shape=[512,512]

输入图片的大小

5、downsample_factor=16

下采样的倍数,可选的有8和16,但8训练要求更大的内存,这里要与训练时相同。

6、mix_type=0

0代表原图与生成的图进行混合;1代表仅保留生成的图;2代表扣去背景,仅保留原图中的目标。

7、cude=False

有cuda就是Ture,没有就用cpu。

在这里,请看detect_image函数里面,首先要用cvtColor函数对图片进行一个转化,因为RGB图像才有权重。

在对图像的大小修改时,需要增添一个灰度框,想想这样的请况,如果图像比输入大小小就会使图像被强行放大,可能会伸长也可能会扩展。那么为了避免这种请况,所以要添加这个灰度边界,后期因为要与原图大小匹配,会将这部分去掉。

然后,对图像的每个像素点进行分类。

# deeplab.py

import colorsys
import copy

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn

from Deeplabv3_plus.deeplabv3plus import DeepLab
from utily.utils import cvtColor, preprocess_input, resize_image, show_config



class DeeplabV3(object):
    defaults = {
        "model_path": 'model_data/deeplab_mobilenetv2.pth',
        "num_classes": 2,
        "backbone": "mobilenet",
        "input_shape": [512, 512],  
        "downsample_factor": 16,  
        "mix_type": 0,
        "cuda": False
    }

    def __init__(self, **kwargs):
        self.__dict__.update(self.defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)

        if self.num_classes <= 21:
            self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 
                            (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 
                            (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), 
                            (128, 64, 12)]
            # 画框设置不同的颜色
        else:
            hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
            self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
            self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))

        # 获得模型
        self.generate()
        show_config(**self.defaults)

    # 获得所有的分类
    def generate(self, onnx=False):
        # 载入模型与权值
        self.net = DeepLab(num_classes=self.num_classes, backbone=self.backbone, downsample_factor=self.downsample_factor, pretrained=False)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net = self.net.eval()
        print('{} model, and classes loaded.'.format(self.model_path))
        if not onnx:
            if self.cuda:
                self.net = nn.DataParallel(self.net)
                self.net = self.net.cuda()

    def detect_image(self, image, count=False, name_classes=None):
        """
        * 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        * 代码仅仅支持RGB图像的预测,所以其它类型的图像都会转化成RGB
        :param image: 图片
        :param count: 计数
        :param name_classes:
        :return:
        """
        image = cvtColor(image)

        # 对输入图像进行一个备份,后面用于绘图
        old_img = copy.deepcopy(image)
        orininal_h = np.array(image).shape[0]
        orininal_w = np.array(image).shape[1]

        # 给图像增加灰条,实现不失真的resize
        # 也可以直接resize进行识别
        image_data, nw, nh  = resize_image(image, (self.input_shape[1],self.input_shape[0]))

        # 添加上batch_size维度
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()

            # 图片传入网络进行预测
            pr = self.net(images)[0]

            # 取出每一个像素点的种类
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()

            # 将灰条部分截取掉
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]

            # 进行图片的resize
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)

            # 取出每一个像素点的种类
            pr = pr.argmax(axis=-1)

        if count:
            classes_nums        = np.zeros([self.num_classes])
            total_points_num    = orininal_h * orininal_w
            print('-' * 63)
            print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio"))
            print('-' * 63)
            for i in range(self.num_classes):
                num     = np.sum(pr == i)
                ratio   = num / total_points_num * 100
                if num > 0:
                    print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio))
                    print('-' * 63)
                classes_nums[i] = num
            print("classes_nums:", classes_nums)
    
        if self.mix_type == 0:
            # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
            # for c in range(self.num_classes):
            #     seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
            #     seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
            #     seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
            seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])

            # 将新图片转换成Image的形式
            image   = Image.fromarray(np.uint8(seg_img))

            # 将新图与原图及进行混合
            image   = Image.blend(old_img, image, 0.7)

        elif self.mix_type == 1:
            # seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
            # for c in range(self.num_classes):
            #     seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
            #     seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
            #     seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
            seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])

            # 将新图片转换成Image的形式
            image   = Image.fromarray(np.uint8(seg_img))

        elif self.mix_type == 2:
            seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')

            # 将新图片转换成Image的形式
            image = Image.fromarray(np.uint8(seg_img))
        
        return image

单张图片的预测

由于我想要将图片放在PyQt5设计的ui中,所以要单张单张的显示。

from PIL import Image
from deeplab import DeeplabV3

if __name__ == "__main__":

    deeplab = DeeplabV3()
    mode = "predict" 
    count = False    #指定了是否进行目标的像素点计数(即面积)与比例计算
    # name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
    name_classes = ["background","crack"]

    if mode == "predict":

        while True:
            img = input('Input image filename:')
            try:
                image = Image.open(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                r_image = deeplab.detect_image(image, count=count, name_classes=name_classes)
                r_image.show()

我们来看看效果:

原图

效果图

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/367795.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构与算法入门

目录数据结构概述逻辑结构存储结构算法概述如何理解“大O记法”时间复杂度空间复杂度数据结构概述 数据结构可以简单的理解为数据与数据之间所存在的一些关系&#xff0c;数据的结构分为数据的存储结构和数据的逻辑结构。 逻辑结构 集合结构&#xff1a;数据元素同属于一个集…

Codeforces Round #848 (Div. 2)A-C

传送门 目录 A. Flip Flop Sum 代码&#xff1a; B. The Forbidden Permutation 代码&#xff1a; C. Flexible String 代码&#xff1a; A. Flip Flop Sum 题意&#xff1a;给你一个长度为n的数组&#xff08;数组元素只为1或者-1&#xff09;&#xff0c;你要且只能进行…

掌握lombok简化Java编码完成后端提效

Lombok安装 –>添加依赖 <dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><version>1.18.16</version><scope>provided</scope> </dependency>scopeprovided&#xff0c;说…

LinkSLA智能运维技术派-Redis的监控

Redis是一个开源&#xff0c;内存存储的数据服务器&#xff0c;可用作数据库、高速缓存和消息队列代理等场景。 首先我们对内存进行监控&#xff0c;主要指标如下&#xff1a; - used_memory:使用内存 - used_memory_rss:从操作系统分配的内存 - mem_fragmentation_ratio:内…

如何成为一名黑客?小白必学的6个基本步骤

黑客攻防是一个极具魅力的技术领域&#xff0c;但成为一名黑客毫无疑问也并不容易。你必须拥有对新技术的好奇心和积极的学习态度&#xff0c;具备很深的计算机系统、编程语言和操作系统知识&#xff0c;并乐意不断地去学习和进步。 如果你想成为一名优秀的黑客&#xff0c;下…

交叉编译 acl

交叉编译 acl 概述 访问控制列表&#xff08;Access Control Lists&#xff0c;ACL&#xff09;是应用在路由器接口的指令列表。在 Linux 系统中&#xff0c;ACL 用于设定用户针对文件的权限&#xff0c;而不是在交换路由器中用来控制数据访问的功能&#xff08;类似于防火墙…

跑步耳机怎么选、最好用的跑步专用耳机分享

跑步时候戴着的耳机一直往下滑&#xff0c;跑步的步伐也不敢快起来&#xff0c;生怕耳机掉下去。除此之外&#xff0c;还担心跑步时流的汗水渗入到耳机里面&#xff0c;生怕因此被电到。因为没有合适的耳机在跑步时听歌&#xff0c;不但没能缓解跑步时的枯燥还徒增了些烦恼&…

力扣-市场分析

大家好&#xff0c;我是空空star&#xff0c;本篇带大家了解一道简单的力扣sql练习题。 文章目录前言一、题目&#xff1a;1158. 市场分析二、解题1.错误示范①提交SQL运行结果2.正确示范①提交SQL运行结果3.错误示范②提交SQL运行结果4.正确示范②提交SQL运行结果5.其他总结前…

华为OD机试题,用 Java 解【查找接口成功率最优时间段】问题

最近更新的博客 华为OD机试 - 猴子爬山 | 机试题算法思路 【2023】华为OD机试 - 分糖果(Java) | 机试题算法思路 【2023】华为OD机试 - 非严格递增连续数字序列 | 机试题算法思路 【2023】华为OD机试 - 消消乐游戏(Java) | 机试题算法思路 【2023】华为OD机试 - 组成最大数…

Allegro如何在关闭飞线模式下查看网络连接位置操作指导

Allegro如何在关闭飞线模式下查看网络连接位置操作指导 在用Allegro做PCB设计的时候,有时会因为设计需要,关闭飞线显示。 如何在关闭飞线显示模式下查看网络连接的位置,如下图 除了能看到网络连接的点位以外,还能看到器件的pin Number 如何显示出这种效果,具体操作如下 …

代码随想录算法训练营 || 回溯算法 491 46 47

Day25491.递增子序列力扣题目链接给定一个整型数组, 你的任务是找到所有该数组的递增子序列&#xff0c;递增子序列的长度至少是2。示例:输入: [4, 6, 7, 7]输出: [[4, 6], [4, 7], [4, 6, 7], [4, 6, 7, 7], [6, 7], [6, 7, 7], [7,7], [4,7,7]]思路有难度的一道题需要注意两个…

Google员工说出了我不敢说的心里话!

前言&#xff1a;本文来自Beyond的投稿&#xff0c;码农翻身做了修改。今天在Medium上看到一篇文章《The maze is in the mouse》&#xff0c;是一个刚从Google离职的员工写的&#xff0c;揭开了Google内部的各种问题&#xff0c;引发了很多人的共鸣&#xff0c;到目前为止&…

ChatGPT能看到图片,太神了!

闲来无事&#xff0c;给ChatGPT提供了一张图片的地址&#xff0c;他说他能看到&#xff0c;并且还描述了出来&#xff0c;真的是太神奇了。以下是对话&#xff1a; 我用Midjourney帮我生成了树世界的主界面图片了&#xff0c;很美 很好&#xff01;如果你想要分享图片&#xf…

如果不使用时钟同步工具,linux如何解决时钟同步问题?仅需要一行命令即可。

这是一篇日记&#xff0c;记录了上帝下凡出手&#xff0c;解救苍生与水火之中的神奇文章&#xff0c;如果你也有过类似的经历&#xff0c;留言关注&#xff0c;咱们交流一下~ 目录 背景&#xff08;如果不想知道可以跳过&#xff09; 一行神奇的命令 一段一段的研究 总结 背…

实现“第 24”种设计模式

传统方案 if-else 在我们编程时出现的频率&#xff0c;无需我多赘述。当逻辑复杂时&#xff0c;我们会写出很多 if-else 语句&#xff0c;于是网络上充斥着大量的相关文章&#xff0c;教我们如何去除if-else&#xff0c;大多大同小异。 归结下来&#xff0c;无非是策略模式、…

凌恩生物资讯|抗性宏基因组又一力作|抗性基因+可移动元件研究新成果!

凌恩生物合作客户&#xff1a;合肥工业大学崔康平老师团队利用凌恩生物宏基因组抗性基因研究解决方案&#xff0c;对污水处理厂活性污泥中的钆&#xff08;Gd&#xff08;III&#xff09;&#xff09;和抗生素磺胺甲噁唑&#xff08;SMX&#xff09;的联合污染情况进行了调查&a…

华为OD机试题,用 Java 解【滑动窗口最大和】问题

最近更新的博客 华为OD机试 - 猴子爬山 | 机试题算法思路 【2023】华为OD机试 - 分糖果(Java) | 机试题算法思路 【2023】华为OD机试 - 非严格递增连续数字序列 | 机试题算法思路 【2023】华为OD机试 - 消消乐游戏(Java) | 机试题算法思路 【2023】华为OD机试 - 组成最大数…

前端无障碍适配

无障碍简介&#xff1a; 帮助一些视障群体使用手机&#xff0c;点击的热区会增加配合文字识别增加一些语音播报的功能&#xff0c;手机一般可以通过&#xff1a;设置—》辅助功能—》无障碍功能菜单 体验无障碍功能 IOS&#xff1a;设置–》辅助功能----》旁白 需求背景 会有…

Allegro如何显示层叠Options和Find操作界面

Allegro如何显示层叠Options和Find操作界面 Allegro常规有三大操作界面,层叠,Options和Find,如下图 软件第一次启动的时候,三大界面是关闭的,下面介绍如何把它们打开,具体操作步骤如下 点击菜单上的View点击Windows

JavaScript 进阶(面试必备)--charater4

文章目录前言一、深浅拷贝:one: 浅拷贝:two:深拷贝二、异常处理:one: throw 抛异常:two: try /catch 捕获异常:three:debugger三、处理thisthis指向 :one:普通函数this指向this指向 :two: 箭头函数this指向3.2 改变this:one: call():two: apply():three: bind()四、性能优化:on…