知识蒸馏算法汇总

news2025/7/18 8:09:20

知识蒸馏有两大类:一类是logits蒸馏,另一类是特征蒸馏。logits蒸馏指的是在softmax时使用较高的温度系数,提升负标签的信息,然后使用Student和Teacher在高温softmax下logits的KL散度作为loss。中间特征蒸馏就是强迫Student去学习Teacher某些中间层的特征,直接匹配中间的特征或学习特征之间的转换关系。例如,在特征No.1和No.2中间,知识可以表示为如何模做两者中间的转化,可以用一个矩阵让学习者产生这个矩阵,学习者和转化之间的学习关系。
这篇文章汇总了常用的知识蒸馏的论文和代码,方便后续的学习和研究。

1、Logits

论文链接:https://proceedings.neurips.cc/paper/2014/file/ea8fcd92d59581717e06eb187f10666d-Paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class Logits(nn.Module):
	'''
	Do Deep Nets Really Need to be Deep?
	http://papers.nips.cc/paper/5484-do-deep-nets-really-need-to-be-deep.pdf
	'''
	def __init__(self):
		super(Logits, self).__init__()

	def forward(self, out_s, out_t):
		loss = F.mse_loss(out_s, out_t)

		return loss

2、ST

论文链接:https://arxiv.org/pdf/1503.02531.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftTarget(nn.Module):
	'''
	Distilling the Knowledge in a Neural Network
	https://arxiv.org/pdf/1503.02531.pdf
	'''
	def __init__(self, T):
		super(SoftTarget, self).__init__()
		self.T = T

	def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

在这里插入图片描述

3、AT

论文链接:https://arxiv.org/pdf/1612.03928.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


'''
AT with sum of absolute values with power p
'''
class AT(nn.Module):
	'''
	Paying More Attention to Attention: Improving the Performance of Convolutional
	Neural Netkworks wia Attention Transfer
	https://arxiv.org/pdf/1612.03928.pdf
	'''
	def __init__(self, p):
		super(AT, self).__init__()
		self.p = p

	def forward(self, fm_s, fm_t):
		loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))

		return loss

	def attention_map(self, fm, eps=1e-6):
		am = torch.pow(torch.abs(fm), self.p)
		am = torch.sum(am, dim=1, keepdim=True)
		norm = torch.norm(am, dim=(2,3), keepdim=True)
		am = torch.div(am, norm+eps)

		return am

4、Fitnet

论文链接:https://arxiv.org/pdf/1412.6550.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class Hint(nn.Module):
	'''
	FitNets: Hints for Thin Deep Nets
	https://arxiv.org/pdf/1412.6550.pdf
	'''
	def __init__(self):
		super(Hint, self).__init__()

	def forward(self, fm_s, fm_t):
		loss = F.mse_loss(fm_s, fm_t)

		return loss

5、NST

论文链接:https://arxiv.org/pdf/1707.01219.pdf

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


'''
NST with Polynomial Kernel, where d=2 and c=0
'''
class NST(nn.Module):
	'''
	Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
	https://arxiv.org/pdf/1707.01219.pdf
	'''
	def __init__(self):
		super(NST, self).__init__()

	def forward(self, fm_s, fm_t):
		fm_s = fm_s.view(fm_s.size(0), fm_s.size(1), -1)
		fm_s = F.normalize(fm_s, dim=2)

		fm_t = fm_t.view(fm_t.size(0), fm_t.size(1), -1)
		fm_t = F.normalize(fm_t, dim=2)

		loss = self.poly_kernel(fm_t, fm_t).mean() \
			 + self.poly_kernel(fm_s, fm_s).mean() \
			 - 2 * self.poly_kernel(fm_s, fm_t).mean()

		return loss

	def poly_kernel(self, fm1, fm2):
		fm1 = fm1.unsqueeze(1)
		fm2 = fm2.unsqueeze(2)
		out = (fm1 * fm2).sum(-1).pow(2)

		return out

6、PKT

论文链接:http://openaccess.thecvf.com/content_ECCV_2018/papers/Nikolaos_Passalis_Learning_Deep_Representations_ECCV_2018_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


'''
Adopted from https://github.com/passalis/probabilistic_kt/blob/master/nn/pkt.py
'''
class PKTCosSim(nn.Module):
	'''
	Learning Deep Representations with Probabilistic Knowledge Transfer
	http://openaccess.thecvf.com/content_ECCV_2018/papers/Nikolaos_Passalis_Learning_Deep_Representations_ECCV_2018_paper.pdf
	'''
	def __init__(self):
		super(PKTCosSim, self).__init__()

	def forward(self, feat_s, feat_t, eps=1e-6):
		# Normalize each vector by its norm
		feat_s_norm = torch.sqrt(torch.sum(feat_s ** 2, dim=1, keepdim=True))
		feat_s = feat_s / (feat_s_norm + eps)
		feat_s[feat_s != feat_s] = 0

		feat_t_norm = torch.sqrt(torch.sum(feat_t ** 2, dim=1, keepdim=True))
		feat_t = feat_t / (feat_t_norm + eps)
		feat_t[feat_t != feat_t] = 0

		# Calculate the cosine similarity
		feat_s_cos_sim = torch.mm(feat_s, feat_s.transpose(0, 1))
		feat_t_cos_sim = torch.mm(feat_t, feat_t.transpose(0, 1))

		# Scale cosine similarity to [0,1]
		feat_s_cos_sim = (feat_s_cos_sim + 1.0) / 2.0
		feat_t_cos_sim = (feat_t_cos_sim + 1.0) / 2.0

		# Transform them into probabilities
		feat_s_cond_prob = feat_s_cos_sim / torch.sum(feat_s_cos_sim, dim=1, keepdim=True)
		feat_t_cond_prob = feat_t_cos_sim / torch.sum(feat_t_cos_sim, dim=1, keepdim=True)

		# Calculate the KL-divergence
		loss = torch.mean(feat_t_cond_prob * torch.log((feat_t_cond_prob + eps) / (feat_s_cond_prob + eps)))

		return loss

