0. 引言
前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions 包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.
Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)∝exp(κμ⊺x)⇓∝(1+μ⊺x)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1∼Beta(2p−1+κ,2p−1), 从而避免了接受-拒绝采样过程.
当然, 按照之前的 VonMisesFisher 的写法, 这个 t 的采样大概是这样:
z = beta.sample(sample_shape)
t = 2 * z - 1
但现在我遇到了这种写法:
class MarginalTDistribution(tds.TransformedDistribution):
	arg_constraints = {
		'dim': constraints.positive_integer,
		'scale': constraints.positive,
	}
	has_rsample = True
	def __init__(self, dim, scale, validate_args=None):
		self.dim = dim
		self.scale = scale
		super().__init__(
			tds.Beta(  # 用 Beta 分布转换, z 服从 Beta(α+κ,β)
				(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args
			),
			transforms=tds.AffineTransform(loc=-1, scale=2),  # t=2z-1 是想要的边缘分布随机数
		)
然后就可以进行对  
     
      
       
       
         t 
        
       
      
        t 
       
      
    t 的采样了.
 
我们可以看到其基本架构, 本文将详细解析其内部的具体细节, 包括:
1. Distribution
 
在之前的 <von Mises-Fisher Distribution (代码解析)> 中, 已经通过 Uniform 简单介绍了 Distribution 的用法. 它是实现各种分布的抽象基类. 本文将以解析源码的方式详细介绍.
1.1 参数验证 validate_args
打开源码, 首先映入眼帘的是关于参数验证的代码:
# true if Python was not started with an -O option. See also the assert statement.
_validate_args = __debug__
@staticmethod
def set_default_validate_args(value: bool) -> None:
	"""
	设置 validation 是否开启.
	validation 通常是耗时的, 所以最好在模型 work 后关闭它.
	"""
	if value not in [True, False]:
		raise ValueError
	Distribution._validate_args = value
Distribution 有一个类属性叫 _validate_args, 默认值是 __debug__(见附录1), 可以通过类静态方法 set_default_validate_args(value: bool) 来修改此值.
构造方法 __init__(...) 中的验证逻辑:
def __init__(self, ..., validate_args: Optional[bool]=None):
	...
	if validate_args is not None:
		self._validate_args = validate_args
也就是说, 你可以在创建 Distribution 实例的时候设置是否进行参数验证. 如果不设置, 则按照类的属性 Distribution._validate_args 来.
if self._validate_args:  # validate_args=False 就不用设置 arg_constraints 了
	try:  # 尝试获取字典 arg_constraints
		arg_constraints = self.arg_constraints
	except NotImplementedError:  # 如果没设置, 则设置为 {}, 抛出警告
		arg_constraints = {}
		warnings.warn(...)
如果需要验证参数, 那么首先要获取一个叫 arg_constraints 的参数验证字典, 它列出了需要验证哪些参数. 这个抽象类里面并没有给出, 需要用户继承该类时写在子类中. 以 Uniform 为例:
class Uniform(Distribution):
	...
	arg_constraints = {
		"low": constraints.dependent(is_discrete=False, event_dim=0),
		"high": constraints.dependent(is_discrete=False, event_dim=0),
	}
	...
至于 constraints.dependent 是啥, 后面会详细介绍. 值得注意的是, 如果你在创建实例时指定 validate_args=False, 那么所有关于参数验证的事就都不用管了.
for param, constraint in arg_constraints.items():
	if constraints.is_dependent(constraint):
		continue  # skip constraints that cannot be checked
	if param not in self.__dict__ and isinstance(
			getattr(type(self), param), lazy_property
	):
		continue  # skip checking lazily-constructed args
	value = getattr(self, param)  # 从当前对象获取参数 value
	valid = constraint.check(value)  # 检查参数值
	if not valid.all():  # 检查不通过
		raise ValueError(...)
这一段就是验证过程了, 包括:
- skip constraints that cannot be checked, 由 constraints.is_dependent(constraint)判断是否可验证;
- skip checking lazily-constructed args, 即参数名不在 self.__dict__中, 并属于lazy_property的跳过;
- 获得参数, 进行验证;
具体的验证细节将在后面介绍.
1.2 batch_shape & event_shape
 
除了 validate_args 参数, __init__(...) 方法中的另外两个参数就是:
def __init__(
		self,
		batch_shape: torch.Size = torch.Size(),
		event_shape: torch.Size = torch.Size(),
):
	self._batch_shape = batch_shape
	self._event_shape = event_shape
	...
这两个参数是啥? 在这个抽象类中, 我们看不到太多信息, 甚至 Uniform 中也只有 batch_shape = self.low.size() 的信息, 大概意思同时进行着一批的均匀分布, 如 low = torch.tensor([0.0, 1.0]) 时, batch_shape = torch.Size([2]), 表示一个二元的均匀分布. 看 MultivariateNormal, 里面信息量较大:
batch_shape = torch.broadcast_shapes(
	covariance_matrix.shape[:-2],  # [:-2]是去掉了协方差矩阵的维度, 剩下的可能是 batch 的维度
	loc.shape[:-1]  # [:-1]是去掉了 envent 的维度, 剩下的可能是 batch 的维度
)  # broadcast_shapes 意思是进行了广播, 如果 matrix 的 batch_shape 是 [2,1], loc 的 batch_shape 是 [1,2], 那么整个的 batch_shape 是广播后的 [2,2]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))  # 之后 covariance_matrix 都被 expand 了
...
event_shape = self.loc.shape[-1:]  # 看来就是样本的 shape
从这一段来看, batch_shape 是指创建的实例在进行多少个平行的基本分布, 而 event_shape 是指基本分布的事件(支撑点)维度. 如:
locs = torch.randn(2, 3)
matrixs = torch.randn(2, 3, 3)
covariance_matrixs = torch.bmm(matrixs, matrixs.transpose(1, 2))
normal = distributions.MultivariateNormal(loc=locs, covariance_matrix=covariance_matrixs)
print(normal.batch_shape)  # 2
print(normal.event_shape)  # 3
print(normal.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[ 1.8972, -0.3961, -0.1530],
		[-0.5018, -2.5110,  0.1293]])
