TCN代码详解-Torch (误导纠正)

news2025/7/19 9:00:30

1. 绪论

TCN网络由Shaojie Bai, J. Zico Kolter, Vladlen Koltun 三人于2018提出。对于序列预测而言,通常考虑循环神经网络结构,例如RNN、LSTM、GRU等。他们三个人的研究建议我们,对于某些序列预测(音频合成、字级语言建模和机器翻译),可以考虑使用卷积网络结构。

关于TCN基本构成和他们的原理有相当多的博客已经解释的很详细的了。总结一句话:TCN = 1D FCN + 因果卷积。下面的博客对因果卷积和孔洞卷积有详细的解释。

  • 时间卷积网络(TCN):结构+pytorch代码
  • TCN论文及代码解读总结
  • 时间序列分析(5) TCN

但是,包括TCN原文作者,上面这些博客对TCN网络结构的阐释无一例外都是使用下面这张图片。而问题在于,如果不熟悉Torch操作基本的卷积网络操作,这张图片具有很大的误导性。


图1 膨胀因果卷积(膨胀因子d = 1,2,4,滤波器大小k = 3)

结合上图和上面列举的博客,我们可以大致理解到,TCN就是在序列上使用一维卷积核,沿着时间方向,按照空洞卷积的方式,依次计算。
例如,上图中,

  1. 第一个hidden层是由 d=1d=1 的空洞卷积,卷积而来,退化为基本的一维卷积操作;
  2. 第二个hidden层是由 d=2d=2 的空洞卷积,卷积而来,卷积每个值时隔开了一个值;
  3. 第二个hidden层是由 d=4d=4 的空洞卷积,卷积而来,卷积每个值时隔开了三个值;

由此,上图中网络深度为3,每一层有1个卷积操作。

如果你也是这么理解,恭喜你,成功的被我带跑偏了😈。

2. TCN结构再次图解

上图中网络深度确实为3,但是每一层并不是只有1个卷积操作。这时候就要拿出原论文中第2个图了。

图2 TCN核心结构

这张图左边展示了TCN结构的核心,卷积+残差,作者把它命名为Residual block。我这里简称为block。
可以发现一个block有两个卷积操作和一个残差操作。因此,图1中每到下一层,都会有两个卷积操作和一个残差操作,并不是一个卷积操作。再次提醒,当 d=1d=1 时,空洞卷积退化为普通的卷积,正如图2右图展示的。

因此,对于图1中由原始序列到第一层hidden的真实结构为:

3. 结合原文的torch代码解释

很多博客再源代码解释时,基本都是一个模子,没有真正解释关键参数的含义,以及他们如何通过torch的tensor作用的。

预了解TCN结构,须明白原论文中作者描述的这样一句话:

Since a TCN’s receptive field depends on the network depth n as well as filter size k and dilation factor d, stabilization of deeper and larger TCNs becomes important.

翻译是:

由于TCN的感受野依赖于网络深度n滤波器大小k扩张因子d,因此更大更深的TCN的稳定变得很重要。

下面结合作者源代码,对这三个参数解释。

3.1 TemporalConvNet

网络深度n就是有多少个block,反应到源代码的变量为num_channels的长度,即 len(numchannels)len(numchannels)。

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        """
        :param num_inputs: int,  输入通道数或者特征数
        :param num_channels: list, 每层的hidden_channel数. 例如[5,12,3], 代表有3个block, 
                                block1的输出channel数量为5; 
                                block2的输出channel数量为12;
                                block3的输出channel数量为3.
        :param kernel_size: int, 卷积核尺寸
        :param dropout: float, drop_out比率
        """
        layers = []
        num_levels = len(num_channels)
		# 可见,如果num_channels=[5,12,3],那么
		# block1的dilation_size=1
		# block2的dilation_size=2
		# block3的dilation_size=4
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

3.2 TemporalBlock

参数dilation的解释,结合上面和下面的代码。

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        """
        构成TCN的核心Block, 原作者在图中成为Residual block, 是因为它存在残差连接.
        但注意, 这个模块包含了2个Conv1d.

        :param n_inputs: int, 输入通道数或者特征数
        :param n_outputs: int, 输出通道数或者特征数
        :param kernel_size: int, 卷积核尺寸
        :param stride: int, 步长, 在TCN固定为1
        :param dilation: int, 膨胀系数. 与这个Residual block(或者说, 隐藏层)所在的层数有关系. 
                                例如, 如果这个Residual block在第1层, dilation = 2**0 = 1;
                                      如果这个Residual block在第2层, dilation = 2**1 = 2;
                                      如果这个Residual block在第3层, dilation = 2**2 = 4;
                                      如果这个Residual block在第4层, dilation = 2**3 = 8 ......
        :param padding: int, 填充系数. 与kernel_size和dilation有关. 
        :param dropout: float, dropout比率
        """
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))

        # 因为 padding 的时候, 在序列的左边和右边都有填充, 所以要裁剪
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)

        # 1×1的卷积. 只有在进入Residual block的通道数与出Residual block的通道数不一样时使用.
        # 一般都会不一样, 除非num_channels这个里面的数, 与num_inputs相等. 例如[5,5,5], 并且num_inputs也是5
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None

        # 在整个Residual block中有非线性的激活. 这个容易忽略!
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

