前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose 类和 to_multi 包装函数。不过 [1] 没考虑各图用不同 augmentation 的情况,如:
- ColorJitter 只对 image 做,而不对 label 做;
- image 的 resize interpolation 可任选,但 label 只能用 nearest。
本篇更新写法,支持各图同用、独用 augmentation。
Code
- 对比 [1],主要改变是改写 MultiCompose类,并将to_multi吸收入内。
- MultiCompose的用法还是和- torchvision.transforms.Compose几乎一致,不过支持独用 augmentation:只要为各图指定各自的 augmentation 类/函数即可。见下一节例程。
def to_multi():
	"""不用单独的 to_multi 打包了,已并入 MultiCompose"""
	raise NotImplementedError
class MultiCompose:
    """扩展 torchvision.transforms.Compose:支持输入多图,
    且保证各 augmentation 中所有输入都用同一随机状态(如旋转同一随机角度),
    分割任务有用。
    """
    # numpy.random.seed range error:
    #   ValueError: Seed must be between 0 and 2**32 - 1
    MIN_SEED = 0 # - 0x8000_0000_0000_0000
    MAX_SEED = min(2**32 - 1, 0xffff_ffff_ffff_ffff)
    def __init__(self, transforms):
    	"""输入:一个 list/tuple,
    	其中每个元素可以是一个 augmentation 对象(transform)/函数,各输入同用;
    	或一个嵌套的 list/tuple,为每个输入指定独用的 augmentation。
    	"""
        # self.transforms = [to_multi(t) for t in transforms]
        no_op = lambda x: x # i.e. identity function
        self.transforms = []
        for t in transforms:
            if isinstance(t, (tuple, list)):
            	# convert `None` to `no_op` for convenience
                self.transforms.append([no_op if _t is None else _t for _t in t])
            else:
                self.transforms.append(t)
    def __call__(self, *images):
        for t in self.transforms:
            if isinstance(t, (tuple, list)): # 独用
                assert len(images) <= len(t) # allow redundant transform
            else: # 同用
                t = [t] * len(images)
            _aug_images = []
            _seed = random.randint(self.MIN_SEED, self.MAX_SEED)
            for _im, _t in zip(images, t):
                seed_everything(_seed)
                _aug_images.append(_t(_im))
            images = _aug_images
        if len(images) == 1:
            images = images[0]
        return images
Usage & Test
例程沿用 [1],但改一下 augmentation:
train_trans = MultiCompose([
	# image 用 bilinear,label 用 nearest
    (ResizeZoomPad((224, 256), "bilinear"), ResizeZoomPad((224, 256), "nearest")), # 独用
    transforms.RandomAffine(30, (0.1, 0.1)), # 同用,传一个就行
    transforms.RandomHorizontalFlip(), # 同用
    # ColorJitter 只对 image 做,label 不做(None)
    [transforms.ColorJitter(0.1, 0.2, 0.3, 0.4), None], # 独用
])
- 效果:

References
- pytorch一致数据增强


