batch 的意思还是那个 batch, 不过这里是指分布的 batch, 而不是数据的 batch. 采样时, 得到一批 samples, 对应每个分布.
还有一个 method 和这两个参数有关: expand, 因为它是一个抽象 method, 基类中并没有实现, 那就直接看 MultivariateNormal 中的:
def expand(self, batch_shape: torch.Size, _instance=None):
	"""
	Args:
		batch_shape (torch.Size): the desired expanded size.
		_instance: new instance provided by subclasses that need to override `.expand`.
	Returns:
		New distribution instance with batch dimensions expanded to `batch_size`.
	"""
	new = self._get_checked_instance(MultivariateNormal, _instance)
	batch_shape = torch.Size(batch_shape)
	loc_shape = batch_shape + self.event_shape
	cov_shape = batch_shape + self.event_shape + self.event_shape
	new.loc = self.loc.expand(loc_shape)
	new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
	if "covariance_matrix" in self.__dict__:
		new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
	if "scale_tril" in self.__dict__:
		new.scale_tril = self.scale_tril.expand(cov_shape)
	if "precision_matrix" in self.__dict__:
		new.precision_matrix = self.precision_matrix.expand(cov_shape)
	super(MultivariateNormal, new).__init__(
		batch_shape, self.event_shape, validate_args=False
	)
	new._validate_args = self._validate_args
	return new