7、FSP

论文链接:http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class FSP(nn.Module):
	'''
	A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
	http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
	'''
	def __init__(self):
		super(FSP, self).__init__()

	def forward(self, fm_s1, fm_s2, fm_t1, fm_t2):
		loss = F.mse_loss(self.fsp_matrix(fm_s1,fm_s2), self.fsp_matrix(fm_t1,fm_t2))

		return loss

	def fsp_matrix(self, fm1, fm2):
		if fm1.size(2) > fm2.size(2):
			fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))

		fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)
		fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2)

		fsp = torch.bmm(fm1, fm2) / fm1.size(2)

		return fsp

8、FT

论文链接:http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class FT(nn.Module):
	'''
	araphrasing Complex Network: Network Compression via Factor Transfer
	http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf
	'''
	def __init__(self):
		super(FT, self).__init__()

	def forward(self, factor_s, factor_t):
		loss = F.l1_loss(self.normalize(factor_s), self.normalize(factor_t))

		return loss

	def normalize(self, factor):
		norm_factor = F.normalize(factor.view(factor.size(0),-1))

		return norm_factor

9、RKD

论文链接:https://arxiv.org/pdf/1904.05068.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


'''
From https://github.com/lenscloth/RKD/blob/master/metric/loss.py
'''
class RKD(nn.Module):
	'''
	Relational Knowledge Distillation
	https://arxiv.org/pdf/1904.05068.pdf
	'''
	def __init__(self, w_dist, w_angle):
		super(RKD, self).__init__()

		self.w_dist  = w_dist
		self.w_angle = w_angle

	def forward(self, feat_s, feat_t):
		loss = self.w_dist * self.rkd_dist(feat_s, feat_t) + \
			   self.w_angle * self.rkd_angle(feat_s, feat_t)

		return loss

	def rkd_dist(self, feat_s, feat_t):
		feat_t_dist = self.pdist(feat_t, squared=False)
		mean_feat_t_dist = feat_t_dist[feat_t_dist>0].mean()
		feat_t_dist = feat_t_dist / mean_feat_t_dist

		feat_s_dist = self.pdist(feat_s, squared=False)
		mean_feat_s_dist = feat_s_dist[feat_s_dist>0].mean()
		feat_s_dist = feat_s_dist / mean_feat_s_dist

		loss = F.smooth_l1_loss(feat_s_dist, feat_t_dist)

		return loss

	def rkd_angle(self, feat_s, feat_t):
		# N x C --> N x N x C
		feat_t_vd = (feat_t.unsqueeze(0) - feat_t.unsqueeze(1))
		norm_feat_t_vd = F.normalize(feat_t_vd, p=2, dim=2)
		feat_t_angle = torch.bmm(norm_feat_t_vd, norm_feat_t_vd.transpose(1, 2)).view(-1)

		feat_s_vd = (feat_s.unsqueeze(0) - feat_s.unsqueeze(1))
		norm_feat_s_vd = F.normalize(feat_s_vd, p=2, dim=2)
		feat_s_angle = torch.bmm(norm_feat_s_vd, norm_feat_s_vd.transpose(1, 2)).view(-1)

		loss = F.smooth_l1_loss(feat_s_angle, feat_t_angle)

		return loss

	def pdist(self, feat, squared=False, eps=1e-12):
		feat_square = feat.pow(2).sum(dim=1)
		feat_prod   = torch.mm(feat, feat.t())
		feat_dist   = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)

		if not squared:
			feat_dist = feat_dist.sqrt()

		feat_dist = feat_dist.clone()
		feat_dist[range(len(feat)), range(len(feat))] = 0

		return feat_dist

在这里插入图片描述

10、AB

论文链接:https://arxiv.org/pdf/1811.03233.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class AB(nn.Module):
	'''
	Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
	https://arxiv.org/pdf/1811.03233.pdf
	'''
	def __init__(self, margin):
		super(AB, self).__init__()

		self.margin = margin

	def forward(self, fm_s, fm_t):
		# fm befor activation
		loss = ((fm_s + self.margin).pow(2) * ((fm_s > -self.margin) & (fm_t <= 0)).float() +
			    (fm_s - self.margin).pow(2) * ((fm_s <= self.margin) & (fm_t > 0)).float())
		loss = loss.mean()

		return loss