3.3 Chomp1d

裁剪模块。这里注意,padding的时候对数据列首尾都添加了,torch官方解释如下:

padding controls the amount of padding applied to the input. It can be either a string {‘valid’, ‘same’} or a tuple of ints giving the amount of implicit padding applied on both sides.

注意这里是both sides。例如,还是上述代码中的例子,kernel_size = 3,在第一层(对于第一个block),padding = 2。对于长度为20的序列,先padding,长度为20+2×2=2420+2×2=24,再卷积,长度为(24−3)+1=22(24−3)+1=22。所以要裁掉,保证输出序列与输入序列相等。

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

4. 验证TCN的输入输出

根据上述代码的解释和理解,我们可以方便的验证其输入和输出。

# 输入27个通道,或者特征
# 构建1层的TCN,最后输出一个通道,或者特征
model2 = TemporalConvNet(num_inputs=27, num_channels=[32,16,4,1], kernel_size=3, dropout=0.3)

import torch

# 检测输出
with torch.no_grad():
	# 模型输入一定是 (batch_size, channels, length)
    model2.eval() 
    print(model2(torch.randn(16,27,20)).shape) 

打印结果为(16, 1, 20) 。通道数降为1。输入序列长度20, 输出序列长度也是20。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/17444.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Matlab仿真极化双基地雷达系统(附源码)

目录 一、系统设置 二、系统仿真 三、使用圆极化接收阵列 四、总结 五、程序 此示例演示如何仿真极化双基地雷达系统以估计目标的范围和速度。发射器、接收器和目标运动学被考虑在内。 一、系统设置 该系统以 300 MHz 的频率运行,使用线性 FM 波形&#xff0…

Devkit代码迁移工具——smartdenovo源码迁移

smartdenovo源码迁移 迁移前准备工作 1、服务器和操作系统正常运行。 2、PC端已经安装SSH远程登录工具。 3、Porting Advisor已在准备好的x86平台环境和鲲鹏平台环境中完成安装并正常运行。 4、待迁移的相关软件包、源代码已准备就绪。 迁移步骤 1、利用Porting Advisor的源码…

人工智能学习:ResNet神经网络(8)

ResNet是一种非常有效的图像分类识别的模型,可以参考如下的链接 https://blog.csdn.net/qq_45649076/article/details/120494328 ResNet网络由残差(Residual)结构的基本模块构成,每一个基本模块包含几个卷积层。其中,…

【MySQL数据库笔记 - 进阶篇】(五)锁

✍个人博客:https://blog.csdn.net/Newin2020?spm1011.2415.3001.5343 📚专栏地址:暂定 📝视频地址:黑马程序员 MySQL数据库入门到精通 📣专栏定位:这个专栏我将会整理 B 站黑马程序员的 MySQL…

硬件科普系列之显示篇——LCD与OLED知多少

前言 无论是手机还是电脑,作为机器与人交互最为频繁的硬件设备,显示屏一直是决定用户体验最为关键的因素之一。大家近几年在购买手机的时候,可以发现目前大部分手机都在使用OLED屏幕,那么你有没有思考过为什么各大厂商都在大力推…

jupuyter的背景主题

jupuyter的背景主题一.背景主题安装查看可用主题1.主题安装2. **查看可用主题**3.更换主题,字体等其他设置4.其他命令,还原原本主题二.每个主题的效果1.chesterish2. grade33.gruvboxd4.oceans165.onedork6.solarizedd7.solarizedl一.背景主题安装查看可…

上帝视角看Vue源码整体架构+相关源码问答

前言 这段时间利用课余时间夹杂了很多很多事把 Vue2 源码学习了一遍,但很多都是跟着视频大概过了一遍,也都画了自己的思维导图。但还是对详情的感念模糊不清,故这段时间对源码进行了总结梳理。 本篇文章更合适于已看过 Vue2 源码&#xff0c…

使用NNI对DLASeg剪枝的失败记录

本文希望对CenterNet算法的Backbone暨DLASeg进行剪枝。 剪枝试验涉及3个文件,分别为: DCN可变性卷积dcn_v2.py,因为DLASeg依赖DCN。 #!/usr/bin/env python from __future__ import absolute_import from __future__ import print_functio…

