查看TensorFlow已训模型的结构和网络参数

news2025/6/26 16:40:53

文章目录

    • 概要
    • 流程

概要

通过以下实例,你将学会如何查看神经网络结构并打印出训练参数。

流程

  • 准备一个简易的二分类数据集,并编写一个单层的神经网络
train_data = np.array([[1, 2, 3, 4, 5], 
                       [7, 7, 2, 4, 10], 
                       [1, 9, 3, 6, 5], 
                       [6, 7, 8, 9, 10]])

train_label = np.array([1, 0, 1, 0])  #标签与样本一一对齐


""" 定义一个单层的神经网络 """
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(1, activation=None)
])
  • 编译,训练,并保存模型
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam'
)
model.fit(train_data,
          train_label,
          epochs=2750)

tf.saved_model.save(model, "model_dir")  #保存到当前目录中,目录名为model_dir
  • 模型保存形式

模型节点和矩阵参数集中保存在 .data-00000-of-00001和 .index文件中,利用这两个文件中创建CheckpointReader对象。

  • 利用模型的Checkpoint对象查看模型结构和参数

Checkpoint对象存储了模型中所有可tracable追踪的对象,并记录保存着这些对象的参数及名称。可通过 tf.train.load_checkpoint()方法获得一个CheckpointReader对象,该对象可以读取Checkpoint内的所有信息。

"""  最后的variables是.data-00000-of-00001和 .index文件去掉后缀后的表达形式,
     从而统一代表着这两个文件"""
save_path = './model_dir/variables/variables'  # 

reader = tf.train.load_checkpoint(save_path)  # 得到CheckpointReader

"""  打印Checkpoint中存储的所有参数名和参数shape """
for variable_name, variable_shape in reader.get_variable_to_shape_map().items():
    print(f'{variable_name} : {variable_shape}')
 

optimizer/_variables/2/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE : []
_CHECKPOINTABLE_OBJECT_GRAPH : []
keras_api/metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE : []
keras_api/metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE : []
layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE : [1]
layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_variables/1/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_learning_rate/.ATTRIBUTES/VARIABLE_VALUE : []
optimizer/_variables/3/.ATTRIBUTES/VARIABLE_VALUE : [1]
optimizer/_variables/4/.ATTRIBUTES/VARIABLE_VALUE : [1]

其中Dense层的权重参数和偏差bias的显示信息为,

layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE : [1]
layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]

接着利用刚刚打印出的参数名即可查看其参数值,

print(reader.get_tensor('layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE'))
print(reader.get_tensor("layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE"))


[[-1.7741445 ]
 [-0.07314294]
 [-0.07213379]
 [ 1.1694099 ]
 [-0.36803177]]

[1.7487208]

  • 验证
model = tf.saved_model.load('model_dir')
print(model([[1, 2, 3, 4, 5]]))
output = -1.7741445 - 2*0.07314294 - 3*0.07213379 + 4*1.1694099 - 5*0.36803177+1.7487208
print(output)


tf.Tensor([[2.4493697]], shape=(1, 1), dtype=float32)

2.4493698000000004

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1584135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【opencv】示例-essential_mat_reconstr.cpp 从两幅图像中恢复3D场景的几何信息

导入OpenCV的calib3d, highgui, imgproc模块以及C的vector, iostream, fstream库。定义了getError2EpipLines函数,这个函数用来计算两组点相对于F矩阵(基础矩阵)的投影误差。定义了sgn函数,用于返回一个双精度浮点数的符号。定义了…

SQLite超详细的编译时选项(十六)

返回:SQLite—系列文章目录 上一篇:SQLite数据库文件格式(十五) 下一篇:SQLite 在Android安装与定制方案(十七) 1. 概述 对于大多数目的,SQLite可以使用默认的 编译选项。但是…

2.HTML常用标签之表单标签

1.HTML常用标签之表单标签 w3c所有标签列表 HTML常用标签之表单标签

结合 tensorflow.js 、opencv.js 与 Ant Design 创建美观且高性能的人脸动捕组件并发布到InsCode

系列文章目录 如何在前端项目中使用opencv.js | opencv.js入门如何使用tensorflow.js实现面部特征点检测tensorflow.js 如何从 public 路径加载人脸特征点检测模型tensorflow.js 如何使用opencv.js通过面部特征点估算脸部姿态并绘制示意图tensorflow.js 使用 opencv.js 将人脸…

【STM32篇】DRV8425驱动步进电机

【STM32篇】4988驱动步进电机_hr4988-CSDN博客 在上篇文章中使用了HR4988实现了步进电机的驱动,在实际运用过程,HR4988或者A4988驱动步进电机会存在电机噪音太大的现象。本次将向各位友友介绍一个驱动简单且非常静音的一款步进电机驱动IC。 1.DRV8425简介…

苹果开发者后台添加udid后,xcode中 Devices 数量没有更新问题

删除 文件夹 /Users/…/Library/MobileDevice/Provisioning Profiles 如何打开:https://zhuanlan.zhihu.com/p/563928113 回到Xcode刷新包名下面的警告验证(可能需要翻墙) 完毕!

Java异常处理机制详解:多层方法调用与异常传播(day23)

