深度学习【Logistic回归模型】

news2025/7/19 12:50:14

回归和分类

回归问题得到的结果都是连续的,比如通过学习时间预测成绩

分类问题是将数据分成几类,比如根据邮件信息将邮件分成垃圾邮件和有效邮件两类。

相比于基础的线性回归其实就是增加了一个sigmod函数。

代码

import matplotlib.pyplot as plt
import torch
import pandas as pd
import numpy as np
import torch.nn as nn

# 设定随机种子
torch.manual_seed(2017)

# 从 data.txt 中读入点
with open('./data.txt', 'r') as f:
    data_list = [i.split('\n')[0].split(',') for i in f.readlines()]
    data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]

# 标准化
x0_max = max([i[0] for i in data])
x1_max = max([i[1] for i in data])
data = [(i[0]/x0_max, i[1]/x1_max, i[2]) for i in data]

#把两个类别的点分别保存起来,后续方便绘图
x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 选择第一类的点
x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 选择第二类的点

plot_x0 = [i[0] for i in x0] #第0类点的x坐标
plot_y0 = [i[1] for i in x0]#第0类点的y坐标
plot_x1 = [i[0] for i in x1]#第1类点的x坐标
plot_y1 = [i[1] for i in x1]#第1类点的y坐标


#数据转化为numpy,再转化为tensor
np_data = np.array(data, dtype='float32') # 转换成 numpy array
x_data = torch.from_numpy(np_data[:, 0:2]) # 转换成 Tensor, 大小是 [100, 2]
y_data = torch.from_numpy(np_data[:, 2]).unsqueeze(1) # 转换成 Tensor,大小是 [100, 1]