这个 method 会创建一个新的 instance 或调用的时候用户提供, 并设置 batch_shape 为参数提供的形状, 然后把参数 expand 到新的 batch_shape. 用法:
mean = torch.randn(3)
matrix = torch.randn(3, 3)
covariance_matrix = torch.mm(matrix, matrix.t())
mvn = MultivariateNormal(mean, covariance_matrix)
bmvn = mvn.expand(torch.Size([2]))
print(bmvn.batch_shape)
print(bmvn.event_shape)
print(bmvn.sample())
##### output #####
torch.Size([2])
torch.Size([3])
tensor([[-4.0891, -4.2424,  6.2574],
		[ 0.7656, -0.2199, -0.9836]])
1.3 一些属性
包括: m e a n mean mean, m o d e mode mode, s t d std std, v a r i a n c e variance variance, e n t r o p y entropy entropy 等基本属性, 都需要用户在子类中自己实现. 还有一些相关的函数:
- cumulative density/mass function cdf(value);
- inverse cumulative density/mass function icdf(value);
 这个函数非常有用, Inverse Transform Sampling 中用其进行采样. 从 U ( 0 , 1 ) U(0,1) U(0,1) 中采样一个 u u u, 然后令 x = F − 1 ( u ) x = F^{-1}(u) x=F−1(u) 就是所求随机变量 X X X 的一个采样.
- log of the probability density/mass function log_prob(value), 对数概率.
注意, 目前看到的只有 log_prob, 并没有 prob, 一些示例要么只算 log_prob, 要么计算后通过 exp(log_prob) 得到 prob.
2. constraints.Constraint
 
前面在1.1参数验证中已经遇到 constraints.dependent(is_discrete=False, event_dim=0) 和 constraint.check(value), 但没有讲具体细节. 本节将详细剖析.
2.1 抽象基类 Constraint
 
先看源码:
class Constraint:
	"""
	一个 constraint 对象, 表示变量在某区域内有效, 即变量可优化的范围.
	"""
	is_discrete = False  # Default to continuous.
	event_dim = 0  # Default to univariate.
	def check(self, value):
		"""
		结果的形状为"sample_shape + batch_shape", 指示 each event 值是否满足此限制.
		"""
		raise NotImplementedError
这是抽象基类 Constraint, 比较简单, 只有两个类属性和一个 method check(value). is_discrete 表示待验证值是否为离散; 联想前面的 event_shape, 大概可以知道 event_dim 是指 len(event_shape).(不过目前看只是为了验证参数, 还能验证采样的 event?)
2.2 _Dependent() 不被验证
 
这个基类信息太少, 对我们理解前面的内容毫无用处, 还是直接观察一些子类吧. 从 dependent = _Dependent() 开始, 它是 constraints.py 中定义好的 placeholder(这个倒是可以学一学):
class _Dependent(Constraint):  # 看"_", 应该是不希望用户直接创建实例
	"""
	Placeholder for variables whose support depends on other variables.
	These variables obey no simple coordinate-wise constraints.
	"""
	def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
		self._is_discrete = is_discrete
		self._event_dim = event_dim
		super().__init__()
	def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
		"""
		Support for syntax to customize static attributes::
			constraints.dependent(is_discrete=True, event_dim=1)
		"""
		if is_discrete is NotImplemented:  # 未提供就是默认
			is_discrete = self._is_discrete
		if event_dim is NotImplemented:
			event_dim = self._event_dim
		return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
	def check(self, x):
		raise ValueError("Cannot determine validity of dependent constraint")
闹了半天, 我们并不能看到 constraints.dependent(is_discrete=False, event_dim=0) 有什么卵用, 只知道 “Cannot determine validity of dependent constraint”, 这也呼应了前面的:
if constraints.is_dependent(constraint):
	continue  # skip constraints that cannot be checked
也就是说, dependent 类型的限制是不会执行参数验证的. 那这个 _Dependent 到底有何用处? 先不管了.
2.3 _IndependentConstraint 重新解释 event_dim
 
