DeepSpeed简介及加速模型训练

news2025/5/22 23:05:27

DeepSpeed是由微软开发的开源深度学习优化框架,专注于大规模模型的高效训练与推理。其核心目标是通过系统级优化技术降低显存占用、提升计算效率,并支持千亿级参数的模型训练。

官网链接:deepspeed
训练代码下载:git代码

一、DeepSpeed的核心作用

  1. 显存优化与高效内存管理

    • ZeRO(Zero Redundancy Optimizer)技术:通过分片存储模型状态(参数、梯度、优化器状态)至不同GPU或CPU,显著减少单卡显存占用。例如,ZeRO-2可将显存占用降低8倍,支持单卡训练130亿参数模型。
      在这里插入图片描述

    • Offload技术:将优化器状态卸载到CPU或NVMe硬盘,扩展至TB级内存,支持万亿参数模型训练。

    • 激活值重计算(Activation Checkpointing):牺牲计算时间换取显存节省,适用于长序列输入。

  2. 灵活的并行策略

    • 3D并行:融合数据并行(DP)、模型并行(张量并行TP、流水线并行PP),支持跨节点与节点内并行组合,适应不同硬件架构。

    • 动态批处理与梯度累积:减少通信频率,支持超大Batch Size训练。

  3. 训练加速与混合精度支持

    • 混合精度训练:支持FP16/BF16,结合动态损失缩放平衡效率与数值稳定性。

    • 稀疏注意力机制:针对长序列任务优化,执行效率提升6倍。

    • 通信优化:支持MPI、NCCL等协议,降低分布式训练通信开销。

  4. 推理优化与模型压缩

    • 低精度推理:通过INT8/FP16量化减少模型体积,提升推理速度。

    • 模型剪枝与蒸馏:压缩模型参数,降低部署成本。


二、与pytorch 对比分析

1. 优势

  • 显存效率:相比PyTorch DDP,单卡80GB GPU可训练130亿参数模型(传统方法仅支持约10亿)。

  • 并行灵活性:支持3D并行组合,优于Horovod(侧重数据并行)和Megatron(侧重模型并行)。

  • 生态集成:与Hugging Face Transformers、PyTorch无缝兼容,简化现有项目迁移。

  • 全流程覆盖:同时优化训练与推理,而vLLM仅专注推理优化。

2. 局限性

  • 配置复杂度:分布式训练需手动调整通信策略和分片参数,学习曲线陡峭(需编写JSON配置文件)。

  • 硬件依赖:部分高级功能(如ZeRO-Infinity)依赖NVMe硬盘或特定GPU架构。

  • 推理效率:纯推理场景下,vLLM的吞吐量更高(连续批处理优化更专精)。


三、训练用例

1、ds_config.json(deepspeed执行训练时,使用的配置文件)
  • deepspeed训练模型时,不需要在代码中定义优化器,只需要在 json 文件中进行配置即可, json文件内容如下:
{
  "train_batch_size": 128, //所有GPU上的 单个训练批次大小 之和
  "gradient_accumulation_steps": 1, //梯度累积 步数
  "optimizer": {
    "type": "Adam", //选择的 优化器
    "params": {
      "lr": 0.00015 //相关学习率大小
    }
  },
  "zero_optimization": { //加速策略
      "stage":2
  }
}

2、训练函数

  • 将模型包装成 deepspeed 形式
#将模型 包装成 deepspeed 形式
model_engine, _, _, _ = deepspeed.initialize(args=args,
                                                     model=model,
                                                     model_parameters=model.parameters())
  • 使用 deepspeed 包装后的模型 进行 反向传播和梯度更新
#使用 deepspeed 进行 反向传播和梯度更新
#反向传播
model_engine.backward(loss)

#梯度更新
model_engine.step()
  • 完整训练代码如下:
'''
使用命令行进行启动

启动命令如下:
deepspeed ds_train.py --epochs 10 --deepspeed --deepspeed_config ds_config.json
'''

import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModel



if __name__ == '__main__':
    #读取命令行 传递的参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", help = "local device id on current node", type = int, default=0)
    parser.add_argument("--epochs", type = int, default=1)
    parser = deepspeed.add_config_arguments(parser)

    args = parser.parse_args()

    #获取数据集
    train_loader, test_loader = load_data() #数据集加载器中的 batch_size的大小 = (ds_config.json中 train_batch_size/gpu数量)

    #获取原始模型
    model = CustomModel().cuda()

    #将模型 包装成 deepspeed 形式
    model_engine, _, _, _ = deepspeed.initialize(args=args,
                                                     model=model,
                                                     model_parameters=model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss().cuda() # 损失函数(分类任务常用)
    
    for i in range(args.epochs):
        for inputs, labels in train_loader:
            #前向传播
            inputs = inputs.cuda()
            labels = labels.cuda()

            outputs = model_engine(inputs)
            loss = loss_fn(outputs, labels)

            #使用 deepspeed 进行 反向传播和梯度更新
            #反向传播
            model_engine.backward(loss)
        
            #梯度更新
            model_engine.step()
        model_engine.save_checkpoint('./ds_models', i)

    #模型保存
    torch.save(model_engine.module.state_dict(),'deepspeed_train_model.pth')

3、模型评估

import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# 1. 定义数据转换(预处理)
transform = transforms.Compose([
    transforms.ToTensor(),          # 转为Tensor格式(自动归一化到0-1)
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化(MNIST的均值和标准差)
])

test_data = datasets.MNIST(
        root='./data',
        train=False,          # 测试集
        transform=transform
    )

#获取数据集
train_loader, test_loader = load_data()

model = CustomModel()
model.load_state_dict(torch.load('deepspeed_train_model.pth'))

#评估
model.eval()  # 设置为评估模式
correct = 0
total = 0

with torch.no_grad():  # 不计算梯度(节省内存)
    for images, labels in test_loader:
        images, labels = images, labels
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # 取概率最大的类别
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"测试集准确率: {100 * correct / total:.2f}%")


# 随机选择一张测试图片
index = np.random.randint(0,1000)  # 可以修改这个数字试不同图片
test_image, true_label = test_data[index]
test_image = test_image.unsqueeze(0)  # 增加批次维度

# 预测
with torch.no_grad():
    output = model(test_image)
predicted_label = torch.argmax(output).item()

print(f"预测: {predicted_label}, 真实: {true_label}")

# 显示结果
plt.imshow(test_image.cpu().squeeze(), cmap='gray')
plt.title(f"预测: {predicted_label}, 真实: {true_label}")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2383448.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openlayer:10点击地图上某些省份利用Overlay实现提示省份名称

实现点击地图上的省份,在点击经纬度坐标位置附近利用Overlay实现提示框提示相关省份名称。本文介绍了如何通过OpenLayers库实现点击地图上的省份,并在点击的经纬度坐标位置附近显示提示框,提示相关省份名称。首先,定义了两个全局变…

upload-labs通关笔记-第13关 文件上传之白名单POST法

目录 一、白名单过滤 二、%00截断 1.截断原理 2、截断条件 &#xff08;1&#xff09;PHP版本 < 5.3.4 &#xff08;2&#xff09;magic_quotes_gpc配置为Off &#xff08;3&#xff09;代码逻辑存在缺陷 三、源码分析 1、代码审计 &#xff08;1&#xff09;文件…

数据库健康监测器(BHM)实战:如何通过 HTML 报告识别潜在问题

在数据库运维中,健康监测是保障系统稳定性与性能的关键环节。通过 HTML 报告,开发者可以直观查看数据库的运行状态、资源使用情况与潜在风险。 本文将围绕 数据库健康监测器(Database Health Monitor, BHM) 的核心功能展开分析,结合 Prometheus + Grafana + MySQL Export…

Oracle 11g 单实例使用+asm修改主机名导致ORA-29701 故障分析

解决 把服务器名修改为原来的&#xff0c;重启服务器。 故障 建表空间失败。 分析 查看告警日志 ORA-1119 signalled during: create tablespace splex datafile ‘DATA’ size 2000M… Tue May 20 18:04:28 2025 create tablespace splex datafile ‘DATA/option/dataf…

OpenCV CUDA模块图像过滤------用于创建一个最大值盒式滤波器(Max Box Filter)函数createBoxMaxFilter()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 createBoxMaxFilter()函数创建的是一个 最大值滤波器&#xff08;Maximum Filter&#xff09;&#xff0c;它对图像中每个像素邻域内的像素值取最…

Redis数据库-消息队列

一、消息队列介绍 二、基于List结构模拟消息队列 总结&#xff1a; 三、基于PubSub实现消息队列 (1)PubSub介绍 PubSub是publish与subscribe两个单词的缩写&#xff0c;见明知意&#xff0c;PubSub就是发布与订阅的意思。 可以到Redis官网查看通配符的书写规则&#xff1a; …

破解充电安全难题:智能终端的多重防护体系构建

随着智能终端的普及&#xff0c;充电安全问题日益凸显。从电池过热到短路起火&#xff0c;充电过程中的安全隐患不仅威胁用户的生命财产安全&#xff0c;也制约了行业的发展。如何构建一套高效、可靠的多重防护体系&#xff0c;成为破解充电安全难题的关键。通过技术创新和系统…

apptrace 三大策略,助力电商 App 在 618 突围

随着 5 月 13 日 “618” 电商大促预售战的打响&#xff0c;各大平台纷纷祭出百亿补贴、消费券等大招&#xff0c;投入超百亿流量与数十亿现金&#xff0c;意图在这场年度商战中抢占先机。但这场流量争夺战远比想象中艰难&#xff0c;中国互联网络信息中心数据显示&#xff0c;…

SuperVINS:应对挑战性成像条件的实时视觉-惯性SLAM框架【全流程配置与测试!!!】【2025最新版!!!!】