#定义模型
class LogisticRegression(torch.nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(2,1)  #这里会自动初始化w和b,所以不需要自己再设置

    def forward(self, x):
        return self.linear(x)   #这里自己第一次写的时候,写成torch.sigmoid(self.linear(x)),其实在后续损失函数中,已经包含了sigmoid函数,所以这里可以直接返回线性层的输出
#实例化
model = LogisticRegression()
#损失函数
criterion = nn.BCEWithLogitsLoss()


optimizer = torch.optim.SGD(model.parameters(), lr=1)  #优化器,自己也是入门,这里学习率往往按照经验设置,也可以多多尝试设置


# 训练函数
def train(epochs):
    model.train()
    for e in range(epochs):
        optimizer.zero_grad()  #梯度归零
        y_pred = model(x_data) #前向传播
        loss = criterion(y_pred, y_data) #计算损失
        loss.backward() #反向传播
        optimizer.step() #更新参数
        if e % 1000 == 0:
            print(f"Epoch {e}, Loss: {loss.item()}")

if __name__ == '__main__':
    train(10000)

    # 获取训练后的参数
    w = model.linear.weight.detach().numpy().flatten()
    b = model.linear.bias.detach().numpy().item()

    # 计算决策边界 (w1*x1 + w2*x2 + b = 0)
    plot_x = np.linspace(0.2, 1, 100)
    plot_y = (-w[0] * plot_x - b) / w[1]



    # 绘制结果
    plt.figure(figsize=(8, 6)) #设置画布
    #描点
    plt.scatter(plot_x0, plot_y0, c='red', label='Class 0', edgecolors='k')
    plt.scatter(plot_x1, plot_y1, c='blue', label='Class 1', edgecolors='k')

    #划分割线
    plt.plot(plot_x, plot_y, 'g--', linewidth=2, label='Decision Boundary')


    plt.title("Logistic Regression Classification")
    plt.xlabel("Normalized Feature 1")
    plt.ylabel("Normalized Feature 2")
    plt.legend()
    plt.show()


 训练:

 最后效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2355990.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据科学与计算

1.设计目标与安装 Seaborn 是一个建立在 Matplotlib 基础之上的 Python 数据可视化库,专注于绘制各种统计图形,以便更轻松地呈现和理解数据。Seaborn 的设计目标是简化统计数据可视化的过程,提供高级接口和美观的默认主题,使得用…

怎样给MP3音频重命名?是时候管理下电脑中的音频文件名了

在处理大量音频文件时,给这些文件起一个有意义的名字可以帮助我们更高效地管理和查找所需的内容。通过使用专业的文件重命名工具如简鹿文件批量重命名工具,可以极大地简化这一过程。本文将详细介绍如何利用该工具对 MP3 音频文件进行重命名。 步骤一&am…

快速上手非关系型数据库-MongoDB

简介 MongoDB 是一个基于文档的 NoSQL 数据库,由 MongoDB Inc. 开发。 NoSQL,指的是非关系型的数据库。NoSQL有时也称作Not Only SQL的缩写,是对不同于传统的关系型数据库的数据库管理系统的统称。 MongoDB 的设计理念是为了应对大数据量、…

【C++学习笔记】深入理解虚函数和多态

文章目录 1. 基本概念1.1 虚函数1.2 虚函数表1.3 虚函数表指针1.4 虚函数表在支持多态方面的工作原理 2. 类对象在内存中的布局参考 1. 基本概念 1.1 虚函数 类的成员函数,并不占用类对象的内存空间。 类中有虚函数,编译器会向类中插入一个看不见的成…

Node.js CSRF 保护指南:示例及启用方法

解释 CSRF 跨站请求伪造 (CSRF/XSRF) 是一种利用用户权限劫持会话的攻击。这种攻击策略允许攻击者通过诱骗用户以攻击者的名义提交恶意请求,从而绕过我们的安全措施。 CSRF 攻击之所以可能发生,是因为两个原因。首先,CSRF 攻击利用了用户无法辨别看似合法的 HTML 元素是否…

【Linux】VSCode用法

描述 部分图片和经验来源于网络,若有侵权麻烦联系我删除,主要是做笔记的时候忘记写来源了,做完笔记很久才写博客。 专栏目录:记录自己的嵌入式学习之路-CSDN博客 目录 1 安装环境及运行C/C 1.1 安装及配置步骤 1.2 运…

来聊聊JVM中安全点的概念

文章目录 写在文章开头详解safepoint基本概念什么是安全点?为什么需要安全点JVM如何让线程跑到最近的安全点线程什么时候需要进入安全点JVM如何保证线程高效进入安全点如何设置安全点用一次GC解释基于安全点的STW实践-基于主线程休眠了解安全点的工作过程代码示例基于日志印证…

Nginx — http、server、location模块下配置相同策略优先级问题

一、配置优先级简述 在 Nginx 中,http、server、location 模块下配置相同策略时是存在优先级的,一般遵循 “范围越小,优先级越高” 的原则,下面为你详细介绍: 1. 配置继承关系 http 块:作为全局配置块&…

线性代数—向量与矩阵的范数(Norm)

参考链接: 范数(Norm)——定义、原理、分类、作用与应用 - 知乎 带你秒懂向量与矩阵的范数(Norm)_矩阵norm-CSDN博客 什么是范数(norm)?以及L1,L2范数的简单介绍_l1 norm-CSDN博客 范数(Norm…

【业务领域】电脑主板芯片电路结构

前言 由前几期视频合集(零基础自学计算机故障排除—7天了解计算机开机过程),讲解了POST的主板软启动过程;有不少网友留言、私信来问各种不开机的故障,但大多网友没能能过我们的这合集视频,很好的理清思路,那这样的情况…

pandas读取Excel数据(.xlsx和.xls)到treeview

对于.xls文件,xlrd可能更合适,但需要注意新版本的xlrd可能不支持xlsx,不过用户可能同时需要处理两种格式,所以可能需要结合openpyxl和xlrd?或者直接用pandas,因为它内部会处理这些依赖。 然后,…

JVM——垃圾收集策略

GC的基本问题 什么是GC? GC 是 garbage collection 的缩写,意思是垃圾回收——把内存(特别是堆内存)中不再使用的空间释放掉;清理不再使用的对象。 为什么要GC? 堆内存是各个线程共享的空间&#xff0c…

马克·雷伯特:用算法让机器人飞奔的人

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 马克雷伯特:用算法让机器人飞奔的人 一、天才的起点 在机器人领域,有一个名字如雷贯耳——马克雷伯特(Marc Raibert)。作为波士顿动力公司(Boston…

信创系统资产清单采集脚本:主机名+IP+MAC 一键生成 CSV

原文链接:信创系统资产清单采集脚本:主机名IPMAC 一键生成 CSV Hello,大家好啊!今天给大家带来一篇在信创终端操作系统上自动批量采集主机名、IP 和 MAC 并导出为 CSV 表格的实战文章!本方案使用 sshpass 和 Bash 脚本…

SpringBoot获取用户信息常见问题(密码屏蔽、驼峰命名和下划线命名的自动转换)

文章目录 一、不返回password字段二、返回的createTime和updateTime为空原因解决:开启驼峰命名和下划线命名的自动转换 设置返回的日期格式 一、不返回password字段 在字段上面添加JsonIgnore注解即可 JsonIgnore // 在把对象序列化成json字符串时,忽略…

Mac下安装Python3,并配置环境变量设置为默认

下载Python 访问Python官方网站 https://www.python.org/ 首先获得python3安装路径 执行命令: which python3 以我这台电脑为例,路径为:/Library/Frameworks/Python.framework/Versions/3.9/bin/python3 编辑 bash_profile 文件 然后用 vim 打…

Linux-04-用户管理命令

一、useradd添加新用户: 基本语法: useradd 用户名:添加新用户 useradd -g 组名 用户:添加新用户到某个组二、passwd设置用户密码: 基本语法: passwd 用户名:设置用户名密码 三、id查看用户是否存在: 基本语法: id 用户名 四、su切换用户: 基本语法: su 用户名称:切换用…

【进阶】--函数栈帧的创建和销毁详解

目录 一.函数栈帧的概念 二.理解函数栈帧能让我们解决什么问题 三.相关寄存器和汇编指令知识点补充 四.函数栈帧的创建和销毁 4.1.调用堆栈 4.2.函数栈帧的创建 4.3 函数栈帧的销毁 一.函数栈帧的概念 --在C语言中,函数栈帧是指在函数调用过程中,…

【一】 基本概念与应用领域【数字图像处理】

考纲 文章目录 1 概念2005甄题【名词解释】2008、2012甄题【名词解释】可考题【简答题】可考题【简答题】 2 应用领域【了解】2.1 伽马射线成像【核医学影像】☆2.2 X射线成像2.3 紫外波段成像2.4 可见光和红外波段成像2.5 微波波段成像2.6 无线电波段成像2.7 电子显微镜成像2…

NU1680低成本、无固件、高集成度无线充电电源接收器

无线充电 电子产品具有无线充电功能使用会更便利,介绍一款低成本、无固件、高集成度无线充电电源接收器NU1680 原理图和BOM可点绑定资源下载,LC部分电容建议X7R。 Load空载切满载测试 (CC Mode) – 尽量保证电子负载没有过冲 – 电子负载不要从0到满…