代码解读 | Hybrid Transformers for Music Source Separation[06]

news2025/9/17 22:28:16

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块

        6、Hybrid Transformer 拆解频域编码模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块ISTFT模块)7个模块。已完成解读:STFT模块、频域编码模块(时域编码和频域编码类似,后续不再解读时域编码模块),待解读:Cross-Domain Transformer Encoder模块。

        本篇目标:拆解频域解码模块ISTFT模块的底层。时域解码和频域解码原理类似(后续不再拆解时域解码模块)。

二、频域解码模块


class HDecLayer(nn.Module):
    def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
                 freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
                 context_freq=True, rewrite=True):
        """
        Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
        """
        super().__init__()
        norm_fn = lambda d: nn.Identity()  # noqa
        if norm:
            norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqa
        if pad:
            pad = kernel_size // 4
        else:
            pad = 0
        self.pad = pad
        self.last = last
        self.freq = freq
        self.chin = chin
        self.empty = empty
        self.stride = stride
        self.kernel_size = kernel_size
        self.norm = norm
        self.context_freq = context_freq
        klass = nn.Conv1d
        klass_tr = nn.ConvTranspose1d
        if freq:
            kernel_size = [kernel_size, 1]
            stride = [stride, 1]
            klass = nn.Conv2d
            klass_tr = nn.ConvTranspose2d
        self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
        self.norm2 = norm_fn(chout)
        if self.empty:
            return
        self.rewrite = None
        if rewrite:
            if context_freq:
                self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
            else:
                self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,[0, context])
            self.norm1 = norm_fn(2 * chin)

        self.dconv = None
        if dconv:
            self.dconv = DConv(chin, **dconv_kw)

    def forward(self, x, skip, length):
       
        if self.freq and x.dim() == 3:
            B, C, T = x.shape
            x = x.view(B, self.chin, -1, T)

        if not self.empty:
            x = x + skip

            if self.rewrite:
                y = F.glu(self.norm1(self.rewrite(x)), dim=1)
            else:
                y = x
            if self.dconv:
                if self.freq:
                    B, C, Fr, T = y.shape
                    y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
                y = self.dconv(y)
                if self.freq:
                    y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
        else:
            y = x
            assert skip is None
        z = self.norm2(self.conv_tr(y))
        print('self.pad,self.last:', self.pad,self.last)
        if self.freq:
            if self.pad:
                z = z[..., self.pad:-self.pad, :]
        else:
            z = z[..., self.pad:self.pad + length]
            assert z.shape[-1] == length, (z.shape[-1], length)
        if not self.last:
            z = F.gelu(z)
        return z, y

        频域解码模块的核心代码如上所示。在上一篇频域编码模块的基础上,继续贴出完善之后的频域编解码模块全景图。

编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

解码层:(Conv2d+Norm1+GLU)+(ConvTranspose2d+Norm2+倒数第二个维度裁剪+GELU),    Norm1\Norm2:Identity()

残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())+(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#频域编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))

#频域解码层4-1的Conv2d和ConvTranspose2d
Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1)) 
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1))
Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1))
Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
ConvTranspose2d(48, 16, kernel_size=(8, 1), stride=(4, 1))

        残差连接模块如下所示。

#残差连接1
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 6, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 96, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))

#残差连接2
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 12, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 192, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))

#残差连接3
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 24, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 384, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))

