用 tensorflow.js 做了一个动漫分类的功能(二)

news2025/7/13 1:55:19

前言:

前面已经通过采集拿到了图片,并且也手动对图片做了标注。接下来就要通过 Tensorflow.js 基于 mobileNet 训练模型,最后就可以实现在采集中对图片进行自动分类了。

这种功能在应用场景里就比较多了,比如图标素材站点,用户通过上传一个图标,系统会自动匹配出相似的图标,还有二手平台,用户通过上传闲置物品图片,平台自动给出分类等,这些也都是前期对海量图片进行了标注训练而得到一个损失率极低的模型。下面就通过简答的代码实现一个小的动漫分类。

环境:

Node

Http-Server

Parcel

Tensorflow

编码:

1. 训练模型

1.1. 创建项目,安装依赖包

npm install @tensorflow/tfjs --legacy-peer-deps
npm install @tensorflow/tfjs-node-gpu --legacy-peer-deps

1.2. 全局安装 Http-Server

npm install i http-server

1.3. 下载 mobileNet 模型文件 (网上有下载)

1.4. 根目录下启动 Http 服务 (开启跨域),用于 mobileNet 和训练结果的模型可访问

http-server--cors-p8080

1.5. 创建训练执行脚本 run.js

const tf = require('@tensorflow/tfjs-node-gpu');

const getData = require('./data');
const TRAIN_PATH = './动漫分类/train';
const OUT_PUT = 'output';
const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';

(async () => {
  const { ds, classes } = await getData(TRAIN_PATH, OUT_PUT);
  console.log(ds, classes);
  //引入别人训练好的模型const mobilenet = await tf.loadLayersModel(MOBILENET_URL);
  //查看模型结构
  mobilenet.summary();

  const model = tf.sequential();
  //截断模型,复用了86个层for (let i = 0; i < 86; ++i) {
    const layer = mobilenet.layers[i];
    layer.trainable = false;
    model.add(layer);
  }
  //降维,摊平数据
  model.add(tf.layers.flatten());
  //设置全连接层
  model.add(tf.layers.dense({
    units: 10,
    activation: 'relu'//设置激活函数,用于处理非线性问题
  }));

  model.add(tf.layers.dense({
    units: classes.length,
    activation: 'softmax'//用于多分类问题
  }));
  //设置损失函数,优化器
  model.compile({
    loss: 'sparseCategoricalCrossentropy',
    optimizer: tf.train.adam(),
    metrics:['acc']
  });

  //训练模型await model.fitDataset(ds, { epochs: 20 });
  //保存模型await model.save(`file://${process.cwd()}/${OUT_PUT}`);
})();

1.6. 创建图片与 Tensor 转换库 data.js

const fs = require('fs');
const tf = require("@tensorflow/tfjs-node-gpu");

const img2x = (imgPath) => {
  const buffer = fs.readFileSync(imgPath);
  //清除数据return tf.tidy(() => {
    //把图片转成tensorconst imgt = tf.node.decodeImage(newUint8Array(buffer), 3);
    //调整图片大小const imgResize = tf.image.resizeBilinear(imgt, [224, 224]);
    //归一化return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);
  });
}

const getData = async (traindir, output) => {
  let classes = fs.readdirSync(traindir, 'utf-8');
  fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes));
  const data = [];
  classes.forEach((dir, dirIndex) => {
    fs.readdirSync(`${traindir}/${dir}`)
      .filter(n => n.match(/jpg$/))
      .slice(0, 1000)
      .forEach(filename => {
        const imgPath = `${traindir}/${dir}/${filename}`;

        data.push({ imgPath, dirIndex });
      });
  });

  console.log(data);

  //打乱训练顺序,提高准确度
  tf.util.shuffle(data);

  const ds = tf.data.generator(function* () {
    const count = data.length;
    const batchSize = 32;
    for (let start = 0; start < count; start += batchSize) {
      const end = Math.min(start + batchSize, count);
      console.log('当前批次', start);
      yield tf.tidy(() => {
        const inputs = [];
        const labels = [];
        for (let j = start; j < end; ++j) {
          const { imgPath, dirIndex } = data[j];
          const x = img2x(imgPath);
          inputs.push(x);
          labels.push(dirIndex);
        }
        const xs = tf.concat(inputs);
        const ys = tf.tensor(labels);
        return { xs, ys };
      });
    }
  });

  return { ds, classes };
}

module.exports = getData;

1.7. 运行执行文件

noderun.js

2. 调用模型

2.1. 全局安装 parcel

npminstall i parcel

2.2. 创建页面 index.html

<scriptsrc="script.js"></script><inputtype="file"onchange="predict(this.files[0])"><br>

2.3. 创建模型调用预测脚本 script.js

import * as tf from'@tensorflow/tfjs';
import { img2x, file2img } from'./utils';