如何在 Windows 10上修复0x000006ba错误

修复0x000006ba错误 可能导致此错误代码的原因已确认的可行的解决办法运行打印机疑难解答重新启动后台打印程序服务清除 PRINTERS 文件夹运行 SFC 和 DISM 扫描启用打印机共享某些 Windows 10 在尝试在 Windows 10 上打印新文档时遇到0x000006ba错误代码。其他用户在尝试使用 W…

【面试题】line-height继承问题

1. line-height为具体数值 当父元素line-height的值为具体数值的时候&#xff0c;例如30px&#xff0c;则子元素的line-height直接继承该数值。 <style>body{font-size: 20px;line-height: 50px;}p{background-color: #ccc;font-size: 16px;} </style><body&g…

类和对象的初步介绍

文章目录面向对象的初步认识什么是面向对象面向对象与面向过程类定义和使用简单认识类类的定义格式随堂练习定义一个学生类类的实例化什么是实例化类和对象的说明this 引用为什么要有this引用什么时this引用this引用的特性对象的构造和初始化构造方法概念特性默认初始化就地初始…

Shell脚本学习指南(三)——文本处理工具

文章目录排序文本行的排序以字段的排序文本块排序sort的效率sort的稳定性sort小结删除重复重新格式化段落计算行数、字数以及字符数打印打印技术的演化其他打印软件提取开头或结尾数行排序文本 含有独立数据记录的文本文恶剪&#xff0c;通常都可以拿来排序。一个可预期的记录…

Vue3 - 组件通信(父传子)

前言 在 Vue3 中&#xff0c;父组件向子组件传参的方法。 与 Vue2 相比&#xff0c;还是有一些区别的。 基础示例 现在我们的需求是&#xff0c;要通过父组件&#xff0c;传递一个标题来让子组件显示。 子组件 Com.vue&#xff1a; <template><div>{{ title }}&l…

大数据工程师必备之数据可视化技术

可视化技术 数据&#xff1a; 偏耀明 7800 高军鹏 8000 代欣 8800 王国庆 20000 ​ 应对现在数据可视化的趋势&#xff0c;越来越多企业需要在很多场景(营销数据、生产数据、用户数据)下使用&#xff0c;可视化图表来展示体现数据&#xff0c;让数据更加直观&#xff0c;数…

tp6使用redis消息队列

尾部写入 for ($i1;$i<1000;$i){Cache::store(redis)->rpush(list,date("Y-m-d H:i:s")."消息{$i}"); }头部读取消息队列并删除 $list Cache::store(redis)->lpop(list); 1、新建个方法运行写入队列 public function hello(){for ($i1;$i<…

C++ Reference: Standard C++ Library reference: Containers: deque: deque: erase

C官网参考链接&#xff1a;https://cplusplus.com/reference/deque/deque/erase/ 公有成员函数 <deque> std::deque::erase C98 iterator erase (iterator position); iterator erase (iterator first, iterator last); C11 iterator erase (const_iterator position )…

Android 后台服务启动Actvity

一、问题背景 相机自动化测试需求&#xff0c;测试apk通过bindService绑定相机apk里面的一个服务&#xff0c;通过AIDL接口的方式向相机apk发送命令&#xff0c;服务接收到命令之后会拉起相机的Activity。原本没有人为干预的情况下是可以拉起这个Activity的&#xff0c;但是拉…

基于PYTHON游乐场服务管理系统的设计与实现

摘要 项目门票是游乐园必不可少的一个部分。在游乐园发展的整个过程中&#xff0c;项目门票担负着最重要的角色。为满足如今日益复杂的管理需求&#xff0c;各类管理系统程序也在不断改进。本课题所设计的游乐场服务管理系统&#xff0c;使用Django框架&#xff0c;Python语言进…

如何优雅部署OpenStack私有云I--Kolla

为方便大数据平台与管理工具的研发&#xff0c;在公司成本不额外增加的情况下&#xff0c;从公司仓库里拉了几台下线物理机来做大数据平台的实验环境。但整体物理机性能都偏高&#xff0c;单独安装一个大数据服务&#xff0c;很豪&#xff0c;但是也很浪费。而且主机台数不是很…

优先级队列(堆)——小记

文章目录堆概念堆的创建堆向下调整堆的插入堆的删除堆排序整体代码&#xff08;创建堆&#xff08;向下调整&#xff09;&#xff0c;堆的插入&#xff0c;堆的删除&#xff0c;堆排序&#xff09;TOPKPriorityQueue特性堆 概念 如果有一个关键码的集合Kk0&#xff0c;k1&…