#残差连接4
DConv(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
    (1): Sequential(
      (0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): GroupNorm(1, 48, eps=1e-05, affine=True)
      (2): GELU(approximate=none)
      (3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(1, 768, eps=1e-05, affine=True)
      (5): GLU(dim=1)
      (6): LayerScale()
    )
  )
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

三、ISTFT模块

        ISTFT模块的核心代码如下所示。

import torch as th
def ispectro(z, hop_length=None, length=None, pad=0):
    *other, freqs, frames = z.shape
    n_fft = 2 * freqs - 2
    z = z.view(-1, freqs, frames)
    win_length = n_fft // (1 + pad)
    is_mps = z.device.type == 'mps'
    if is_mps:
        z = z.cpu()
    x = th.istft(z,
                 n_fft,
                 hop_length,
                 window=th.hann_window(win_length).to(z.real),
                 win_length=win_length,
                 normalized=True,
                 length=length,
                 center=True)
    _, length = x.shape
    return x.view(*other, length)

        其中,torch.istft【逆短时傅里叶变换(Inverse Short Time Fourier Transform,ISTFT)】,该函数期望是torch.stft函数的逆过程。它具有相同的参数(加上一个可选参数length),并且应该返回原始信号的最小二乘估计。算法将根据NOLA条件(非零重叠)进行检查。

#### torch.istft接口参数####
input (Tensor): 输入张量,期望是`torch.stft`的输出,可以是复数形式(`channel`, `fft_size`, `n_frame`),或者是实数形式(`channel`, `fft_size`, `n_frame`, 2),其中`channel`维度是可选的。

       deprecated:: 1.8.0
            实数输入已废弃,请使用`stft(..., return_complex=True)`返回的复数输入代替。
n_fft (int): 傅里叶变换的大小。
hop_length (Optional[int]): 相邻滑动窗口帧之间的距离。(默认:`n_fft // 4`)
win_length (Optional[int]): 窗口帧和STFT滤波器的大小。(默认:`n_fft`)
window (Optional[torch.Tensor]): 可选的窗函数。(默认:`torch.ones(win_length)`)
center (bool): 指示输入是否在两边进行了填充,使得第`t`帧位于时间`t × hop_length`处居中。(默认:`True`)
normalized (bool): 指示STFT是否被标准化。(默认:`False`)
onesided (Optional[bool]): 指示STFT是否为单边谱。(默认:如果输入尺寸中的`n_fft != fft_size`则为`True`)
length (Optional[int]): 修剪信号的长度,即原始信号的长度。(默认:整个信号)
return_complex (Optional[bool]):指示输出是否应为复数,或者输入是否应假定源自实信号和窗函数。注意,这与`onesided=True`不兼容。(默认:`False`)

        频域解码模块和ISTFT模块解读完毕。还剩一个Cross-Domain Transformer Encoder模块没有解读。后面又来新的活了,希望能把demucs落地~。


        感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1822537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Doris连接超时问题排查记录

文章目录 一、现象描述二、问题排查1、分析驱动包2、分析Mysql客户端(问题解决) 一、现象描述 先上官网部署地址,按照官网上一步步进行部署 https://doris.apache.org/zh-CN/docs/get-starting/quick-start 基本到最后都挺顺利的&#xff0c…

STM32学习笔记(四)--TIM定时器中断详解

(1)配置步骤1.配置RCC外设时钟2.配置时基单元的时钟3.配置初始化时基单元4.使能更新中断5.配置NVIC 选择一个合适的优先级6.启动定时器 其中涉及外设有 RCC内部时钟(EIR外部时钟 ITR其他定时器 TIx捕获通道)、TIM、NVIC 高级定时器…

大模型赛道有前景吗?普通人该如何入门大模型?(附AI大模型资源)

大模型赛道有前景吗? 这个问题,是个热门话题,但不是个好问题。 因为,它基于不同的提问人、提问意图,会有不同的答案。 对于一个职业发展初期的新人,提问的意图可能是:我要不要转行去大模型赛…

【VS】尚未配置为Web项目XXXX指定的本地IIS URL HTTP://localhost

报错原因: 我们在Web项目的属性配置中勾选了“使用本地IIS Web服务器”; 本来嘛,这也没啥,问题是当我们的电脑IP改变时,将会导致程序找不到原来的IP地址了,那么当然会报错啦。 解决办法: 其实…

南方cass专业测绘软件下载,南方cass功能强大的cad辅助测绘软件获取!

在测绘领域,南方CASS测绘软件无疑是一颗璀璨的明星,被誉为“全能选手”。这款软件在功能方面表现出了令人赞叹的多样性和专业性,为测绘工作提供了极大的便利。 ​ 首先,南方CASS测绘软件具备强大的数据兼容性,支持多种…

Blender:渲染输出

渲染输出界面 渲染设置界面: 输出设置界面: 输出文件格式 【文档】 视频导出格式: AVI JPEG 使用JPEG压缩的AVI。有损,能得到更小的文件,但大小无法与编解码器的压缩算法得到的文件相比。JPEG 压缩也是数字摄像机使用…

Vue3项目中Pinia使用详解

开篇 本文的目的是创建一个使用typescript的vue3项目,并使用pinia来管理状态。 详细步骤 创建项目 创建vue3项目,并使用vite作为打包工具 npm create vitelatest vue3_pinia // 选择vue,随后选择typesript进入项目,并按照依赖包 cd vue3_…

适用VS2019尝试生成跨平台的动态库

文章目录 1、2、步骤2.1 创建一个CMake项目2.2 写一个简单的计算加法的函数2.3 调整CMakeLists.txt2.4 Windows下编译x86的库2.4.1 配置x86-release2.4.2 选择启动项2.4.3 生成动态库 2.5 linux下编译动态库2.5.1 参考2.4.1设置Linux-GCC-Release配置X642.5.2 配置远程linux计算…

r语言数据分析案例25-基于向量自回归模型的标准普尔 500 指数长期预测与机制分析

一、背景介绍 2007 年的全球经济危机深刻改变了世界经济格局,引发了一系列连锁反应,波及各大洲。经济增长停滞不前,甚至在某些情况下出现负增长,给出口导向型发展中国家带来了不确定性。实体经济受到的冲击尤为严重,生…

Python学习笔记8:入门知识(八)

前言 本篇是元组的知识点学习以及知识点的补充 元组 概念 不可变的列表,叫做元组。 在之前列表的特性中,我们就说过列表是可变的,但是在实际使用过程中,我们有时候仍然需要一系列不可变的元素,这个时候就需要元组出…

每日5题Day24 - LeetCode 116 - 120

每一步向前都是向自己的梦想更近一步,坚持不懈,勇往直前! 第一题:116. 填充每个节点的下一个右侧节点指针 - 力扣(LeetCode) /* // Definition for a Node. class Node {public int val;public Node left;…

万字长文爆肝Spring(一)

Spring_day01 今日目标 掌握Spring相关概念完成IOC/DI的入门案例编写掌握IOC的相关配置与使用掌握DI的相关配置与使用 1,课程介绍 对于一门新技术,我们需要从为什么要学、学什么以及怎么学这三个方向入手来学习。那对于Spring来说: 1.1 为什么要学? …

the histogram of cross-entropy loss values 交叉熵损失值的直方图以及cross-entropy loss交叉熵损失

交叉熵损失值的直方图在机器学习和深度学习中有几个重要的作用和用途: 评估模型性能: 直方图可以帮助评估模型在训练数据和测试数据上的性能。通过观察损失值的分布,可以了解模型在不同数据集上的表现情况。例如,损失值分布的形状和范围可以反…

2024 Idea最新激活码

idea的激活与安装 操作如下: ① 打开网站:https://web.52shizhan.cn 切换到:激活码,点击获取 ② 这个时候就跳转到现成账号页面,点击获取体验号,如图 ③ 来到了获取现成账号的页面了。输入你的邮箱账号即…

uni app 自定义 带popup弹窗的input组件

工作需要。自定义了个带popup弹窗的input组件。此组件满足个人需求&#xff0c;不喜勿喷。应该可以看明白怎么回事&#xff0c;也能自己改改&#xff0c;所以就不要联系了&#xff0c;点赞收藏就好 <template><view class"dialog_main"><input v-mod…

【第七篇】SpringSecurity核心组件和核心过滤器

一、SpringSecurity中的核心组件 在SpringSecurity中的jar分为4个,作用分别为 jar作用spring-security-coreSpringSecurity的核心jar包,认证和授权的核心代码都在这里面spring-security-config如果使用SpringSecurity XML命名空间进行配置或者SpringSecurity的<br />J…

Python 使用 Tkinter库 设置 tkinter ttk 框架的背景颜色

Tkinter 设置 tkinter ttk 框架的背景颜色 在本文中&#xff0c;我们将介绍如何使用 Tkinter 在 tkinter ttk 框架中设置背景颜色。Tkinter 是 Python 中常用的 GUI 工具包&#xff0c;ttk 则是 Tkinter 中的一个模块&#xff0c;提供了一套更加现代化的控件。 Tkinter 简介 …

ESP32基础应用之esp32连接腾讯云并使用微信小程序控制的智能灯

文章目录 1. 项目简介1.1 功能接收1.2 使用资源1.3 测试平台 2 腾讯云物联网开发平台3 esp32设备开发3.1 准备参考例程3.2 vscode平台创建测试工程3.3 修改工程 问题总结使用PowerShell命令行终端生成的二维码不能用 1. 项目简介 1.1 功能接收 实现腾讯云创建项目与设备&…

泰坦尼克号数据集机器学习实战教程

泰坦尼克号数据集是一个公开可获取的数据集&#xff0c;源自1912年沉没的RMS泰坦尼克号事件。这个数据集被广泛用于教育和研究&#xff0c;特别是作为机器学习和数据分析的经典案例。数据集记录了船上乘客的一些信息&#xff0c;以及他们是否在灾难中幸存下来。以下是数据集中主…