const MODEL_PATH = 'http://127.0.0.1:8080/t7';
const CLASSES = ["假面骑士","奥特曼","海贼王","火影忍者","龙珠"];


window.onload = async () => {
    const model = await tf.loadLayersModel(MODEL_PATH + '/output/model.json');

    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            return model.predict(x);
        });

        const index = pred.argMax(1).dataSync()[0];
        console.log(pred.argMax(1).dataSync());

        let predictStr = "";
        if (typeof CLASSES[index] == 'undefined') {
            predictStr = BRAND_CLASSES[index];
        } else {
            predictStr = CLASSES[index];
        }

        setTimeout(() => {
            alert(`预测结果:${predictStr}`);
        }, 0);
    };
};

2.4. 创建图片 tensor 格式转换库 utils.js

import * as tf from'@tensorflow/tfjs';

exportfunctionimg2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}

exportfunctionfile2img(f) {
    returnnewPromise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}

2.5. 打包项目并运行

parcelindex.html

2.6. 运行效果

注意:

1. 模型训练过程报错

Input to reshape is a tensor with 50176 values, but the requested shape has 150528

1.1. 原因

张量 reshape 不对,实际输入元素个数与所需矩阵元素个数不一致,就是采集过来的图片有多种图片格式,而不同格式的通道不同 (jpg3 通道,png4 通道,灰色图片 1 通道),在将图片转换 tensor 时与代码里的张量形状不匹配。

1.2. 解决方法

