| info | |
|---|---|
| paper | https://arxiv.org/abs/2205.13147 | 
| code | https://github.com/RAIVNLab/MRL | 
| org | 华盛顿大学、Google、哈弗大学 | 
| 个人博客位置 | http://www.myhz0606.com/article/mrl | 
Motivation
我们平时做retrieval相关的工作,很多时候需要根据业务场景和计算资源对向量进行降维。受限开发周期,我们往往不会通过重新训练特征提取模型来调整向量维度,而是用PCA等方法来实现。但是当降维的scale较大时,PCA等方法的效果较差。Matryoshka Representation Learning (MRL)这篇paper介绍了一个很简单但有效的方法能实现一次训练,获取不同维度的表征提取。下面来看它具体是怎么做的吧。
Method
文中只描述MRL最核心的部分,详细介绍请看原论文。
我们以一个图像分类任务为例,其pipeline如下。图片首先通过一个Feature extractor提取特征,flatten后用一个FC来映射到表征空间,再接入一个classifier(也是个全连接层)得到该图片在类别上的概率分布。用这个方法训练,一次训练我们只能得到一种维度的图片表征(如图中是2048维)

为了一次训练获得不同维度的图片表征,最简单粗暴的方法就是我们可以用多个FC及对应的Classifier进行联合训练。这无疑是有效的,但由于FC和classifier多了,模型会大一些。

MRL对上面做了一个优化,它能通过一组FC和Classifier实现多种尺度的特征训练。pipeline如下图所示(图中同个颜色表示共享权重)。MRL实现的核心就是:对同一组FC和Classifier进行分片,从而实现不同维度的表征训练。
论文公式中的 
     
      
       
       
         F 
        
       
         ( 
        
        
        
          x 
         
        
          i 
         
        
       
         ; 
        
        
        
          θ 
         
        
          F 
         
        
       
         ) 
        
       
      
        F(x_i; \theta_{F}) 
       
      
    F(xi;θF)是我图中的Feature_extractor + FC。
min  { W ( m ) } m ∈ M , θ F 1 N ∑ i ∈ [ N ] ∑ m ∈ M c m ⋅ L ( W ( m ) ⋅ F ( x i ; θ F ) 1 : m ; y i ) , \min _ { \{ { \boldsymbol W } ^ { ( m ) } \} _ { m \in { \mathcal M } } , \, \theta _ { F } } \frac { 1 } { N } \sum _ { i \in [ N ] } \sum _ { m \in { \mathcal M } } c _ { m } \cdot { \mathcal L } ( { \boldsymbol W } ^ { ( m ) } \cdot F ( x _ { i } ; \theta _ { F } ) _ { 1 : m } \, ; \, y _ { i } ) \; , {W(m)}m∈M,θFminN1i∈[N]∑m∈M∑cm⋅L(W(m)⋅F(xi;θF)1:m;yi),

MRL的实现源码如下图所示:
class MRL_Linear_Layer(nn.Module):
	def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):
		super(MRL_Linear_Layer, self).__init__()
		self.nesting_list = nesting_list
		self.num_classes = num_classes # Number of classes for classification
		self.efficient = efficient
		if self.efficient:
			setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))		
		else:	
			for i, num_feat in enumerate(self.nesting_list):
				setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))	
	def reset_parameters(self):
		if self.efficient:
			self.nesting_classifier_0.reset_parameters()
		else:
			for i in range(len(self.nesting_list)):
				getattr(self, f"nesting_classifier_{i}").reset_parameters()
	def forward(self, x):
		nesting_logits = ()
		for i, num_feat in enumerate(self.nesting_list):
			if self.efficient:
				if self.nesting_classifier_0.bias is None:
					nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
				else:
					nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
			else:
				nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)
		return nesting_logits
Result
该图对比了MRL不同维度的表征在imagenet1K上linear classification和1-NN的准确率。

下图给出了scale model和dataset时MRL依旧有效,并且MRL提取的表征具备良好的插值性能。

更多实验结果见原论文。
小结
这篇文章虽然idea很简单,但很适合工程应用。
参考文献
Matryoshka Representation Learning



















