从代码学习深度学习 - 风格迁移 PyTorch版

news2025/5/20 16:36:03

文章目录

  • 前言
  • 方法 (Methodology)
  • 阅读内容和风格图像
  • 预处理和后处理
  • 抽取图像特征
  • 定义损失函数
    • 内容损失 (Content Loss)
    • 风格损失 (Style Loss)
    • 全变分损失 (Total Variation Loss)
    • 总损失函数
  • 初始化合成图像
  • 训练模型
  • 总结


前言

大家好!欢迎来到我们的深度学习代码学习系列。今天,我们将深入探讨一个非常有趣且富有创意的计算机视觉领域——风格迁移 (Style Transfer)

想象一下,你能否将梵高的《星夜》的独特笔触和色彩应用到你拍摄的一张城市风景照片上?或者将一幅著名油画的风格赋予你心爱的宠物照片?风格迁移技术正是致力于实现这种艺术融合的魔法。

简单来说,风格迁移的目标是生成一张新的图像,这张图像既保留了内容图像 (Content Image) 的主要结构和物体,又融入了风格图像 (Style Image) 的艺术纹理、色彩和笔触特点。这背后是深度学习,特别是卷积神经网络 (CNN) 的强大能力,它们能够从图像中学习并分离出内容表示和风格表示。

在本篇博客中,我们将一起:

  1. 理解风格迁移的基本原理和核心思想。
  2. 逐步分析并实现一个基于 PyTorch 的风格迁移模型。
  3. 学习如何定义和使用内容损失、风格损失以及全变分损失来指导模型的优化过程。
  4. 通过代码实践,将一张内容图像和一张风格图像融合成一张全新的艺术作品。

我们将详细解读每一个代码块,确保即使是初学者也能跟上节奏。希望通过这篇博客,你能不仅理解风格迁移的理论,更能亲手实现它,感受深度学习在创意应用上的魅力。

让我们开始这场艺术与代码的探索之旅吧!

完整代码:下载链接

方法 (Methodology)

风格迁移的核心思想是利用预训练的卷积神经网络 (CNN) 来分别提取内容图像的内容特征和风格图像的风格特征。然后,我们以内容图像(或随机噪声)为起点,生成一张初始的合成图像。这个合成图像是整个过程中唯一需要更新和优化的部分。

迭代优化的目标是让合成图像在内容上接近内容图像,在风格上接近风格图像。这是通过定义一个总损失函数来实现的,该损失函数通常包含三个部分:

  1. 内容损失 (Content Loss):衡量合成图像与内容图像在内容特征上的差异。我们希望合成图像能“看清”内容图像中的物体和场景。
  2. 风格损失 (Style Loss):衡量合成图像与风格图像在风格特征(如纹理、笔触、色彩分布)上的差异。我们希望合成图像能“模仿”风格图像的艺术风格。
  3. 全变分损失 (Total Variation Loss):作为一种正则化项,用于减少合成图像中的噪点,使其更加平滑自然。

下图展示了风格迁移的基本流程:

在这里插入图片描述

我们首先初始化合成图像,可以将其初始化为内容图像。该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数。然后,选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们选择其中某些层的输出作为内容特征或风格特征。例如,上图中,预训练的神经网络含有多个卷积层和池化层,我们选择其中某些层的输出(如 relu4_2 作为内容特征,relu1_1, relu2_1, relu3_1, relu4_1, relu5_1 作为风格特征)。

接下来,通过前向传播计算风格迁移的损失函数,并通过反向传播迭代模型参数,即不断更新合成图像。当模型训练结束时,输出风格迁移的模型参数,即得到最终的合成图像。

阅读内容和风格图像

首先,我们需要加载我们的内容图像和风格图像。这里我们使用 PIL (Pillow) 库来加载图像,并用 Matplotlib 来显示它们。

# 图像风格迁移预处理代码
# 配置 matplotlib 行内显示|
%matplotlib inline