11、SP

论文链接:https://arxiv.org/pdf/1907.09682.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class SP(nn.Module):
	'''
	Similarity-Preserving Knowledge Distillation
	https://arxiv.org/pdf/1907.09682.pdf
	'''
	def __init__(self):
		super(SP, self).__init__()

	def forward(self, fm_s, fm_t):
		fm_s = fm_s.view(fm_s.size(0), -1)
		G_s  = torch.mm(fm_s, fm_s.t())
		norm_G_s = F.normalize(G_s, p=2, dim=1)

		fm_t = fm_t.view(fm_t.size(0), -1)
		G_t  = torch.mm(fm_t, fm_t.t())
		norm_G_t = F.normalize(G_t, p=2, dim=1)

		loss = F.mse_loss(norm_G_s, norm_G_t)

		return loss

12、Sobolev

论文链接:https://arxiv.org/pdf/1706.04859.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad


class Sobolev(nn.Module):
	'''
	Sobolev Training for Neural Networks
	https://arxiv.org/pdf/1706.04859.pdf

	Knowledge Transfer with Jacobian Matching
	http://de.arxiv.org/pdf/1803.00443
	'''
	def __init__(self):
		super(Sobolev, self).__init__()

	def forward(self, out_s, out_t, img, target):
		target_out_s = torch.gather(out_s, 1, target.view(-1, 1))
		grad_s       = grad(outputs=target_out_s, inputs=img,
							grad_outputs=torch.ones_like(target_out_s),
							create_graph=True, retain_graph=True, only_inputs=True)[0]
		norm_grad_s  = F.normalize(grad_s.view(grad_s.size(0), -1), p=2, dim=1)

		target_out_t = torch.gather(out_t, 1, target.view(-1, 1))
		grad_t       = grad(outputs=target_out_t, inputs=img,
							grad_outputs=torch.ones_like(target_out_t),
							create_graph=True, retain_graph=True, only_inputs=True)[0]
		norm_grad_t  = F.normalize(grad_t.view(grad_t.size(0), -1), p=2, dim=1)

		loss = F.mse_loss(norm_grad_s, norm_grad_t.detach())

		return loss

13、BSS

论文链接:https://arxiv.org/pdf/1805.05532.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.gradcheck import zero_gradients
'''
Modified by https://github.com/bhheo/BSS_distillation
'''

def reduce_sum(x, keepdim=True):
	for d in reversed(range(1, x.dim())):
		x = x.sum(d, keepdim=keepdim)
	return x


def l2_norm(x, keepdim=True):
	norm = reduce_sum(x*x, keepdim=keepdim)
	return norm.sqrt()


class BSS(nn.Module):
	'''
	Knowledge Distillation with Adversarial Samples Supporting Decision Boundary
	https://arxiv.org/pdf/1805.05532.pdf
	'''
	def __init__(self, T):
		super(BSS, self).__init__()
		self.T = T

	def forward(self, attacked_out_s, attacked_out_t):
		loss = F.kl_div(F.log_softmax(attacked_out_s/self.T, dim=1),
						F.softmax(attacked_out_t/self.T, dim=1),
						reduction='batchmean') #* self.T * self.T

		return loss


class BSSAttacker():
	def __init__(self, step_alpha, num_steps, eps=1e-4):
		self.step_alpha = step_alpha
		self.num_steps = num_steps
		self.eps = eps

	def attack(self, model, img, target, attack_class):
		img = img.detach().requires_grad_(True)

		step = 0
		while step < self.num_steps:
			zero_gradients(img)
			_, _, _, _, _, output = model(img)

			score = F.softmax(output, dim=1)
			score_target = score.gather(1, target.unsqueeze(1))
			score_attack_class = score.gather(1, attack_class.unsqueeze(1))

			loss = (score_attack_class - score_target).sum()
			loss.backward()

			step_alpha = self.step_alpha * (target == output.max(1)[1]).float()
			step_alpha = step_alpha.unsqueeze(1).unsqueeze(1).unsqueeze(1)
			if step_alpha.sum() == 0:
				break

			pert = (score_target - score_attack_class).unsqueeze(1).unsqueeze(1)
			norm_pert = step_alpha * (pert + self.eps) * img.grad / l2_norm(img.grad)

			step_adv = img + norm_pert
			step_adv = torch.clamp(step_adv, -2.5, 2.5)
			img.data = step_adv.data

			step += 1

		return img

14、CC

论文链接:http://openaccess.thecvf.com/content_ICCV_2019/papers/Peng_Correlation_Congruence_for_Knowledge_Distillation_ICCV_2019_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


