Sinkhorn算法实战:用Python手把手教你解决最优传输问题(附完整代码)
Sinkhorn算法实战用Python手把手教你解决最优传输问题附完整代码最优传输理论在机器学习领域正掀起一场静默的革命。想象一下这样的场景你需要将一组资源从A地运往B地同时希望运输成本最低或者你需要将一幅图像的颜色分布调整为另一幅图像的风格。这些看似不同的问题背后都隐藏着同一个数学框架——最优传输。而Sinkhorn算法正是让这个理论走出数学殿堂、走进工程师电脑的魔法钥匙。1. 最优传输与Sinkhorn算法基础最优传输问题最早由法国数学家蒙日(Gaspard Monge)在1781年提出核心思想是如何以最小的成本将一个概率分布转换为另一个概率分布。这里的成本可以是我们熟悉的欧式距离也可以是任何自定义的度量标准。传统的最优传输解法面临两大挑战计算复杂度高精确解法的时间复杂度通常为O(n³)数值不稳定当分布中存在零值时容易产生数值问题Sinkhorn算法通过引入熵正则化巧妙地解决了这些问题。其核心公式可以表示为P diag(u) K diag(v) # 运输矩阵分解形式其中K exp(-C/ε)是经过指数变换的成本矩阵u和v是通过迭代更新的缩放向量ε是控制正则化强度的超参数提示熵正则化参数ε的选择至关重要——较大的ε使问题更平滑但结果更模糊较小的ε更精确但计算更困难2. Python实现详解让我们从零开始实现Sinkhorn算法。首先确保安装必要的库pip install numpy matplotlib POT2.1 算法核心实现import numpy as np def sinkhorn(a, b, C, epsilon0.1, max_iter1000, tol1e-9): Sinkhorn算法实现 参数: a: (n,) 源分布 b: (m,) 目标分布 C: (n,m) 成本矩阵 epsilon: 正则化参数 max_iter: 最大迭代次数 tol: 收敛阈值 返回: P: (n,m) 最优传输矩阵 # 初始化 u np.ones_like(a) v np.ones_like(b) K np.exp(-C / epsilon) # 迭代更新 for _ in range(max_iter): u_prev, v_prev u.copy(), v.copy() u a / (K v) v b / (K.T u) # 检查收敛 if (np.max(np.abs(u - u_prev)) tol and np.max(np.abs(v - v_prev)) tol): break return np.diag(u) K np.diag(v)2.2 可视化分析理解算法行为的最佳方式是可视化。我们创建两个高斯分布并观察它们的传输过程import matplotlib.pyplot as plt from ot.datasets import make_1D_gauss # 生成数据 n 100 x np.arange(n) a make_1D_gauss(n, m20, s5) # 源分布 b make_1D_gauss(n, m70, s10) # 目标分布 # 成本矩阵 C (x.reshape(-1,1) - x.reshape(1,-1))**2 / n**2 # 计算传输 P sinkhorn(a, b, C, epsilon0.05) # 可视化 plt.figure(figsize(12,8)) plt.subplot(2,2,1) plt.plot(x, a, r, labelSource) plt.plot(x, b, b, labelTarget) plt.legend() plt.subplot(2,2,2) plt.imshow(P, cmapviridis) plt.title(Transport Matrix) plt.subplot(2,2,3) plt.plot(P.sum(1), r--, labelMarginal a) plt.plot(a, r, alpha0.3) plt.plot(P.sum(0), b--, labelMarginal b) plt.plot(b, b, alpha0.3) plt.legend() plt.tight_layout()3. 参数调优与性能优化3.1 正则化参数ε的影响ε的选择直接影响结果的质量和计算效率ε值计算速度结果精度适用场景较大(0.1)快低初步探索/可视化中等(0.01-0.1)中等中等一般应用较小(0.01)慢高精确计算# 测试不同ε值 epsilons [0.5, 0.1, 0.05, 0.01] results {} for eps in epsilons: results[eps] sinkhorn(a, b, C, epsiloneps)3.2 加速技巧实际应用中可以采用以下优化策略对数域计算避免数值下溢批处理同时计算多个传输对GPU加速使用CuPy替代NumPy改进后的对数域实现def sinkhorn_log(a, b, C, epsilon0.1, max_iter1000): 对数域实现 log_a np.log(a) log_b np.log(b) log_K -C / epsilon u np.zeros_like(a) v np.zeros_like(b) for _ in range(max_iter): u log_a - np.log(np.exp(log_K v).sum(1)) v log_b - np.log(np.exp(log_K.T u).sum(1)) return np.exp(u[:,None] log_K v)4. 实战应用案例4.1 图像颜色迁移将一张图像的色彩分布迁移到另一张图像from skimage import data, transform import cv2 # 加载图像 src_img data.astronaut() tgt_img data.chelsea() # 预处理 src cv2.cvtColor(src_img, cv2.COLOR_RGB2LAB).reshape(-1,3)/255. tgt cv2.cvtColor(tgt_img, cv2.COLOR_RGB2LAB).reshape(-1,3)/255. # 采样 n_samples 1000 src_samples src[np.random.choice(len(src), n_samples)] tgt_samples tgt[np.random.choice(len(tgt), n_samples)] # 计算成本矩阵 C np.sum(src_samples**2, 1)[:,None] np.sum(tgt_samples**2, 1)[None,:] - 2*src_samples tgt_samples.T # 计算传输 P sinkhorn(np.ones(n_samples)/n_samples, np.ones(n_samples)/n_samples, C, epsilon0.01) # 应用传输 transported tgt_samples[np.argmax(P, axis1)]4.2 文档嵌入对齐对齐不同语言模型的词嵌入空间from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import TfidfVectorizer # 加载双语语料 en_data fetch_20newsgroups(subsettrain, categories[sci.space]).data[:100] fr_data [...] # 假设有对应的法语翻译 # 创建嵌入 en_vec TfidfVectorizer(max_features500).fit_transform(en_data) fr_vec TfidfVectorizer(max_features500).fit_transform(fr_data) # 计算分布 en_dist np.array(en_vec.mean(axis0)).flatten() fr_dist np.array(fr_vec.mean(axis0)).flatten() # 计算词-词成本矩阵 en_emb [...] # 英语词向量 fr_emb [...] # 法语词向量 C np.linalg.norm(en_emb[:,None] - fr_emb[None,:], axis2) # 对齐 P sinkhorn(en_dist, fr_dist, C, epsilon0.1) aligned_emb P fr_emb / en_dist[:,None]5. 高级技巧与问题排查5.1 常见问题解决方案问题现象可能原因解决方案结果全为NaNε太小导致数值溢出使用对数域实现或增大ε收敛慢成本矩阵尺度不一致标准化成本矩阵边缘约束不满足迭代次数不足增加max_iter或降低tol5.2 扩展到不平衡传输当总质量不相等时可以使用部分传输def unbalanced_sinkhorn(a, b, C, epsilon0.1, tau1.0): 不平衡传输 K np.exp(-C / epsilon) u np.ones_like(a) v np.ones_like(b) for _ in range(1000): u (a / (K v)) ** tau v (b / (K.T u)) ** tau return np.diag(u) K np.diag(v)5.3 多尺度加速对于大规模问题可以采用多尺度方法对分布进行粗粒度化在粗粒度上计算传输将结果作为细粒度初始值逐步细化def multiscale_sinkhorn(a, b, C, levels3): 多尺度Sinkhorn # 构建金字塔 a_pyramid [a] b_pyramid [b] C_pyramid [C] for _ in range(levels-1): a_pyramid.append(a_pyramid[-1][::2] a_pyramid[-1][1::2]) b_pyramid.append(b_pyramid[-1][::2] b_pyramid[-1][1::2]) C_pyramid.append(C_pyramid[-1][::2,::2] C_pyramid[-1][1::2,::2] C_pyramid[-1][::2,1::2] C_pyramid[-1][1::2,1::2]) # 从粗到细计算 P np.ones_like(C_pyramid[-1]) for l in reversed(range(levels)): a_l, b_l, C_l a_pyramid[l], b_pyramid[l], C_pyramid[l] if l levels-1: P np.kron(P, np.ones((2,2))) # 上采样 P P * (a_l.sum() / P.sum()) # 重新归一化 P sinkhorn(a_l, b_l, C_l, initialP) return P在实际项目中我发现多尺度方法可以将计算时间从数小时缩短到几分钟特别是在处理高分辨率图像或大规模嵌入时效果显著。另一个实用技巧是预热初始化——先用较大的ε值计算然后逐步减小ε并使用前一次结果作为初始值这样通常能获得更好的收敛性。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2459300.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!