一、项目背景及意义 SuperVINS是一个改进的视觉-惯性SLAM&#xff08;同时定位与地图构建&#xff09;框架&#xff0c;旨在解决在挑战性成像条件下的定位和地图构建问题。该项目基于经典的VINS-Fusion框架&#xff0c;但通过引入深度学习方法进行了显著改进。 视觉-惯性导航系…

Node-Red通过开疆智能Profinet转ModbusTCP采集西门子PLC数据配置案例

一、内容简介 本篇内容主要介绍Node-Red通过node-red-contrib-modbus插件与开疆智能ModbusTCP转Profinet设备进行通讯&#xff0c;这里Profinet转ModbusTCP网关作为从站设备&#xff0c;Node-Red作为主站分别从0地址开始读取10个线圈状态和10个保持寄存器&#xff0c;分别用Mo…

【性能测试】jvm监控

使用本地jvisualvm远程监控服务器 参考文章&#xff1a;https://blog.csdn.net/yeyuningzi/article/details/140261411 jvisualvm工具默认是监控本地jvm&#xff0c;如果需要监控远程就要修改配置参数 1、先查看是否打开 ps -ef|java 如果打开杀掉进程 2、进入项目服务路径下…

Uniapp开发鸿蒙应用时如何运行和调试项目

经过前几天的分享&#xff0c;大家应该应该对uniapp开发鸿蒙应用的开发语法有了一定的了解&#xff0c;可以进行一些简单的应用开发&#xff0c;今天分享一下在使用uniapp开发鸿蒙应用时怎么运行到鸿蒙设备&#xff0c;并且在开发中怎么调试程序。 运行 Uniapp项目支持运行到…

QT+RSVisa控制LXI仪器

1.下载并安装visa R&SVISA - Rohde & Schwarz China 2.安装后的目录说明 安装了64位visa会默认把32位的安装上&#xff1b; 64位库和头文件目录为&#xff1a;C:\Program Files\IVI Foundation 32位库和头文件目录为&#xff1a;C:\Program Files (x86)\IVI Foundation…

springboot3+vue3融合项目实战-大事件文章管理系统-文章分类也表查询(条件分页)

在pojo实体类中增加pagebean实体类 Data NoArgsConstructor AllArgsConstructor public class PageBean <T>{private Long total;//总条数private List<T> items;//当前页数据集合 }articlecontroller增加代码 GetMappingpublic Result<PageBean<Article&g…

Canvas进阶篇:鼠标交互动画

Canvas进阶篇&#xff1a;鼠标交互动画 前言获取鼠标坐标鼠标事件点击事件监听代码示例效果预览 拖动事件监听代码示例效果预览 结语 前言 在上一篇文章Canvas进阶篇&#xff1a;基本动画详解 中&#xff0c;我们讲述了在Canvas中实现动画的基本步骤和动画的绘制方法。本文将进…

【Node.js】Web开发框架

个人主页&#xff1a;Guiat 归属专栏&#xff1a;node.js 文章目录 1. Node.js Web框架概述1.1 Web框架的作用1.2 Node.js主要Web框架生态1.3 框架选择考虑因素 2. Express.js2.1 Express.js概述2.2 基本用法2.2.1 安装Express2.2.2 创建基本服务器 2.3 路由2.4 中间件2.5 请求…

使用Vite创建一个动态网页的前端项目

1. 引言 虽然现在的前端更新换代的速度很快&#xff0c;IDE和工具一批批的换&#xff0c;但是我们始终要理解一点基本的程序构建的思维&#xff0c;这些环境和工具都是为了帮助我们更快的发布程序。笔者还记得以前写前端代码的时候&#xff0c;只使用文本编辑器&#xff0c;然…

系统架构设计师案例分析题——web篇

软考高项系统架构设计师&#xff0c;其中的科二案例分析题为5选3&#xff0c;总分75达到45分即合格。本贴来归纳web设计题目中常见的知识点即细节&#xff1a; 目录 一.核心知识 1.常见英文名词 2.私有云 3.面向对象三模型 4.计网相关——TCP和UDP的差异 5.MQTT和AMQP协…

MySQL--day5--多表查询

&#xff08;以下内容全部来自上述课程&#xff09; 多表查询 1. 为什么要用多表查询 # 如果不用多表查询 #查询员工名为Abel的人在哪个城市工作? SELECT* FROM employees WHERE last_name Abel;SELECT * FROM departments WHERE department_id 80;SELECT * FROM locati…

leetcode hot100刷题日记——7.最大子数组和

class Solution { public:int maxSubArray(vector<int>& nums) {//方法一&#xff1a;动态规划//dp[i]表示以i下标结尾的数组的最大子数组和//那么在i0时&#xff0c;dp[0]nums[0]//之后要考虑的就是我们要不要把下一个数加进来&#xff0c;如果下一个数加进来会使结…