'''
CC with P-order Taylor Expansion of Gaussian RBF kernel
'''
class CC(nn.Module):
	'''
	Correlation Congruence for Knowledge Distillation
	http://openaccess.thecvf.com/content_ICCV_2019/papers/
	Peng_Correlation_Congruence_for_Knowledge_Distillation_ICCV_2019_paper.pdf
	'''
	def __init__(self, gamma, P_order):
		super(CC, self).__init__()
		self.gamma = gamma
		self.P_order = P_order

	def forward(self, feat_s, feat_t):
		corr_mat_s = self.get_correlation_matrix(feat_s)
		corr_mat_t = self.get_correlation_matrix(feat_t)

		loss = F.mse_loss(corr_mat_s, corr_mat_t)

		return loss

	def get_correlation_matrix(self, feat):
		feat = F.normalize(feat, p=2, dim=-1)
		sim_mat  = torch.matmul(feat, feat.t())
		corr_mat = torch.zeros_like(sim_mat)

		for p in range(self.P_order+1):
			corr_mat += math.exp(-2*self.gamma) * (2*self.gamma)**p / \
						math.factorial(p) * torch.pow(sim_mat, p)

		return corr_mat

15、LwM

论文链接:https://arxiv.org/pdf/1811.08051.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

'''
LwM is originally an incremental learning method with 
classification/distillation/attention distillation losses.

Here, LwM is only defined as the Grad-CAM based attention distillation.
'''
class LwM(nn.Module):
	'''
	Learning without Memorizing
	https://arxiv.org/pdf/1811.08051.pdf
	'''
	def __init__(self):
		super(LwM, self).__init__()

	def forward(self, out_s, fm_s, out_t, fm_t, target):
		target_out_t = torch.gather(out_t, 1, target.view(-1, 1))
		grad_fm_t    = grad(outputs=target_out_t, inputs=fm_t,
							grad_outputs=torch.ones_like(target_out_t),
							create_graph=True, retain_graph=True, only_inputs=True)[0]
		weights_t = F.adaptive_avg_pool2d(grad_fm_t, 1)
		cam_t = torch.sum(torch.mul(weights_t, grad_fm_t), dim=1, keepdim=True)
		cam_t = F.relu(cam_t)
		cam_t = cam_t.view(cam_t.size(0), -1)
		norm_cam_t = F.normalize(cam_t, p=2, dim=1)

		target_out_s = torch.gather(out_s, 1, target.view(-1, 1))
		grad_fm_s    = grad(outputs=target_out_s, inputs=fm_s,
							grad_outputs=torch.ones_like(target_out_s),
							create_graph=True, retain_graph=True, only_inputs=True)[0]
		weights_s = F.adaptive_avg_pool2d(grad_fm_s, 1)
		cam_s = torch.sum(torch.mul(weights_s, grad_fm_s), dim=1, keepdim=True)
		cam_s = F.relu(cam_s)
		cam_s = cam_s.view(cam_s.size(0), -1)
		norm_cam_s = F.normalize(cam_s, p=2, dim=1)

		loss = F.l1_loss(norm_cam_s, norm_cam_t.detach())

		return loss

16、IRG

论文链接:http://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


class IRG(nn.Module):
	'''
	Knowledge Distillation via Instance Relationship Graph
	http://openaccess.thecvf.com/content_CVPR_2019/papers/
	Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf

	The official code is written by Caffe
	https://github.com/yufanLIU/IRG
	'''
	def __init__(self, w_irg_vert, w_irg_edge, w_irg_tran):
		super(IRG, self).__init__()

		self.w_irg_vert = w_irg_vert
		self.w_irg_edge = w_irg_edge
		self.w_irg_tran = w_irg_tran

	def forward(self, irg_s, irg_t):
		fm_s1, fm_s2, feat_s, out_s = irg_s
		fm_t1, fm_t2, feat_t, out_t = irg_t

		loss_irg_vert = F.mse_loss(out_s, out_t)

		irg_edge_feat_s = self.euclidean_dist_feat(feat_s, squared=True)
		irg_edge_feat_t = self.euclidean_dist_feat(feat_t, squared=True)
		irg_edge_fm_s1  = self.euclidean_dist_fm(fm_s1, squared=True)
		irg_edge_fm_t1  = self.euclidean_dist_fm(fm_t1, squared=True)
		irg_edge_fm_s2  = self.euclidean_dist_fm(fm_s2, squared=True)
		irg_edge_fm_t2  = self.euclidean_dist_fm(fm_t2, squared=True)
		loss_irg_edge = (F.mse_loss(irg_edge_feat_s, irg_edge_feat_t) +
						 F.mse_loss(irg_edge_fm_s1,  irg_edge_fm_t1 ) +
						 F.mse_loss(irg_edge_fm_s2,  irg_edge_fm_t2 )) / 3.0

		irg_tran_s = self.euclidean_dist_fms(fm_s1, fm_s2, squared=True)
		irg_tran_t = self.euclidean_dist_fms(fm_t1, fm_t2, squared=True)
		loss_irg_tran = F.mse_loss(irg_tran_s, irg_tran_t)

		# print(self.w_irg_vert * loss_irg_vert)
		# print(self.w_irg_edge * loss_irg_edge)
		# print(self.w_irg_tran * loss_irg_tran)
		# print()

		loss = (self.w_irg_vert * loss_irg_vert +
				self.w_irg_edge * loss_irg_edge +
				self.w_irg_tran * loss_irg_tran)

		return loss

	def euclidean_dist_fms(self, fm1, fm2, squared=False, eps=1e-12):
		'''
		Calculating the IRG Transformation, where fm1 precedes fm2 in the network.
		'''
		if fm1.size(2) > fm2.size(2):
			fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
		if fm1.size(1) < fm2.size(1):
			fm2 = (fm2[:,0::2,:,:] + fm2[:,1::2,:,:]) / 2.0

		fm1 = fm1.view(fm1.size(0), -1)
		fm2 = fm2.view(fm2.size(0), -1)
		fms_dist = torch.sum(torch.pow(fm1-fm2, 2), dim=-1).clamp(min=eps)

		if not squared:
			fms_dist = fms_dist.sqrt()

		fms_dist = fms_dist / fms_dist.max()

		return fms_dist

	def euclidean_dist_fm(self, fm, squared=False, eps=1e-12): 
		'''
		Calculating the IRG edge of feature map. 
		'''
		fm = fm.view(fm.size(0), -1)
		fm_square = fm.pow(2).sum(dim=1)
		fm_prod   = torch.mm(fm, fm.t())
		fm_dist   = (fm_square.unsqueeze(0) + fm_square.unsqueeze(1) - 2 * fm_prod).clamp(min=eps)

		if not squared:
			fm_dist = fm_dist.sqrt()

		fm_dist = fm_dist.clone()
		fm_dist[range(len(fm)), range(len(fm))] = 0
		fm_dist = fm_dist / fm_dist.max()

		return fm_dist

	def euclidean_dist_feat(self, feat, squared=False, eps=1e-12):
		'''
		Calculating the IRG edge of feat.
		'''
		feat_square = feat.pow(2).sum(dim=1)
		feat_prod   = torch.mm(feat, feat.t())
		feat_dist   = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)

		if not squared:
			feat_dist = feat_dist.sqrt()

		feat_dist = feat_dist.clone()
		feat_dist[range(len(feat)), range(len(feat))] = 0
		feat_dist = feat_dist / feat_dist.max()

		return feat_dist