我们看点复杂的, MultivariateNormal.arg_constraints:
arg_constraints = {
	"loc": constraints.real_vector,
	"covariance_matrix": constraints.positive_definite,
	"precision_matrix": constraints.positive_definite,
	"scale_tril": constraints.lower_cholesky,
}
这些都是 constraints.py 中定义好的实例, 对于大多情况, 这些预定义好的实例已经够用, 但如果需要, 你也可以自定义. 先看 real_vector:
independent = _IndependentConstraint
real_vector = independent(real, 1)
class _IndependentConstraint(Constraint):
	"""
	封装一个 constraint,  通过 aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`,
	an event is valid 当且仅当它依赖的所有 entries 是 valid 的.
	"""
	def __init__(self, base_constraint, reinterpreted_batch_ndims):
		self.base_constraint = base_constraint
		self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
		super().__init__()
	@property
	def event_dim(self):
		# real.event_dim 是 0, + real_vector(reinterpreted_batch_ndims=1) = 1
		return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
	def check(self, value):
		result = self.base_constraint.check(value)  # 首先要符合 base.check
		if result.dim() < self.reinterpreted_batch_ndims:
			# 给 batch 留够 dim
			expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
			raise ValueError(
				f"Expected value.dim() >= {expected} but got {value.dim()}"
			)
		result = result.reshape(  # 减掉 event
			result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
		)
		result = result.all(-1)  # 减少一个 dim
		return result
意思很明了了, real_vector 是依赖于 real(base_constraint) 的, reinterpreted_batch_ndims=1 是说把原来的 value 重新解释, event_dim 加上 reinterpreted_batch_ndims, 比如
value = [[1, 2, 3],
		 [4, 5, 6]]
本来 real 的 event_dim=0, 验证结果为(sample_shape + batch_shape = (2,2)):
value = [[True, True, True],
		 [True, True, True]]
现在重新解释为 event_dim=1, 验证结果为:
result = result.reshape(  # 减掉 event
   	result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)  # (-1,) 表示新 event 内的所有 entries 展平
)
result = result.all(-1)  # 新 event 内的所有 entries 为 True, 则新 event 为 True
================>
value = [True, True]
3. Transform & _InverseTransform
 
上一节介绍了 constraints.Constraint, 明白了在构建 Distribution 实例时进行的参数验证, 以保证用户提供的参数符合要求. 但还留下了一个疑问: Constraint 中的 event_dim 是指 len(event_shape), 难道还能验证采样的 event? 再者, check(value) 返回值的形状是 sample_shape + batch_shape, 进一步说明它是会被用于采样结果检查的. 让我们看一看能否在 Transform 中找到答案.
Transform & _InverseTransform 是一对互逆的操作, 看一看里面都有什么:
3.1 Attributes
class Transform:
	"""
	Attributes:
		domain (constraints.Constraint):
			Transform 的有效输入范围.
		codomain (constraints.Constraint):
			Transform 的有效输出范围.  # 输出是 inverse transform 的输入.
		bijective (bool): Transform 是否双射.
			即使不是双射, Transforms 也应是弱伪可逆的:
				t(t.inv(t(x)) == t(x) and t.inv(t(t.inv(y))) == t.inv(y).
		sign (int or Tensor): 对于双射单变量 transforms, +1 or -1
			取决于 transform 单调增还是单调减.
	"""
	bijective = False  # 默认 False
	domain: constraints.Constraint
	codomain: constraints.Constraint
