

import math
import wave
import array
import functools
from abc import ABC, abstractmethod
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import os
import sys
# ====================== 设计模式部分 ======================
class PreprocessStrategy(ABC):
"""预处理策略基类"""
@abstractmethod
def process(self, signal):
pass
class FrameStrategy(ABC):
"""分帧策略基类"""
@abstractmethod
def frame(self, signal, sample_rate):
pass
class PitchDetector(ABC):
"""基频检测器基类"""
@abstractmethod
def detect(self, frames, sample_rate):
pass
class Visualizer(ABC):
"""可视化器基类"""
@abstractmethod
def visualize(self, analysis_data):
pass
class ProcessorFactory:
"""处理器工厂(工厂模式)"""
@staticmethod
def create_preprocessor(strategy_type):
if strategy_type == "preemphasis":
return PreEmphasisStrategy()
raise ValueError("未知的预处理策略")
@staticmethod
def create_framer(strategy_type):
if strategy_type == "fixed":
return FixedFrameStrategy()
raise ValueError("未知的分帧策略")
@staticmethod
def create_detector(strategy_type):
if strategy_type == "autocorrelation":
return AutoCorrelationDetector()
elif strategy_type == "cepstrum":
return CepstrumDetector()
raise ValueError("未知的检测策略")
@staticmethod
def create_visualizer(strategy_type):
if strategy_type == "matplotlib":
return MatplotlibVisualizer()
elif strategy_type == "text":
return TextVisualizer()
raise ValueError("未知的可视化策略")
# ====================== 信号处理部分 ======================
class PreEmphasisStrategy(PreprocessStrategy):
"""预加重策略(提升高频分量)"""
def process(self, signal):
return [signal[i] - 0.97 * signal[i-1] for i in range(1, len(signal))]
class FixedFrameStrategy(FrameStrategy):
"""固定分帧策略(策略模式)"""
def __init__(self, frame_ms=25, overlap_ratio=0.5):
self.frame_ms = frame_ms
self.overlap_ratio = overlap_ratio
def frame(self, signal, sample_rate):
frame_length = int(sample_rate * self.frame_ms / 1000)
step = int(frame_length * (1 - self.overlap_ratio))
frames = []
for start in range(0, len(signal) - frame_length, step):
frames.append(signal[start:start+frame_length])
return frames
class AutoCorrelationDetector(PitchDetector):
"""自相关基频检测(核心算法)"""
def detect(self, frames, sample_rate):
pitches = []
acf_values = [] # 存储每帧的自相关值用于可视化
for frame in frames:
# 计算自相关函数
acf = self._autocorrelation(frame)
acf_values.append(acf)
# 寻找基频峰值
pitch = self._find_pitch(acf, sample_rate)
pitches.append(pitch)
return pitches, acf_values # 返回基频和自相关值
def _autocorrelation(self, frame):
n = len(frame)
acf = []
for lag in range(n//2): # 只计算一半延迟
total = 0.0
for i in range(n - lag):
total += frame[i] * frame[i + lag]
acf.append(total)
return acf
def _find_pitch(self, acf, sample_rate):
# 忽略前10个样本(避免直流分量影响)
if len(acf) < 20: # 确保有足够的数据点
return 0
search_range = acf[10:]
if not search_range:
return 0
# 寻找最大峰值位置
max_index = search_range.index(max(search_range)) + 10
# 二次插值提高精度
if 1 < max_index < len(acf)-1:
alpha = acf[max_index-1]
beta = acf[max_index]
gamma = acf[max_index+1]
# 避免除以零
denominator = alpha - 2*beta + gamma
if abs(denominator) < 1e-10: # 防止分母接近零
peak_index = max_index
else:
delta = 0.5 * (alpha - gamma) / denominator
peak_index = max_index + delta
else:
peak_index = max_index
# 计算基频
if peak_index > 0:
return sample_rate / peak_index
return 0
class CepstrumDetector(PitchDetector):
"""倒谱法基频检测(备选策略)"""
def detect(self, frames, sample_rate):
# 实现类似自相关法,使用倒谱峰值检测
# 此处省略具体实现
return [0] * len(frames), [[]] * len(frames)
# ====================== 可视化部分 ======================
class MatplotlibVisualizer(Visualizer):
"""使用matplotlib进行可视化(需要matplotlib库)"""
def __init__(self):
# 解决中文显示问题
self._configure_matplotlib()
def _configure_matplotlib(self):
"""配置matplotlib以支持中文显示"""
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
plt.rcParams['axes.unicode_minus'] = False
def visualize(self, analysis_data):
# 解包分析数据
original_signal = analysis_data['original_signal']
processed_signal = analysis_data['processed_signal']
frames = analysis_data['frames']
windowed_frames = analysis_data['windowed_frames']
pitches = analysis_data['pitches']
acf_values = analysis_data['acf_values']
sample_rate = analysis_data['sample_rate']
# 创建画布
fig = plt.figure(figsize=(15, 12))
gs = GridSpec(4, 2, figure=fig)
# 1. 原始信号和预处理信号对比
ax1 = plt.subplot(gs[0, :])
time_original = [i / sample_rate for i in range(len(original_signal))]
time_processed = [i / sample_rate for i in range(len(processed_signal))]
ax1.plot(time_original, original_signal, label='原始信号')
ax1.plot(time_processed, processed_signal, label='预处理后信号', alpha=0.7)
ax1.set_title('原始信号 vs 预处理信号')
ax1.set_xlabel('时间 (秒)')
ax1.set_ylabel('振幅')
ax1.legend()
ax1.grid(True)
# 2. 第一帧的原始和加窗信号
ax2 = plt.subplot(gs[1, 0])
frame_index = 0
frame_time = [i / sample_rate * 1000 for i in range(len(frames[frame_index]))] # 毫秒
ax2.plot(frame_time, frames[frame_index], label='原始帧')
ax2.plot(frame_time, windowed_frames[frame_index], label='加窗帧')
ax2.set_title(f'第{frame_index+1}帧信号 (原始 vs 加窗)')
ax2.set_xlabel('时间 (毫秒)')
ax2.set_ylabel('振幅')
ax2.legend()
ax2.grid(True)
# 3. 第一帧的自相关函数
ax3 = plt.subplot(gs[1, 1])
lags = [i * 1000 / sample_rate for i in range(len(acf_values[frame_index]))] # 毫秒
ax3.plot(lags, acf_values[frame_index])
ax3.set_title(f'第{frame_index+1}帧的自相关函数')
ax3.set_xlabel('延迟 (毫秒)')
ax3.set_ylabel('自相关值')
ax3.grid(True)
# 标记基音周期
pitch = pitches[frame_index]
if pitch > 0:
period = 1000 / pitch # 周期(毫秒)
ax3.axvline(period, color='r', linestyle='--',
label=f'基音周期: {period:.2f}ms')
ax3.legend()
# 4. 基频检测结果
ax4 = plt.subplot(gs[2, :])
frame_times = [i * (len(frames[0]) * (1-0.5) / sample_rate)
for i in range(len(pitches))] # 帧中心时间
ax4.plot(frame_times, pitches)
ax4.set_title('基频检测结果')
ax4.set_xlabel('时间 (秒)')
ax4.set_ylabel('频率 (Hz)')
ax4.grid(True)
# 5. 频谱图
ax5 = plt.subplot(gs[3, 0])
frame_to_show = windowed_frames[frame_index]
n = len(frame_to_show)
# 使用DFT计算频谱
spectrum = [abs(self._dft(frame_to_show, k)) for k in range(n//2)]
freqs = [k * sample_rate / n for k in range(n//2)]
ax5.plot(freqs, spectrum)
ax5.set_title(f'第{frame_index+1}帧频谱')
ax5.set_xlabel('频率 (Hz)')
ax5.set_ylabel('幅度')
ax5.set_xlim(0, 2000)
ax5.grid(True)
# 标记基频和谐波
if pitch > 0:
ax5.axvline(pitch, color='r', linestyle='--', label=f'基频: {pitch:.1f}Hz')
# 标记前5个谐波
for i in range(2, 6):
harmonic = i * pitch
if harmonic < 2000:
ax5.axvline(harmonic, color='g', linestyle=':',
label=f'{i}次谐波' if i==2 else None)
ax5.legend()
# 6. 原始信号频谱图
ax6 = plt.subplot(gs[3, 1])
# 对整段信号进行DFT
full_spectrum = [abs(self._dft(original_signal, k)) for k in range(len(original_signal)//2)]
full_freqs = [k * sample_rate / len(original_signal) for k in range(len(original_signal)//2)]
ax6.plot(full_freqs, full_spectrum)
ax6.set_title('整段信号频谱')
ax6.set_xlabel('频率 (Hz)')
ax6.set_ylabel('幅度')
ax6.set_xlim(0, 2000)
ax6.grid(True)
plt.tight_layout()
plt.show()
def _dft(self, x, k):
"""离散傅里叶变换 (不使用第三方库)"""
N = len(x)
real = 0.0
imag = 0.0
for n in range(N):
angle = 2 * math.pi * k * n / N
real += x[n] * math.cos(angle)
imag -= x[n] * math.sin(angle)
return math.sqrt(real*real + imag*imag) / N
class TextVisualizer(Visualizer):
"""文本可视化器(用于无图形界面环境)"""
def visualize(self, analysis_data):
pitches = analysis_data['pitches']
print("\n基频检测结果摘要:")
print(f"分析帧数: {len(pitches)}")
# 计算有效基频统计
valid_pitches = [p for p in pitches if 50 < p < 800]
if valid_pitches:
avg_pitch = sum(valid_pitches) / len(valid_pitches)
min_pitch = min(valid_pitches)
max_pitch = max(valid_pitches)
print(f"平均基频: {avg_pitch:.2f} Hz")
print(f"最小基频: {min_pitch:.2f} Hz")
print(f"最大基频: {max_pitch:.2f} Hz")
else:
print("未检测到有效基频")
# 打印前5帧的基频
print("\n前5帧基频值:")
for i, pitch in enumerate(pitches[:5]):
print(f"帧 {i+1}: {pitch:.2f} Hz")
# ====================== 工具函数 ======================
def normalize_signal(signal):
"""信号归一化到[-1, 1]范围"""
max_val = max(abs(x) for x in signal)
return [x / max_val for x in signal] if max_val > 0 else signal
def hamming_window(frame):
"""应用汉明窗减少频谱泄露"""
N = len(frame)
return [frame[i] * (0.54 - 0.46 * math.cos(2 * math.pi * i / (N - 1)))
for i in range(N)]
# ====================== 核心处理流程 ======================
class PitchAnalyzer:
"""基频分析主流程(门面模式)"""
def __init__(self, config):
self.preprocessor = ProcessorFactory.create_preprocessor(config["preprocess"])
self.framer = ProcessorFactory.create_framer(config["frame"])
self.detector = ProcessorFactory.create_detector(config["detect"])
self.visualizer = ProcessorFactory.create_visualizer(config.get("visualize", "text"))
self.config = config
def analyze(self, signal, sample_rate, visualize=False):
# 保存原始信号用于可视化
original_signal = signal.copy()
# 1. 预处理
processed = self.preprocessor.process(signal)
# 2. 分帧
frames = self.framer.frame(processed, sample_rate)
# 3. 加窗处理
windowed_frames = [hamming_window(frame) for frame in frames]
# 4. 基频检测
pitches, acf_values = self.detector.detect(windowed_frames, sample_rate)
# 准备可视化数据
analysis_data = {
'original_signal': original_signal,
'processed_signal': processed,
'frames': frames,
'windowed_frames': windowed_frames,
'pitches': pitches,
'acf_values': acf_values,
'sample_rate': sample_rate
}
# 5. 可视化
if visualize:
self.visualizer.visualize(analysis_data)
return pitches, analysis_data
# ====================== 文件处理 ======================
def read_wav_file(filename):
"""读取WAV文件返回信号和采样率"""
with wave.open(filename, 'rb') as wav:
n_frames = wav.getnframes()
sample_rate = wav.getframerate()
data = wav.readframes(n_frames)
# 将字节数据转换为浮点数
samples = array.array('h', data)
return [s / 32768.0 for s in samples], sample_rate
# ====================== 测试用例 ======================
def test_pitch_analysis(show_plot=True):
"""测试基频检测功能"""
# 生成440Hz正弦波(A4标准音)
sample_rate = 44100
duration = 0.5 # 0.5秒
freq = 440.0
t = [i / sample_rate for i in range(int(sample_rate * duration))]
signal = [math.sin(2 * math.pi * freq * t_i) for t_i in t]
# 添加噪声模拟真实环境
signal = [s + 0.1 * math.sin(2 * math.pi * 3000 * t_i) for s, t_i in zip(signal, t)]
# 配置分析器
config = {
"preprocess": "preemphasis",
"frame": "fixed",
"detect": "autocorrelation",
"visualize": "matplotlib" if show_plot else "text"
}
analyzer = PitchAnalyzer(config)
# 执行分析并可视化
pitches, analysis_data = analyzer.analyze(signal, sample_rate, visualize=show_plot)
# 验证结果(取有效帧的平均值)
valid_pitches = [p for p in pitches if 50 < p < 800]
if valid_pitches:
avg_pitch = sum(valid_pitches) / len(valid_pitches)
# 允许±5Hz误差
assert 435 < avg_pitch < 445, f"检测失败:期望440Hz,得到{avg_pitch:.2f}Hz"
print(f"测试通过!检测基频:{avg_pitch:.2f}Hz (期望440Hz)")
else:
print("测试失败:未检测到有效基频")
def test_zero_division():
"""测试除以零异常处理"""
# 创建一个全零信号,会触发除以零异常
sample_rate = 44100
signal = [0.0] * 1000 # 1秒的静音
# 配置分析器
config = {
"preprocess": "preemphasis",
"frame": "fixed",
"detect": "autocorrelation",
"visualize": "text"
}
analyzer = PitchAnalyzer(config)
# 执行分析
pitches, _ = analyzer.analyze(signal, sample_rate)
# 验证结果应为0或接近0
assert all(p < 1e-6 for p in pitches), "全零信号处理失败"
print("除以零测试通过!")
def test_real_audio(filename=r"E:\temp\sample.wav"):
"""测试真实音频文件"""
try:
signal, sample_rate = read_wav_file(filename)
except FileNotFoundError:
print(f"文件 {filename} 未找到,使用测试信号代替")
test_pitch_analysis(show_plot=True)
return
# 只取前1秒音频
if len(signal) > sample_rate:
signal = signal[:sample_rate]
# 配置分析器
config = {
"preprocess": "preemphasis",
"frame": "fixed",
"detect": "autocorrelation",
"visualize": "matplotlib"
}
analyzer = PitchAnalyzer(config)
# 执行分析并可视化
analyzer.analyze(signal, sample_rate, visualize=True)
if __name__ == "__main__":
#print("1. 测试合成信号")
#test_pitch_analysis(show_plot=True)
#print("\n2. 测试除以零处理")
#test_zero_division()
print("\n3. 测试真实音频(需要sample.wav文件)")
test_real_audio()