[GTCRN 48 kHz] Causal-Stream Model 的演进思路
GTCRN 演进路径记录 v1 → v2 → v3 → v3.1/v3.2 → v4 → v4.1 的改动和原因。版本概览版本改动点参数量质量指标内存实时v1 baseline基线139KDNSMOS 3.15—×v2 transient换损失函数139KDNSMOS 3.15—×v3 causal因果化改造145KDNSMOS 2.98—√v3.1 precisionKD QAT 压缩41.6KPESQ 2.041228 KB (INT8)√v3.2 transient宽度1.5× 瞬态损失~83KPESQ ~2.15~355 KB (INT8)√v4 network opt架构精简 (4层GTConv)~87KPESQ 2.147683 KB (FP32)√v4.1 int8INT8 混合精度 C 推理~87KPESQ 2.037464 KB√网络结构 (v1/v2 共用)输入 spec (B, 513, T, 2) │ ├─ 可学习频带权重 (513,) │ ▼ ERB_48k.bm(): 513 → 219 │ 低频171保留高频342→48 ERB band │ ▼ SFE_Lite: DWConv(1×5) → PWConv → BN │ ▼ ┌─ Encoder ─────────────────────────────┐ │ DSConv: 219→110 (stride2) ← skip1 │ │ DSConv: 110→55 (stride2) ← skip2 │ │ GTConvLite×6 (d1,2,4,8,4,2) ← skip3-8 │ SubbandAttention │ └───────────────────────────────────────┘ │ ▼ DPGRNN_Enhanced × 2 │ intra: 双向GRU (频率轴) │ inter: 单向GRU (时间轴) │ ▼ ┌─ Decoder ─────────────────────────────┐ │ GTConvLite×6 skip (逆序) │ │ DSDeconv: 55→110 skip2 │ │ DSDeconv: 110→219 skip1 │ └───────────────────────────────────────┘ │ ▼ ERB_48k.bs(): 219 → 513 │ ▼ CRM掩码 → 输出GTConvLite 内部x → DWConv(3×3, dilation) → PWConv → BN → PReLU → TRALite (时序注意力) → SEBlock (通道注意力) → x (残差)DPGRNN 内部x (B,C,T,F) → reshape (B*T, F, C) → Linear → 双向GRU (频率轴) → Linear → reshape LayerNorm → reshape (B*F, T, C) → Linear → 单向GRU (时间轴) → Linear → reshape LayerNorm → 输出v1 → v2: 换损失函数问题v1 用的是标准 SpecRIMAGLoss对所有帧一视同仁。但实际听感上键盘敲击、鼠标点击这类突发噪音处理得不好。DNSMOS 是整段平均掩盖了这个问题。方案不改网络只改损失函数。加了瞬态检测# 检测能量突变energy_diff|energy[t]-energy[t-1]|transientenergy_diffthreshold*mean_energy# 瞬态帧损失放大5倍lossΣ weight[t]*frame_loss[t]weight[t]5.0iftransient[t]else1.0结果DNSMOS 基本持平 (3.1474 → 3.147)瞬态噪音主观听感明显改善训练时间变长 (29 → 71 epochs)为什么不改网络能用损失函数解决的问题就不动架构。改架构的代价要重新验证各模块交互可能引入新bug推理时有额外开销改损失函数只影响训练推理零开销。v2 → v3: 因果化问题v1/v2 是离线模型要看完整段音频才能处理。没法用在实时场景通话、直播。延迟分析非因果模型需要看未来帧感受野决定最小延迟v2大概要200-500ms实时通话要求50ms方案把所有偷看未来的操作改掉模块v2 (非因果)v3 (因果)GTConvLitepadding(d,1) 对称pad_t(k-1)*d 左边TRALiteConv1d padding2F.pad(x,(4,0))DPGRNN inter双向GRU单向GRU频率轴的操作不用改因为频率轴不涉及时间因果。v3 网络结构输入 spec (B, 513, T, 2) │ ▼ ERB_48k.bm(): 513 → 219 │ ▼ in_conv: Conv2d(2→3) │ ▼ ┌─ CausalEncoder ───────────────────────┐ │ DSConv: 219→110 ← skip1 │ │ DSConv: 110→55 ← skip2 │ │ CausalGTConvLite×6 ← skip3-8 │ SubbandAttention │ └───────────────────────────────────────┘ │ ▼ CausalDPGRNN × 2 │ intra: 双向GRU (频率轴) ← 不用改 │ inter: 单向GRU (时间轴) ← 改成单向 │ ▼ ┌─ CausalDecoder ───────────────────────┐ │ CausalGTConvLite×6 skip │ │ Fuse DSDeconv: 55→110 │ │ DSDeconv: 110→219 skip1 │ └───────────────────────────────────────┘ │ ▼ out_conv → ERB_48k.bs() → CRM → 输出因果模块对比GTConvLite → CausalGTConvLite离线: padding(dilation, 1)前后各看dilation帧 因果: F.pad(x, (0,0,pad_t,0))只看过去pad_t帧 pad_t (kernel-1) * dilationTRALite → CausalTRA离线: Conv1d(k5, padding2)前后各看2帧 因果: Conv1d(k5, padding0) F.pad(x,(4,0))只看过去4帧DPGRNN → CausalDPGRNN离线: inter用双向GRU能看整个时间序列 因果: inter改单向GRU只能看到当前和过去其他改动激活函数: PReLU → SiLUDSConv: 加了中间BN顺序调整参数量: 139K → 145K (4%)参数量增加是因为单向GRU要增大hidden_size才能保持建模能力。结果DNSMOS: 3.15 → 2.98 (-5%)延迟: 10ms (单帧)RTF: 0.21 (还有4.7倍余量)掉了0.17分是预期内的。因果模型看不到未来信息量必然少于非因果模型。流式状态实时推理要维护帧间状态GTConv缓存: 12层不同dilation长度不同TRA历史: 12层每层4帧GRU hidden: 2×DPGRNN × 2层Skip缓存: 8组v3 → v3.1: 精度剪枝 (KD QAT)问题v3 teacher 模型 (width_mult2.0, 145K 参数) 运行时占用 711 KB无法部署到 500 KB 内存限制的嵌入式设备。方案两步压缩知识蒸馏 (KD) 缩小模型 → 量化感知训练 (QAT) 压缩权重。Step 1: 知识蒸馏Teacher: v3 (width_mult2.0, CH32)Student: width_mult1.0 (CH16, 41.6K 参数)训练 30 epochs 后收敛到容量极限Step 2: QAT INT8 量化Conv2d / Linear → INT8 per-channel 对称量化GRU / BN / LN → 保持 FP32导出 INT8 权重 per-channel scale结果指标值PESQ2.041SI-SNR14.22 dBDNS OVR2.778FP32 内存888 KBINT8 内存228 KB✅发现的问题inter GRU 的weight_hh在长序列 (1000 帧) 上累积量化误差导致噪底漂移。解决方案inter GRU 权重保持 FP32。v3.1 → v3.2: 宽度扩展 瞬态感知问题v3.1 (width_mult1.0, CH16) 容量有限对键盘敲击、鼠标点击等瞬态噪音抑制不足。方案改动v3.1v3.2宽度width_mult1.0 (CH16)width_mult1.5 (CH24)参数量41.6K~83K瞬态权重transient_weight1.0 (无效)transient_weight8.0瞬态检测无频谱平坦度 (flatness_threshold0.3)KD 瞬态均匀权重瞬态帧 ×5关键改进: 引入TransientAwareLoss_v2通过频谱平坦度区分噪声瞬态和语音瞬态避免误伤语音起始段。结果指标v3.1v3.2PESQ2.041~2.15-2.25INT8 内存228 KB~355 KB瞬态抑制弱明显改善内存从 228 KB 增到 355 KB仍在 500 KB 限制内。v3 → v4: 架构精简问题v3 使用 6 层 GTConv (dilation[1,2,4,8,4,2])encoder decoder 共 12 层。dilation8 的层感受野过大对实时场景贡献有限但占用大量因果缓存。方案精简架构减少层数和通道数参数v3v4GTConv 层数6 (enc) 6 (dec)4 (enc) 4 (dec)Dilation 序列[1,2,4,8,4,2][1,2,4,2]通道数 CH3220DPGRNN hidden3220SE Block选择性启用全部启用Skip 连接8 组6 组v4 网络结构输入 spec (B, 513, T, 2) │ ▼ ERB_48k.bm(): 513 → 219 │ ▼ in_conv: Conv2d(2→3) │ ▼ ┌─ CausalEncoder ───────────────────────┐ │ DSConv: 219→110 ← skip1 │ │ DSConv: 110→55 ← skip2 │ │ CausalGTConv×4 (d1,2,4,2) ← skip3-6 │ SubbandAttention │ └───────────────────────────────────────┘ │ ▼ CausalDPGRNN × 2 │ intra: 双向GRU (频率轴) │ inter: 单向GRU (时间轴) │ ▼ ┌─ CausalDecoder ───────────────────────┐ │ CausalGTConv×4 skip │ │ Fuse DSDeconv: 55→110 │ │ DSDeconv: 110→219 skip1 │ └───────────────────────────────────────┘ │ ▼ out_conv → ERB_48k.bs() → CRM → 输出训练KD: v3 teacher (CH32, 6层) → v4 student (CH20, 4层)QAT: Scheme 1b 混合精度 (22层FP32 45层INT8)结果指标v3v4PESQ (KD)—2.147PESQ (QAT)—2.037参数量145K~87KFP32 内存—683 KB内存分解 (FP32)类别大小占比Core (权重)348.82 KB51.1%— ERB 滤波器128.25 KB36.8% of Core— DPGRNN ×2150.02 KB43.0% of Core— GTConv ×855.81 KB16.0% of CoreState (状态)216.46 KB31.7%Workspace97.20 KB14.2%STFT Handle20.17 KB2.9%总计682.64 KB主要瓶颈: ERB 滤波器 (128 KB) 和 GRU 权重 (137 KB) 无法量化占 Core 的 80%。v4 → v4.1: INT8 混合精度 C 推理问题v4 的 C 推理管线所有权重以 FP32 存储。需要将 QAT 训练结果迁移到 C 端实现 INT8 混合精度推理。方案量化策略 (Scheme 1b):FP32 保留 (22层): in_conv, down1.pw, subband_attn, 所有 TRA, up1/up2.dw, GRU, LayerNorm, alpha/betaINT8 量化 (45层): GTConv dw/pw, SE fc1/fc2, DSConv/DSDeconv, DPGRNN pre/post/post2, fuse, out_conv量化方式: 对称 per-channel,q round(clamp(w / scale, -127, 127))BN 折叠: 24 个 BatchNorm 层在导出时折叠进 Conv 权重:W_folded W * (gamma / sqrt(var eps)) b_folded beta - mean * gamma / sqrt(var eps)实施阶段:Python 导出脚本 (export_qat_weights.py) — 提取 QAT 权重BN 折叠量化导出C 端权重结构体修改 —int8_t weight[]float scale[]float bias[]C 端层计算修改 — INT8 反量化计算移除 BN 计算权重加载 (GTC5 格式)、流式推理管线更新、Demo 更新发现并修复的 Bugexport_fp32_weights.py缺少out_convbias: bias[0.5724, 0.0033]未导出导致 C 端 mask 偏移严重复数 mask 乘法错误: 修正为标准复数乘法out spec * mask(实部虚部交叉相乘)窗函数问题模型训练使用sqrt(hann)centerTrue但 C 流式处理无法做 center padding。sqrt(hann): 第 961 帧 (~20秒) 产生 NaN 溢出普通hann: 全程 30 秒稳定无 NaNC 端当前使用普通 hann 窗。结果指标值INT8 vs FP32 缓存 SNR19.87 dBINT8 vs FP32 相关系数0.995RTF0.032总内存464 KBNaN / Clipping无C 流式 vs Python 批处理相关系数仅 0.37这是预期内的Python 使用双向时间上下文 (非因果 DPGRNN 非因果 GTConv)C 端是单向因果推理。提升一致性需要在训练时加入 causal 约束。演进路线图v1 (离线基线, DNSMOS 3.15) │ ▼ 换损失函数 v2 (瞬态感知, DNSMOS 3.15) │ ▼ 因果化改造 v3 (因果流式, DNSMOS 2.98, 145K params) │ ├──────────────────────┐ │ │ ▼ KD压缩 ▼ 架构精简 v3.1 (41.6K, 228KB) v4 (87K, 683KB FP32) │ │ ▼ 宽度扩展瞬态 ▼ INT8混合精度 v3.2 (83K, 355KB) v4.1 (87K, 464KB)文件结构archived_models/ ├── v1_baseline/ │ ├── original_export/ │ │ └── gtrcn_light_v3_48k_enhanced.py │ └── best_model_epoch29_score3.1474.tar │ ├── v2_transient/ │ ├── config.yaml │ ├── best_model_epoch71_score3.147.tar │ └── full_training_run/ │ ├── v3_causal_stream/ │ ├── models/ │ │ └── gtcrn_light_v3_48k_causal_v2.py │ ├── checkpoints/ │ │ └── best_model_epoch35_score2.983.tar │ └── QAT/ # QAT训练脚本 │ ├── v3.1_precision_pruning/ │ ├── PLAN.md │ ├── RESULTS.md │ └── runs/ # KD QAT 训练记录 │ ├── v3.2_width1.5_transient/ │ ├── PLAN.md │ └── runs/ # 宽度1.5 瞬态训练记录 │ ├── v4_network_opt/ │ ├── Streaming/ # FP32 C流式推理 │ ├── MEMORY_REPORT.md │ └── runs/ # KD QAT 训练记录 │ └── v4.1_int8_quantization/ ├── PLAN.md ├── Streaming/ # INT8混合精度 C流式推理 ├── export_qat_weights.py └── tmp/ # 调试输出和对比脚本选型建议场景推荐原因离线处理v1质量最高 (DNSMOS 3.15)办公环境v2瞬态处理好实时通话 (资源充足)v3低延迟质量较高极限内存 (256KB)v3.1228 KB最小体积瞬态噪音 嵌入式v3.2355 KB瞬态抑制好均衡部署v4.1464 KB架构精简INT8 推理
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2423502.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!