17、VID

论文链接:https://openaccess.thecvf.com/content_CVPR_2019/papers/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def conv1x1(in_channels, out_channels):
	return nn.Conv2d(in_channels, out_channels,
					 kernel_size=1, stride=1,
					 padding=0, bias=False)

'''
Modified from https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/VID.py
'''
class VID(nn.Module):
	'''
	Variational Information Distillation for Knowledge Transfer
	https://zpascal.net/cvpr2019/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf
	'''
	def __init__(self, in_channels, mid_channels, out_channels, init_var, eps=1e-6):
		super(VID, self).__init__()
		self.eps = eps
		self.regressor = nn.Sequential(*[
				conv1x1(in_channels, mid_channels),
				# nn.BatchNorm2d(mid_channels),
				nn.ReLU(),
				conv1x1(mid_channels, mid_channels),
				# nn.BatchNorm2d(mid_channels),
				nn.ReLU(),
				conv1x1(mid_channels, out_channels),
			])
		self.alpha = nn.Parameter(
				np.log(np.exp(init_var-eps)-1.0) * torch.ones(out_channels)
			)

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
			# elif isinstance(m, nn.BatchNorm2d):
			# 	nn.init.constant_(m.weight, 1)
			# 	nn.init.constant_(m.bias, 0)

	def forward(self, fm_s, fm_t):
		pred_mean = self.regressor(fm_s)
		pred_var  = torch.log(1.0+torch.exp(self.alpha)) + self.eps
		pred_var  = pred_var.view(1, -1, 1, 1)
		neg_log_prob = 0.5 * (torch.log(pred_var) + (pred_mean-fm_t)**2 / pred_var)
		loss = torch.mean(neg_log_prob)

		return loss

18、OFD

论文链接:http://openaccess.thecvf.com/content_ICCV_2019/papers/Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


'''
Modified from https://github.com/clovaai/overhaul-distillation/blob/master/CIFAR-100/distiller.py
'''
class OFD(nn.Module):
	'''
	A Comprehensive Overhaul of Feature Distillation
	http://openaccess.thecvf.com/content_ICCV_2019/papers/
	Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
	'''
	def __init__(self, in_channels, out_channels):
		super(OFD, self).__init__()
		self.connector = nn.Sequential(*[
				nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
				nn.BatchNorm2d(out_channels)
			])

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def forward(self, fm_s, fm_t):
		margin = self.get_margin(fm_t)
		fm_t = torch.max(fm_t, margin)
		fm_s = self.connector(fm_s)

		mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
		loss = torch.mean((fm_s - fm_t)**2 * mask)

		return loss

	def get_margin(self, fm, eps=1e-6):
		mask = (fm < 0.0).float()
		masked_fm = fm * mask

		margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)

		return margin

19、AFD

论文链接:https://openreview.net/pdf?id=ryxyCeHtPB
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

