抽象方法和实例调用方法
对比表格:
特性 | 抽象方法 (forward) | 实例调用方法 (call) |
---|---|---|
定义方式 | @abc.abstractmethod 装饰器 | 特殊方法名 __call__ |
调用方式 | 不能直接调用,必须通过子类实现 | 可以直接调用对象:controller(attn, ...) |
实现要求 | 必须由子类实现,否则无法实例化 | 可以在基类中实现,子类可以覆盖 |
主要功能 | 定义具体业务逻辑(如存储注意力) | 处理通用流程(如计数、步骤控制) |
代码位置 | 在子类中实现(如 AttentionStore ) | 在基类中实现(AttentionControl ) |
使用场景 | 需要子类提供不同实现时 | 需要统一接口时 |
错误处理 | 在类定义时检查 | 在运行时检查 |
代码复用 | 每个子类都需要实现 | 所有子类共享基类实现 |
以下面代码片段为例
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
return 0
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
h = attn.shape[0]
attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
-
抽象方法:
- 是接口定义
- 强制子类实现
- 提供具体功能
-
实例调用方法:
- 是统一接口
- 处理通用逻辑
- 调用抽象方法
具体的调用差异
forward
和 __call__
的区别和调用时机:
- 基本区别:
# forward 是抽象方法,需要子类实现
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
# __call__ 是实例调用方法,可以直接调用对象
def __call__(self, attn, is_cross: bool, place_in_unet: str):
# 调用 forward 并处理其他逻辑
if self.cur_att_layer >= self.num_uncond_att_layers:
h = attn.shape[0]
attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
# ... 其他逻辑
- 调用时机:
__call__
的调用:
controller = AttentionStore()
# 直接调用对象时会触发 __call__
result = controller(attn, is_cross=True, place_in_unet="down")
forward
的调用:
# forward 是在 __call__ 内部被调用的
# 不会直接调用 forward
- 工作流程:
# 1. 当调用 controller(attn, ...) 时,会触发 __call__
# 2. __call__ 方法会:
# - 检查条件
# - 调用 forward
# - 更新计数器
# - 处理步骤间逻辑
# 3. forward 方法(由子类实现)处理具体的注意力计算
- 实际应用示例:
# 在 AttentionStore 中
def forward(self, attn, is_cross: bool, place_in_unet: str):
# 存储注意力值
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[1] <= 32**2:
self.step_store[key].append(attn)
return attn
# 在 AttentionReplace 中
def forward(self, attn, is_cross: bool, place_in_unet: str):
# 替换注意力值
return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
- 为什么这样设计:
a) 分离关注点:
__call__
处理通用逻辑(计数、步骤等)forward
处理具体实现(存储、替换等)
b) 提供统一接口:
- 所有控制器都可以像函数一样调用
- 内部实现可以不同
c) 代码复用:
- 通用逻辑在
__call__
中实现一次 - 子类只需要实现
forward
- 调用顺序:
# 1. 外部调用
controller(attn, is_cross=True, place_in_unet="down")
# 2. 触发 __call__
def __call__(self, attn, is_cross: bool, place_in_unet: str):
# 3. 条件检查
if self.cur_att_layer >= self.num_uncond_att_layers:
# 4. 调用 forward
attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
# 5. 处理其他逻辑
这种设计的好处是:
- 提供了清晰的接口
- 分离了通用逻辑和具体实现
- 方便代码复用和维护
- 使代码结构更加清晰
所以,在这个例子中 forward
和 __call__
的区别在于:
__call__
是外部接口,处理通用逻辑forward
是内部实现,处理具体功能- 它们共同工作,提供了灵活且统一的接口