class Transform:
	"""
	可逆变换的抽象基类, with computable log det jacobians.
	Caching 对于计算逆复杂或不稳定的变换非常有用.
	子类应该实现 one or both of `_call` or `_inverse`.
	如果 `bijective=True`, 则必须实现 `log_abs_det_jacobian`.
	Args:
		cache_size (int): If one, the latest single value is cached.
		Only 0 and 1 are supported.
	"""
	def __init__(self, cache_size=0):
		self._cache_size = cache_size
		self._inv = None
		if cache_size == 0:
			pass  # default behavior
		elif cache_size == 1:
			self._cached_x_y = None, None
		else:
			raise ValueError("cache_size must be 0 or 1")
		super().__init__()
	def __getstate__(self):
		state = self.__dict__.copy()
		state["_inv"] = None
		return state
	@property
	def event_dim(self):
		if self.domain.event_dim == self.codomain.event_dim:  # 当定义域和值域 event_dim 相同时, 才能简略为 event_dim
			return self.domain.event_dim
		raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
	@property
	def inv(self):
		"""
		Returns the inverse :class:`Transform` of this transform.
		This should satisfy ``t.inv.inv is t``.
		"""
		inv = None
		if self._inv is not None:
			inv = self._inv()
		if inv is None:
			inv = _InverseTransform(self)
			self._inv = weakref.ref(inv)
		return inv
	def with_cache(self, cache_size=1):
		if self._cache_size == cache_size:
			return self
		if type(self).__init__ is Transform.__init__:
			return type(self)(cache_size=cache_size)
		raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
	def __call__(self, x):
		"""
		Computes the transform `x => y`.
		"""
		if self._cache_size == 0:
			return self._call(x)
		x_old, y_old = self._cached_x_y
		if x is x_old:
			return y_old
		y = self._call(x)
		self._cached_x_y = x, y
		return y
	def _inv_call(self, y):
		"""
		Inverts the transform `y => x`.
		"""
		if self._cache_size == 0:
			return self._inverse(y)
		x_old, y_old = self._cached_x_y
		if y is y_old:
			return x_old
		x = self._inverse(y)
		self._cached_x_y = x, y
		return x
	def _call(self, x):
		"""
		Abstract method to compute forward transformation.
		"""
		raise NotImplementedError
	def _inverse(self, y):
		"""
		Abstract method to compute inverse transformation.
		"""
		raise NotImplementedError
	def log_abs_det_jacobian(self, x, y):
		"""
		Computes the log det jacobian `log |dy/dx|` given input and output.
		"""
		raise NotImplementedError
	def forward_shape(self, shape):
		"""
		Infers the shape of the forward computation, given the input shape.
		Defaults to preserving shape.
		"""
		return shape
	def inverse_shape(self, shape):
		"""
		Infers the shapes of the inverse computation, given the output shape.
		Defaults to preserving shape.
		"""
		return shape
附录
1. __debug__ 和 assert (来自 Kimi)
 
__debug__ 是一个内置变量,用于指示 Python 解释器是否处于调试模式。当 Python 以调试模式运行时,__debug__ 被设置为 True;否则,在优化模式下运行时,它被设置为 False。
__debug__ 可以用于条件性地执行调试代码,例如:
if __debug__:
	print("Debug mode is on, performing extra checks...")
	# 这里可以放一些只在调试模式下运行的代码,比如详细的日志记录
	# 或者复杂的验证逻辑
else:
	print("Debug mode is off.")
在上面的例子中,如果命令行执行:
python -O myscript.py
##### output #####
Debug mode is off.
------------------------------------------------------
python myscript.py
##### output #####
Debug mode is on, performing extra checks...
assert 语句受 __debug__ 影响:
def calculate(a, b):
	# 这个 assert 在 __debug__ 为 True 时执行
	assert a > 0 and b > 0, "Both inputs must be positive."
	
	# 正常的函数逻辑
	return a * b
# 在这里,assert 会检查输入是否为正数
result = calculate(5, 3)
print(result)
# 如果我们改变条件使 assert 失败
# result = calculate(-1, 3)  # 这会触发 AssertionError,除非运行时 __debug__ 为 False




![[Halcon学习笔记]Halcon窗口进行等比例显示图像](https://img-blog.csdnimg.cn/img_convert/424adf96e08fefe296458054e277565f.webp?x-oss-process=image/format,png)