'''
In the original paper, AFD is one of components of AFDS.
AFDS: Attention Feature Distillation and Selection
AFD:  Attention Feature Distillation
AFS:  Attention Feature Selection

We find the original implementation of attention is unstable, thus we replace it with a SE block.
'''
class AFD(nn.Module):
	'''
	Pay Attention to Features, Transfer Learn Faster CNNs
	https://openreview.net/pdf?id=ryxyCeHtPB
	'''
	def __init__(self, in_channels, att_f):
		super(AFD, self).__init__()
		mid_channels = int(in_channels * att_f)

		self.attention = nn.Sequential(*[
				nn.Conv2d(in_channels, mid_channels, 1, 1, 0, bias=True),
				nn.ReLU(inplace=True),
				nn.Conv2d(mid_channels, in_channels, 1, 1, 0, bias=True)
			])

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
				if m.bias is not None:
					nn.init.constant_(m.bias, 0)
		
	def forward(self, fm_s, fm_t, eps=1e-6):
		fm_t_pooled = F.adaptive_avg_pool2d(fm_t, 1)
		rho = self.attention(fm_t_pooled)
		# rho = F.softmax(rho.squeeze(), dim=-1)
		rho = torch.sigmoid(rho.squeeze())
		rho = rho / torch.sum(rho, dim=1, keepdim=True)

		fm_s_norm = torch.norm(fm_s, dim=(2,3), keepdim=True)
		fm_s      = torch.div(fm_s, fm_s_norm+eps)
		fm_t_norm = torch.norm(fm_t, dim=(2,3), keepdim=True)
		fm_t      = torch.div(fm_t, fm_t_norm+eps)

		loss = rho * torch.pow(fm_s-fm_t, 2).mean(dim=(2,3))
		loss = loss.sum(1).mean(0)

		return loss


20、CRD

论文链接:https://openreview.net/pdf?id=SkgpBJrtvS
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


'''
Modified from https://github.com/HobbitLong/RepDistiller/tree/master/crd
'''
class CRD(nn.Module):
	'''
	Contrastive Representation Distillation
	https://openreview.net/pdf?id=SkgpBJrtvS

	includes two symmetric parts:
	(a) using teacher as anchor, choose positive and negatives over the student side
	(b) using student as anchor, choose positive and negatives over the teacher side

	Args:
		s_dim: the dimension of student's feature
		t_dim: the dimension of teacher's feature
		feat_dim: the dimension of the projection space
		nce_n: number of negatives paired with each positive
		nce_t: the temperature
		nce_mom: the momentum for updating the memory buffer
		n_data: the number of samples in the training set, which is the M in Eq.(19)
	'''
	def __init__(self, s_dim, t_dim, feat_dim, nce_n, nce_t, nce_mom, n_data):
		super(CRD, self).__init__()
		self.embed_s = Embed(s_dim, feat_dim)
		self.embed_t = Embed(t_dim, feat_dim)
		self.contrast = ContrastMemory(feat_dim, n_data, nce_n, nce_t, nce_mom)
		self.criterion_s = ContrastLoss(n_data)
		self.criterion_t = ContrastLoss(n_data)

	def forward(self, feat_s, feat_t, idx, sample_idx):
		feat_s = self.embed_s(feat_s)
		feat_t = self.embed_t(feat_t)
		out_s, out_t = self.contrast(feat_s, feat_t, idx, sample_idx)
		loss_s = self.criterion_s(out_s)
		loss_t = self.criterion_t(out_t)
		loss = loss_s + loss_t

		return loss


class Embed(nn.Module):
	def __init__(self, in_dim, out_dim):
		super(Embed, self).__init__()
		self.linear = nn.Linear(in_dim, out_dim)

	def forward(self, x):
		x = x.view(x.size(0), -1)
		x = self.linear(x)
		x = F.normalize(x, p=2, dim=1)

		return x


class ContrastLoss(nn.Module):
	'''
	contrastive loss, corresponding to Eq.(18)
	'''
	def __init__(self, n_data, eps=1e-7):
		super(ContrastLoss, self).__init__()
		self.n_data = n_data
		self.eps = eps

	def forward(self, x):
		bs = x.size(0)
		N  = x.size(1) - 1
		M  = float(self.n_data)

		# loss for positive pair
		pos_pair = x.select(1, 0)
		log_pos  = torch.div(pos_pair, pos_pair.add(N / M + self.eps)).log_()

		# loss for negative pair
		neg_pair = x.narrow(1, 1, N)
		log_neg  = torch.div(neg_pair.clone().fill_(N / M), neg_pair.add(N / M + self.eps)).log_()

		loss = -(log_pos.sum() + log_neg.sum()) / bs

		return loss