1.数组下标越界 2.多个处理异常 上面这两个代码的区别就是有无 System.out.println("抛出了NumberFormatException"); System.out.println("抛出了ArrayIndexOutOfBoundsException"); 第一种是不论捕获到哪种异常,都只会调用e.printStack…

探索GlusterFS:开源分布式文件系统

目录 引言 一、GlusterFS简介 (一)基本介绍 (二)GlusterFS特点 (三)GlusterFS术语 (四)GlusterFS工作流程 二、GlusterFs的卷类型 (一)卷类型 &…

【面试题】微博、百度等大厂的排行榜如何实现?

背景 现如今每个互联网平台都会提供一个排行版的功能,供人们预览最新最有热度的一些消息,比如百度: 再比如微博: 我们要知道,这些互联网平台每天产生的数据是非常大,如果我们使用MySQL的话,db实…

使用R语言计算矩形分布(均匀分布)并绘制图形

理论部分 矩形分布(均匀分布),是指在某一区间内,随机变量取任何值的概率都是相同的。这种分布的概率密度函数在一个特定的区间内是一个常数,因此其图形呈现出一个矩形的形状,故得名为“矩形分布”。在概率…

智能边缘自动化:HDMI接口钡铼ARM工业电脑实践案例

一款具备HDMI接口的高性能ARM工业计算机应运而生,为实现在工业4.0时代的关键数据实时处理与可视化管理提供了强有力的硬件支撑。这款计算机依托其独特的边缘计算能力,完美解决了工业环境中大规模数据传输至云端的高延迟问题,成功实现了OT&…

酷开科技在大数据及人工智能推动下,成功将酷开系统与AI融合

随着科技的不断发展,以及大数据这个概念的出现,让看似冷冰冰的数字开始具备了温度,开始让数字产生了温暖的价值,也让各个行业看到了大数据的作用。酷开科技生态的核心场景是家庭、是客厅,无论是以酷开科技为代表的OTT&…

电压继电器SRMUVS-220VAC-2H2D 导轨安装 JOSEF约瑟

系列型号: SRMUVS-58VAC-2H欠电压监视继电器;SRMUVS-100VAC-2H欠电压监视继电器; SRMUVS-110VAC-2H欠电压监视继电器;SRMUVS-220VAC-2H欠电压监视继电器; SRMUVS-58VAC-2H2D欠电压监视继电器;SRMUVS-100…

找不到vcruntime140.dll怎么办,vcruntime140.dll丢失的多种解决方法

在我们日常频繁地与电脑打交道、依赖其处理各种工作、学习乃至娱乐任务的过程中,偶尔会遭遇一些令人困扰的技术问题。其中一种颇为常见的情况便是,当您正全神贯注于某个重要应用的操作,或是满怀期待地试图启动一款新安装的游戏时,…

蓝桥杯刷题 二分-[2145]求阶乘(C++)

问题描述 满足 N! 的末尾恰好有 K 个 0 的最小的 N 是多少? 如果这样的 N 不存在输出 −1。 输入格式 一个整数 K。 输出格式 一个整数代表答案。 样例输入 2 样例输出 10 评测用例规模与约定 对于 30% 的数据,1 ≤ K ≤ 10的6次方 对于 100% 的数据&…

ES6对于Class类的基本语法详解(2024-04-10)

目录 1、传统ES5写法 2、ES6 的class语法 3、ES5与ES6行为对比 4、类的constructor() 方法 5、类的实例 new 6、类的对象属性(新写法) 7、类的取值函数(getter)和存值函数(setter) 8、Class类的表达…

用vue3写一个AI聊天室

效果图如下&#xff1a; 1、页面布局&#xff1a; <template><div class"body" style"background-color: rgb(244, 245, 248); height: 730px"><div class"container"><div class"right"><div class"…

SpringBoot3 + uniapp 对接 阿里云0SS 实现上传图片视频到 0SS 以及 0SS 里删除图片视频的操作(最新)

SpringBoot3 uniapp 对接 阿里云0SS 实现上传图片视频到 0SS 以及 0SS 里删除图片视频的操作 最终效果图uniapp 的源码UpLoadFile.vuedeleteOssFile.jshttp.js SpringBoot3 的源码FileUploadController.javaAliOssUtil.java 最终效果图 uniapp 的源码 UpLoadFile.vue <tem…

Netty出坑记

NIO&#xff1a; 一个线程处理多个请求 BIO&#xff1a; 阻塞 netty 编码解码 TFO&#xff1a; 校验cookie合法性&#xff0c;不合法 TCP流程 设计QQ&#xff1a; 登录过程&#xff0c;client TCP协议向server发送信息&#xff0c;HTTP协议下载信息 发消息&#xff1a;clie…

Win10系统VScode远程连接VirtualBox安装的Ubuntu20.04.5

1.打开虚拟机&#xff0c;在中端中输入命令: sudo apt-get install openssh-server 安装ssh 我这里已经安装完成&#xff0c;故显示是这样 2.输入命令&#xff1a;sudo systemctl start ssh 启动远程连接 注意&#xff0c;如果使用VirtualBox安装的虚拟机&#xff0c;需要启用…