一种方法是删除灰色或 png 图片,其二是修改代码 tf.node.decodeImage (new Uint8Array (buffer), 3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/362494.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java集成redis简单存储

这里主要将数据存redis并设置过期时间、通过key删除redis、通过key更新redis&#xff08;续期&#xff09; 将数据存redis并设置过期时间 引入redis依赖 import org.springframework.data.redis.core.StringRedisTemplate;AutowiredStringRedisTemplate stringRedisTemplate…

【基础教程】Appium自动化测试,太详细了!

Appium简介Appium是一款开源的Appium自动化工具, 基于Webdriver协议, 主要有以下3个特点:全能: 支持iOS/Andorid/H5/混合App/WinApp通用: 支持Win/Linux/Mac, 支持Java/Python/Ruby/Js/PHP等各种语言开源: 免费App自动化测试工具对比iOS官方:Uiautomation/XCUITest: 白盒, UI测…

(二十)、完成个人中心页面的数据统计+设置详情页点赞用户的头像组【uniapp+uinicloud多用户社区博客实战项目(完整开发文档-从零到完整项目)】

1&#xff0c;个人中心页面的数据统计 数据统计包括两项内容&#xff1a;1.当前登录用户的点赞总数量。2.当前登录用户发布文章的总数量 1.1&#xff0c;在self页面data中定义对象 data() {return {totalObj:{artNum:0,likeNum:0}};},1.2&#xff0c;获取总数量的方法&#x…

多线程(初识线程)

线程的诞生 了解进程存在的意义 实现了并发编程的效果&#xff08;并发编程&#xff1a;有可能是并发执行&#xff0c;也有可能是并行执行&#xff09; 并发编程的目的&#xff1a;充分利用上多核CPU资源&#xff0c;提升运行效率 了解进程创建和销毁的过程带来的问题 进程是…

系列二、函数

一、定义 函数 是指一段可以直接被另一段程序调用的程序或代码。 也就意味着&#xff0c;这一段程序或代码MySQL中 已经给我们提供了&#xff0c;我们要做的就是在合适的业务场景调用对应的函数完成对应的业务需求即可。二、字符串函数 2.1、案例 2.1.1、concat 字符串拼接 s…

js中?.、??的具体用法

1、?. &#xff08;可选链运算符&#xff09; 在javascript中如果一个值为null、undefined&#xff0c;直接访问下面的属性&#xff0c;会报 Uncaught TypeError: Cannot read properties of undefined 异常错误。而在真实的项目中是会出现这种情况&#xff0c;有这个值就读这…

泛型擦除(Generic erase)(内含教学视频+源代码)

泛型擦除&#xff08;Generic erase&#xff09;&#xff08;内含教学视频源代码&#xff09; 教学视频源代码下载链接地址&#xff1a;https://download.csdn.net/download/weixin_46411355/87473560 源代码中使用的泛型&#xff0c;在经过编辑后&#xff0c;代码中就看不到泛…

PX4之飞行控制框架

PX4的飞行控制程序通过模块来实现&#xff0c;与飞控相关的模块主要有commander&#xff0c;navigator&#xff0c;pos_control&#xff0c;att_control这几个&#xff0c;分别可以在src/modules目录中找到。 commander - 指令/事件处理模块&#xff0c;处理指令、遥控器输入和…

新C++(9):谈谈,翻转那些事儿

"相信羁绊&#xff0c;相信微光&#xff0c;相信一切无常。"一、AVL树翻转那些事儿(1)什么是AVL树&#xff1f;在计算机科学中&#xff0c;AVL树是最先发明的自平衡二叉查找树。在AVL树中任何节点的两个子树的高度最大差别为1&#xff0c;所以它也被称为高度平衡树。…

网上插画教学哪家质量好,汇总5大插画培训班

网上插画教学哪家质量好&#xff1f;给大家梳理了国内5家专业的插画师培训班&#xff0c;最新五大插画班排行榜&#xff0c;各有优势和特色&#xff01; 一&#xff1a;国内知名插画培训机构排名 1、轻微课&#xff08;五颗星&#xff09; 主打课程有日系插画、游戏原画、古风插…

Tencent OS下逻辑卷(LVM)创建和扩容

测试环境是一个虚拟机&#xff0c;原配置1个虚拟盘。 创建4个虚拟盘&#xff0c;每盘2G并挂载在虚拟主机上&#xff0c;启动虚拟主机开始测试。 LVM英文是Logical Volume Manager&#xff0c;直接翻译为逻辑卷管理。 这种磁盘管理模式比较灵活&#xff0c;在磁盘空间不足的时…

深入浅出C++ ——容器适配器

文章目录一、容器适配器二、deque类简介1. deque的原理2. deque迭代器3. deque的优点和缺陷4. 为什么选择deque作为stack和queue的底层默认容器一、容器适配器 适配器的概念 适配器是STL六大核心组件之一&#xff0c;它是一种设计模式&#xff0c;该种模式是将一个类的接口转换…

大规模 IoT 边缘容器集群管理的几种架构-2-HashiCorp 解决方案 Nomad

前文回顾 大规模 IoT 边缘容器集群管理的几种架构-0-边缘容器及架构简介大规模 IoT 边缘容器集群管理的几种架构-1-RancherK3s &#x1f4da;️Reference: IoT 边缘计算系列文章 HashiCorp 解决方案 - Nomad Docker 简介 Nomad: 一个简单而灵活的调度器和编排器&#xff0c;…

网络工程课(二)

ensp配置vlan 一、配置计算机ip地址和子网掩码 二、配置交换机LSW1 system-view [Huawei]sysname SW1 [SW1]vlan batch 10 20 [SW1]interface Ethernet0/0/1 [SW1-Ethernet0/0/1]port link-type access 将接口设为access接口 [SW1-Ethernet0/0/1]port default vlan 10 [SW1-E…

【MyBatis】源码学习 04 - 从 MapperMethod 简单分析一条 SQL 的映射操作流程

文章目录前言参考目录学习笔记1、测试代码说明2、binding 包的主要功能3、获取 Mapper 接口实例过程4、SQL 语句执行流程4.1、方法调用器4.2、MapperMethod 绑定方法4.2.1、SqlCommand4.2.2、MethodSignature4.3、MapperMethod#execute前言 本文内容对应的是书本第 13 章的内容…

【亲测2022年】网络工程师被问最多的面试笔试题

嗨罗~大家好久不见&#xff0c;主要是薄荷呢主业还是比较繁忙的啦&#xff0c;之前发了一个面试题大家都很喜欢&#xff0c;非常感谢各位大佬对薄荷的喜爱&#xff0c;嘻嘻然后呢~薄荷调研了身边的朋友和同事&#xff0c;发现我们之前去面试&#xff0c;写的面试题有很多共同的…

C++ Effictive 第6章 继承与面向对象设计 笔记

继承意味着"is-a"。如果B继承自A&#xff0c;那么B is-a A。 子类声明与父类函数同名的函数时&#xff0c;父类函数会被遮掩。 使用using Base::func(args...)&#xff1b;父类所有func的重载函数都在子类中被声明。此举下&#xff0c;如果子类函数与父类函数参数也一…

不要对chatgpt过度反思 第一部分

最近一段时间&#xff0c;chatgpt很热&#xff0c;随意翻一些文章或视频&#xff0c;一些非常整齐一致的怪论&#xff0c;时不时都会冒出来。 为什么这种革命性创新又出现美国&#xff1f; 为什么我国互联网只会电商&#xff0c;没有创新&#xff1f; 为什么我们做不出来&…

列表推导式_Python教程

内容摘要 Python中存在一种特殊的表达式&#xff0c;名为推导式&#xff0c;它的作用是将一种数据结构作为输入&#xff0c;再经过过滤计算等处理&#xff0c;最后输出另一种数据结构。根据数据结构的不同会被分为列表推导式、 文章正文 Python中存在一种特殊的表达式&#x…

股票、指数、快照、逐笔... 不同行情数据源的实时关联分析应用

在进行数据分析时经常需要对多个不同的数据源进行关联操作&#xff0c;因此在各类数据库的 SQL 语言中均包含了丰富的 join 语句&#xff0c;以支持批计算中的多种关联操作。 DolphinDB 不仅通过 join 语法支持了对于全量历史数据的关联处理&#xff0c;而且在要求低延时的实时…