class ContrastMemory(nn.Module):
	def __init__(self, feat_dim, n_data, nce_n, nce_t, nce_mom):
		super(ContrastMemory, self).__init__()
		self.N = nce_n
		self.T = nce_t
		self.momentum = nce_mom
		self.Z_t = None
		self.Z_s = None

		stdv = 1. / math.sqrt(feat_dim / 3.)
		self.register_buffer('memory_t', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))
		self.register_buffer('memory_s', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))

	def forward(self, feat_s, feat_t, idx, sample_idx):
		bs = feat_s.size(0)
		feat_dim = self.memory_s.size(1)
		n_data = self.memory_s.size(0)

		# using teacher as anchor
		weight_s = torch.index_select(self.memory_s, 0, sample_idx.view(-1)).detach()
		weight_s = weight_s.view(bs, self.N + 1, feat_dim)
		out_t = torch.bmm(weight_s, feat_t.view(bs, feat_dim, 1))
		out_t = torch.exp(torch.div(out_t, self.T)).squeeze().contiguous()

		# using student as anchor
		weight_t = torch.index_select(self.memory_t, 0, sample_idx.view(-1)).detach()
		weight_t = weight_t.view(bs, self.N + 1, feat_dim)
		out_s = torch.bmm(weight_t, feat_s.view(bs, feat_dim, 1))
		out_s = torch.exp(torch.div(out_s, self.T)).squeeze().contiguous()

		# set Z if haven't been set yet
		if self.Z_t is None:
			self.Z_t = (out_t.mean() * n_data).detach().item()
		if self.Z_s is None:
			self.Z_s = (out_s.mean() * n_data).detach().item()

		out_t = torch.div(out_t, self.Z_t)
		out_s = torch.div(out_s, self.Z_s)

		# update memory
		with torch.no_grad():
			pos_mem_t = torch.index_select(self.memory_t, 0, idx.view(-1))
			pos_mem_t.mul_(self.momentum)
			pos_mem_t.add_(torch.mul(feat_t, 1 - self.momentum))
			pos_mem_t = F.normalize(pos_mem_t, p=2, dim=1)
			self.memory_t.index_copy_(0, idx, pos_mem_t)

			pos_mem_s = torch.index_select(self.memory_s, 0, idx.view(-1))
			pos_mem_s.mul_(self.momentum)
			pos_mem_s.add_(torch.mul(feat_s, 1 - self.momentum))
			pos_mem_s = F.normalize(pos_mem_s, p=2, dim=1)
			self.memory_s.index_copy_(0, idx, pos_mem_s)

		return out_s, out_t


21、DML

论文链接:https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf
代码:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F


'''
DML with only two networks
'''
class DML(nn.Module):
	'''
	Deep Mutual Learning
	https://zpascal.net/cvpr2018/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf
	'''
	def __init__(self):
		super(DML, self).__init__()

	def forward(self, out1, out2):
		loss = F.kl_div(F.log_softmax(out1, dim=1),
						F.softmax(out2, dim=1),
						reduction='batchmean')

		return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/9308.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文件上传漏洞实验-通过截取http请求绕过前端javascript验证进行文件上传

1、什么是文件上传漏洞 文件上传漏洞是指网络攻击者上传了一个可执行的文件到服务器并执行。这里上传的文件可以是木马&#xff0c;病毒&#xff0c;恶意脚本或者WebShell等。这种攻击方式是最为直接和有效的&#xff0c;部分文件上传漏洞的利用技术门槛非常的低&#xff0c;对…

如何使用AI图片清晰度增强器软件增强和锐化图片、提高照片清晰度并去除噪点

通过使用深度学习AI算法对照片进行批量锐化、去噪和去模糊处理&#xff0c;该程序可以应用再大部分照片和图片&#xff0c;包括徽标、卡通和动漫 可能很多朋友都会遇到需要批量增强和锐化照片的情况&#xff1a;例如&#xff0c;如果拍摄过程中曝光不足、夜晚噪点多或者画面模…

基于蚁群算法的TPS问题求解策略研究(Matlab代码实现)

&#x1f352;&#x1f352;&#x1f352;欢迎关注&#x1f308;&#x1f308;&#x1f308; &#x1f4dd;个人主页&#xff1a;我爱Matlab &#x1f44d;点赞➕评论➕收藏 养成习惯&#xff08;一键三连&#xff09;&#x1f33b;&#x1f33b;&#x1f33b; &#x1f34c;希…

Java面试干货:关于数组查找的几个常用实现算法

查找算法在我们的面试和开发中&#xff0c;是很常见的一种算法&#xff0c;今天我就给大家介绍几个常用的查找算法。 一. 线性查找 1.概念 线性查找也叫顺序查找&#xff0c;这是最基本的一种查找方法。该算法是从给定的值中进行搜索&#xff0c;从一端开始逐一检查每个元素…

华为堆叠技术讲解

目录 为什么出现堆叠 什么是堆叠 堆叠的特征 堆叠的优缺点 华为堆叠技术 框式交换机堆叠技术CSS CSS堆叠涉及的相关基础概念 主交换机选举过程 堆叠系统主备倒换 CSS两种堆叠口 CSS堆叠方式 CSS以太网链路聚合 本地优先转发 CSS双主检测 CSS版本升级 CSS堆叠实…

java和vue车辆管理系统车管所系统

简介 车辆管理系统车管所系统&#xff0c;管理员添加车主信息&#xff0c;车主提交自己的车辆信息&#xff0c;管理员审核车辆&#xff0c;对车辆行进年检&#xff0c;统计&#xff0c;记录车辆违规信息。车主可以查看自己的车辆信息、投诉、查看自己的违规记录等。 演示视频…

【附源码】计算机毕业设计JAVA客户台账管理

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; Springboot mybatis Maven Vue 等等组成&#xff0c;B/…

2021亚太杯C题全网最全解题思路+塞罕坝林场数据数据分享

全网绝对能获奖的免费思路&#xff01;&#xff01;&#xff01; 文章目录1.写在前面&#xff0c;需要塞罕坝林场数据的这里链接获取&#xff1a;2.C题全网最全解题思路1.写在前面&#xff0c;需要塞罕坝林场数据的这里链接获取&#xff1a; https://download.csdn.net/downlo…