# 导入必要的库
import torch  # PyTorch 深度学习框架
import torchvision  # PyTorch 的计算机视觉工具包
from torch import nn  # 神经网络模块
import utils_for_huitu  # 自定义的绘图工具模块
import matplotlib.pyplot as plt  # 用于创建和操作 Matplotlib 图表
from PIL import Image  # Python 图像处理库

# 设置 matplotlib 图表的默认尺寸
# 该函数可能会设置全局的图表显示参数
utils_for_huitu.set_figsize()

# 加载内容图像
# content_img: PIL Image 对象,维度为 (height, width, channels)
# 其中 channels 通常为 3(RGB)或 4(RGBA)
content_img = Image.open('img/test.jpeg')

# 显示内容图像
# plt.imshow() 接受形状为 (H, W, C) 的数组,其中:
# H: 图像高度
# W: 图像宽度  
# C: 颜色通道数(RGB为3)
plt.imshow(content_img)
plt.show()  # 显示图像

# 获取图像信息(可选)
if hasattr(content_img, 'size'):
    # 获取图像的宽度和高度
    img_width, img_height = content_img.size  # PIL图像的尺寸为 (width, height)
    print(f"内容图像尺寸:宽度={
     img_width}px, 高度={
     img_height}px")

内容图像 img/test.jpeg 显示如下:

内容图像尺寸:宽度=600px, 高度=333px

[跳过图片具体内容]

style_img = Image.open('img/06_autumn-oak.jpg')
plt.imshow(style_img);

风格图像 img/06_autumn-oak.jpg 显示如下:

在这里插入图片描述

预处理和后处理

接下来,我们定义两个函数:preprocesspostprocess。这两个函数负责将我们加载的 PIL 图像转换为神经网络可以处理的张量格式,以及在训练完成后将张量转换回图像格式以便显示。

预处理函数 preprocess 主要执行以下操作:

  1. 将输入图像的大小调整为指定的 image_shape
  2. 将 PIL 图像转换为 PyTorch 张量。这个过程会自动将图像的维度从 (H, W, C) 转换为 (C, H, W),并将像素值从 [0, 255] 的范围归一化到 [0, 1] 的范围。
  3. 对图像在 RGB 三个通道上分别进行标准化,使用的是 ImageNet 数据集的均值和标准差。
  4. 在张量的最前面添加一个批次维度,最终输出格式为 (1, C, H, W)。

后处理函数 postprocess 则执行相反的操作:

  1. 移除批次维度。
  2. 对标准化的张量进行反标准化,恢复其原始的像素值范围。
  3. 将像素值裁剪到 [0, 1] 之间,以确保可以正确显示。
  4. 将张量转换回 PIL 图像格式。
# 图像预处理和后处理函数

import torch
import torchvision

# ImageNet 数据集的 RGB 通道均值
# 维度: [3] - 分别对应 R、G、B 三个通道
rgb_mean = torch.tensor([0.485, 0.456, 0.406])

# ImageNet 数据集的 RGB 通道标准差
# 维度: [3] - 分别对应 R、G、B 三个通道
rgb_std = torch.tensor([0.229, 0.224, 0.225])