如何在SpringBoot项目中,实现记录用户登录的IP地址及归属地信息?

在登录模块&#xff0c;我们经常要记录登录日志&#xff0c;其中比较重要的信息有ip地址和ip归属地&#xff0c;像我们公司开发的产品会提供给用户试用&#xff0c;因为我们做的是无人机应用方向的&#xff0c;即使试用也会产生费用&#xff0c;因为我们很多功能一旦用了就会消…

GFS分布式文件系统及其部署

目录 一、GlusterFS 1 MFS 2 GlusterFS 二 、GlusterFs特点 1 扩展性和高性能 2 高可用性 3 全局统一命名空间 4 弹性卷管理 5 基于标准协议 三 GlusterFS 术语 1 Brick(存储块) 2 volume(逻辑卷) 3 FUSE 4 VFS 5 Glusterd (后台管理进程) 四 模块化堆枝式架构…

深度支持赛事宣发,DF平台助推第三届全国人工智能大赛顺利举办!

由深圳市人民政府和鹏城实验室主办、深圳市科技创新委员会与新一代人工智能产业技术创新战略联盟联合承办的人工智能领域顶尖赛事——第三届全国人工智能大赛已圆满落幕。DataFountain大数据竞赛平台&#xff08;简称DF平台&#xff09;作为本次大赛的宣传合作伙伴&#xff0c;…

EFK部署centos7.9(二)head插件部署

安装配置head监控插件 本人是在ES服务器安装head插件是Nodejs实现的&#xff0c;所以需要先安装Nodejs。 wget https://nodejs.org/dist/v14.17.6/node-v14.17.6-linux-x64.tar.xz 下载安装包 tar xf node-v14.17.6-linux-x64.tar.xz -C /usr/local/ 解压安装包 vim /e…

PostgreSQL实战之物理复制和逻辑复制(三)

目录 PostgreSQL实战之物理复制和逻辑复制&#xff08;三&#xff09; 3.1 单实例、异步流复制、同步流复制性能测试 3.1 读性能测试 3.2 写性能测试 PostgreSQL实战之物理复制和逻辑复制&#xff08;三&#xff09; 3.1 单实例、异步流复制、同步流复制性能测试 根据Post…

智慧交通解决方案-最新全套文件

智慧交通解决方案-最新全套文件一、建设背景二、思路架构1、先进性2、全智慧化3、可靠性4、保密性三、解决方案四、获取 - 智慧交通全套最新解决方案合集一、建设背景 智能交通系统是未来交通系统的发展方向&#xff0c;它是将先进的信息技术、数据通讯传输技术、电子传感技术…

kt-connect使用-k8s流量代理

1.下载kt安装包 地址: https://github.com/alibaba/kt-connect/releases 2.下载k8s集群的config文件 cd /root/.kube 下载服务器的config文件3.安装config文件 切换到用户目录C:\Users\yangx创建.kube文件夹mkdir .kube目录结构为C:\Users\yangx\.kube把下载config文件移动到…

XSS进阶三

目录实验目的预备知识实验环境实验步骤一实例七、和实例六好像木有区别实验步骤二实例八、有时候你要跳出你的思维实验步骤三实例九、将xss进行到底实验目的 1.深入理解xss工作原理。 2.怎么去绕过规则实现xss。 3.培养学生的独立思考能力。 预备知识 XSS基础、XSS进阶一。 …

字节跳动测试岗面试记:二面被按地上血虐,所幸Offer已到手...

在互联网做了几年之后&#xff0c;去大厂“镀镀金”是大部分人的首选。大厂不仅待遇高、福利好&#xff0c;更重要的是&#xff0c;它是对你专业能力的背书&#xff0c;大厂工作背景多少会给你的简历增加几分竞争力。 但说实话&#xff0c;想进大厂还真没那么容易。最近面试字…

[力扣] 剑指 Offer 第二天 - 从尾到头打印链表

[力扣] 剑指 Offer 第二天 - 从尾到头打印链表题目来源题目描述题目分析解题思路递归反转数组&#xff08;切片&#xff09;代码实现递归执行结果复杂度分析反转切片执行结果复杂度分析总结耐心和持久胜过激烈和狂热。 题目来源 来源&#xff1a;力扣&#xff08;LeetCode&…

【微服务】如何利用Nacos Config实现服务配置?

微服务--Nacos Config1、前言2、Nacos Config2.1 简介2.2 常见的服务配置中心3、Nacos Config入门4、Nacos Config深入4.1 配置动态刷新4.2 配置共享4.2.1 同一个微服务不同环境之间共享配置4.2.2 不同微服务之间共享配置5、Nacos的几个概念6、总结1、前言 在前期的文章中&…

纯代谢冲上Nature不是幻想,蛋氨酸饮食可影响小鼠癌症预后,同样可影响人体代谢

百趣代谢组学文献分享一篇题名 "Dietary methionine influences therapy in mouse cancer models and alters human metabolism"&#xff0c;发表在Nature。文章是做的纯代谢研究&#xff08;饮食影响癌症和代谢&#xff09;。纯代谢冲上Nature&#xff0c;这篇文章是…