def preprocess(img, image_shape):
    """
    预处理函数:将 PIL 图像转换为标准化的张量
    
    参数:
    - img: PIL Image 对象,维度为 (H, W, C)
    - image_shape: 目标图像尺寸,元组格式 (H, W)
    
    返回:
    - 处理后的张量,维度为 [1, C, H, W]
      其中:1 是批次维度,C=3(RGB通道),H 是高度,W 是宽度
    """
    # 定义图像变换流水线
    transforms = torchvision.transforms.Compose([
        # 1. 调整图像大小到指定尺寸
        # 输入: PIL Image (H, W, C)
        # 输出: PIL Image (image_shape[0], image_shape[1], C)
        torchvision.transforms.Resize(image_shape),
        
        # 2. 将 PIL 图像转换为张量
        # 输入: PIL Image (H, W, C),像素值范围 [0, 255]
        # 输出: torch.Tensor [C, H, W],像素值范围 [0, 1]
        torchvision.transforms.ToTensor(),
        
        # 3. 标准化:(x - mean) / std
        # 输入: torch.Tensor [C, H, W]
        # 输出: torch.Tensor [C, H, W],标准化后的值
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    
    # 应用变换并添加批次维度
    # transforms(img): 维度 [C, H, W]
    # unsqueeze(0): 维度 [1, C, H, W]
    return transforms(img).unsqueeze(0)


def postprocess(img):
    """
    后处理函数:将标准化的张量转换回 PIL 图像
    
    参数:
    - img: 标准化的张量,维度为 [B, C, H, W],其中 B 是批次大小
    
    返回:
    - PIL Image 对象,维度为 (H, W, C)
    """
    
    # 1. 提取第一个样本(移除批次维度)
    # img[0]: 维度从 [B, C, H, W] 变为 [C, H, W]
    img = img[0].to(rgb_std.device)
    
    # 2. 反标准化:x * std + mean
    # 2.1 permute(1, 2, 0): 维度从 [C, H, W] 变为 [H, W, C]
    # 2.2 乘以标准差并加上均值,恢复原始值范围
    # 2.3 torch.clamp(..., 0, 1): 将值限制在 [0, 1] 范围内
    # 结果维度: [H, W, C]
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    
    # 3. 转换回 PIL 图像
    # 3.1 permute(2, 0, 1): 维度从 [H, W, C] 变为 [C, H, W]
    # 3.2 ToPILImage(): 将张量转换为 PIL Image
    # 输出: PIL Image 对象,维度为 (H, W, C)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))


# 使用示例:
# 预处理
pil_image = Image.open('img/05_rainier.jpg')  # PIL Image (H, W, C)
print(pil_image)
tensor = preprocess(pil_image, (224, 224))  # 输出 [1, 3, 224, 224]
print(tensor.shape)

关于 torchvision.transforms.ToTensor() 的关键点总结:

  1. 维度转换:
    • 输入:PIL Image (H, W, C) 或 numpy array。
    • 输出:PyTorch Tensor [C, H, W]
    • 它会自动将通道维度从最后一维移动到第一维。
  2. 设计原因:
    • PyTorch 的标准张量格式是 [B, C, H, W](批次、通道、高度、宽度),这是卷积神经网络期望的输入格式。
    • 这种内存布局在 GPU 计算时更为高效。
  3. 不同图像类型处理:
    • RGB图像:(H, W, 3)[3, H, W]
    • 灰度图像:(H, W)[1, H, W] (自动添加通道维度)
    • RGBA图像:(H, W, 4)[4, H, W]
  4. 同时进行的操作:
    • 数据类型转换:uint8float32
    • 数值范围转换:[0, 255][0, 1]

抽取图像特征

我们将使用基于 ImageNet 数据集预训练的 VGG-19 模型来抽取图像特征。VGG-19 是一个深度卷积神经网络,因其良好的特征提取能力而常用于风格迁移任务。

首先,我们加载预训练的 VGG-19 模型。为了方便管理模型权重文件,我们先设置了权重的下载路径。PyTorch 会自动从网络下载预训练权重(如果本地不存在的话)。

# VGG19 预训练模型加载与配置

import os
import torch
import torchvision

# 设置预训练模型权重的下载路径
# 注意:这必须在加载模型之前设置,否则权重会下载到默认路径
download_path = './model_weights'  # 自定义权重保存路径

# 确保下载目录存在,如果不存在则创建
# exist_ok=True 表示如果目录已存在不会报错
os.makedirs(download_path, exist_ok=True)

# 方法1:使用 torch.hub.set_dir() 设置下载缓存目录
# 这会设置 PyTorch Hub 的默认下载目录
torch.hub.set_dir(download_path)

# 方法2:通过环境变量设置

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2380106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件设计师考试《综合知识》设计模式之——工厂模式与抽象工厂模式考点分析

软件设计师考试《综合知识》工厂模式与抽象工厂模式考点分析 1. 分值占比与考察趋势(75分制) 年份题量分值占总分比例核心考点2023111.33%抽象工厂模式适用场景2022222.67%工厂方法 vs 抽象工厂区别2021111.33%工厂方法模式结构2020111.33%简单工厂模式…

轻量级离线版二维码工具的技术分析与开发指南

摘要 本文介绍一款基于本地化运行的轻量级二维码处理工具。该工具采用标准QR Code规范实现,具备完整的生成与识别功能。通过实测验证其核心功能表现及适用场景。 主要功能模块分析 编码生成模块:支持文本/URL等多种数据类型转换;提供尺寸调…

机器学习--特征工程具体案例

一、数据集介绍 sklearn库中的玩具数据集,葡萄酒数据集。在前两次发布的内容《机器学习基础中》有介绍。 1.1葡萄酒列标签名: wine.feature_names 结果: [alcohol, malic_acid, ash, alcalinity_of_ash, magnesium, total_phenols, flavanoi…

Unreal 从入门到精通之SceneCaptureComponent2D实现UI层3D物体360°预览

文章目录 前言SceneCaptureComponent2D实现步骤新建渲染目标新建材质UI控件激活3DPreview鼠标拖动旋转模型最后前言 我们在(电商展示/角色预览/装备查看)等应用场景中,经常会看到这种3D展示的页面。 即使用相机捕获一个3D的模型的视图,然后把这个视图显示在一个UI画布上,…

电机控制杂谈(25)——为什么对于一般PMSM系统而言相电流五、七次谐波电流会比较大?

1. 背景 最近都在写论文回复信。有个审稿人问了一个问题——为什么对于一般PMSM系统而言相电流五、七次谐波电流会比较大?同时,为什么相电流五、七次谐波电流会在dq基波旋转坐标系构成六次谐波电流? 回答这个问题挺简单的,但在网…

多模态大语言模型arxiv论文略读(七十八)

AID: Adapting Image2Video Diffusion Models for Instruction-guided Video Prediction ➡️ 论文标题:AID: Adapting Image2Video Diffusion Models for Instruction-guided Video Prediction ➡️ 论文作者:Zhen Xing, Qi Dai, Zejia Weng, Zuxuan W…

【C语言】易错题 经典题型

出错原因&#xff1a;之前运行起来的可执行程序没有关闭 关闭即可 平均数&#xff08;average&#xff09; 输入3个整数&#xff0c;输出它们的平均值&#xff0c;保留3位小数。 #include <stdio.h> int main() {int a, b, c;scanf("%d %d %d", &a, &…

说一说Node.js高性能开发中的I/O操作

众所周知&#xff0c;在软件开发的领域中&#xff0c;输入输出&#xff08;I/O&#xff09;操作是程序与外部世界交互的重要环节&#xff0c;比如从文件读取数据、向网络发送请求等。这段时间&#xff0c;也指导项目中一些项目的开发工作&#xff0c;发现在Node.js运用中&#…

应用层协议简介:以 HTTP 和 MQTT 为例

文章目录 应用层协议简介&#xff1a;什么是应用层协议&#xff1f;为什么需要应用层协议&#xff1f;什么是应用层协议&#xff1f;为什么需要应用层协议&#xff1f; HTTP 协议详解HTTP 协议特点HTTP 工作的基本原理HTTP 请求与响应示例为什么 Web 应用基于 HTTP 请求&#x…

LeetCode 39. 组合总和 LeetCode 40.组合总和II LeetCode 131.分割回文串

LeetCode 39. 组合总和 需要注意的是题目已经明确了数组内的元素不重复&#xff08;重复的话需要执行去重操作&#xff09;&#xff0c;且元素都为正整数&#xff08;如果存在0&#xff0c;则会出现死循环&#xff09;。 思路1&#xff1a;暴力解法 对最后结果进行去重 每一…

如何在 Windows 11 或 10 上安装 Fliqlo 时钟屏保

了解如何在 Windows 11 或 10 上安装 Fliqlo,为您的 PC 或笔记本电脑屏幕添加一个翻转时钟屏保以显示时间。 Fliqlo 是一款适用于 Windows 和 macOS 平台的免费时钟屏保。它也适用于移动设备,但仅限于 iPhone 和 iPad。Fliqlo 的主要功能是在用户不活动时在 PC 或笔记本电脑…

国芯思辰| 轮速传感器AH741对标TLE7471应用于汽车车轮速度感应

在汽车应用中&#xff0c;轮速传感器可用于车轮速度感应&#xff0c;为 ABS、ESC 等安全系统提供精确的轮速信息&#xff0c;帮助这些系统更好地发挥作用&#xff0c;在紧急制动或车辆出现不稳定状态时&#xff0c;及时调整车轮的制动力或动力分配。 国芯思辰两线制差分式轮速…

小程序弹出层/抽屉封装 (抖音小程序)

最近忙于开发抖音小程序&#xff0c;最想吐槽的就是&#xff0c;既没有适配的UI框架&#xff0c;百度上还找不到关于抖音小程序的案列&#xff0c;我真的很裂开啊&#xff0c;于是我通过大模型封装了一套代码 效果如下 介绍 可以看到 这个弹出层是支持关闭和标题显示的&#xf…

电子电路原理第十六章(负反馈)

1927年8月,年轻的工程师哈罗德布莱克(Harold Black)从纽约斯塔顿岛坐渡轮去上班。为了打发时间,他粗略写下了关于一个新想法的几个方程式。后来又经过反复修改, 布莱克提交了这个创意的专利申请。起初这个全新的创意被认为像“永动机”一样愚蠢可笑,专利申请也遭到拒绝。但…

命令拼接符

Linux多命令顺序执行符号需要记住5个 【&#xff5c;】【||】【 ;】 【&】 【&&】 &#xff0c;在命令执行里面&#xff0c;如果服务器疏忽大意没做限制&#xff0c;黑客通过高命令拼接符&#xff0c;可以输入很多非法的操作。 ailx10 网络安全优秀回答者 互联网…

【通用智能体】Lynx :一款基于终端的纯文本网页浏览器

Lynx &#xff1a;一款基于终端的纯文本网页浏览器 一、Lynx简介二、应用场景及案例场景 1&#xff1a;服务器端网页内容快速查看场景 2&#xff1a;网页内容快速提取场景 3&#xff1a;表单提交与自动化交互场景 4&#xff1a;网络诊断与调试场景 5&#xff1a;辅助工具适配 三…

51单片机的lcd12864驱动程序

#include <reg51.h> #include <intrins.h>#define uchar

GStreamer (三)常⽤插件

常⽤插件 1、Source1.1、filesrc1.2. videotestsrc1.3. v4l2src1.4. rtspsrc和rtspclientsink 2、 Sink2.1. filesink2.2. fakesink2.3. xvimagesink2.4. kmssink2.5. waylandsink2.6. rkximagesink2.7. fpsdisplaysink 3 、视频推流/拉流3.1. 本地推流/拉流3.1.1 USB摄像头3.1…

软件架构风格系列(2):面向对象架构

文章目录 引言一、什么是面向对象架构风格1. 定义与核心概念2. 优点与局限性二、业务建模&#xff1a;用对象映射现实世界&#xff08;一&#xff09;核心实体抽象1. 员工体系2. 菜品体系 &#xff08;二&#xff09;封装&#xff1a;隐藏实现细节 三、继承实战&#xff1a;构建…

go-zero(十八)结合Elasticsearch实现高效数据检索

go-zero结合Elasticsearch实现高效数据检索 1. Elasticsearch简单介绍 Elasticsearch&#xff08;简称 ES&#xff09; 是一个基于 Lucene 库 构建的 分布式、开源、实时搜索与分析引擎&#xff0c;采用 Apache 2.0 协议。它支持水平扩展&#xff0c;能高效处理大规模数据的存…