我们将介绍如何用PIL库实现一些简单的图像增强方法。
[!NOTE] 初始化配置
import numpy as np  
from PIL import Image, ImageOps, ImageEnhance  
import warnings
warnings.filterwarnings('ignore')
IMAGE_SIZE = 640
[!important] 辅助函数
主要用于控制增强幅度
def int_parameter(level, maxval):  
	return int(level * maxval / 10)  
  
  
def float_parameter(level, maxval):  
	return float(level) * maxval / 10.
def sample_level(n):  
    return np.random.uniform(low=0.1, high=n)
level用于控制增强方法的数值强度,maxval一般取值为4,level是一个从均匀分布中采样的数值,这样让每次增强都具有随机性。
[!example] 增强方法
色彩反转
def invert(pil_img, _):  
    return ImageOps.invert(pil_img)
镜像
def mirror(pil_img, _):  
    return ImageOps.mirror(pil_img)
均衡化
def equalize(pil_img, _):  
    return ImageOps.equalize(pil_img)
色彩分离
def posterize(pil_img, level):  
    level = int_parameter(sample_level(level), 4)  
    return ImageOps.posterize(pil_img, 4 - level)
旋转
def rotate(pil_img, level):  
    degrees = int_parameter(sample_level(level), 30)  
    if np.random.uniform() > 0.5:  
        degrees = -degrees  
    return pil_img.rotate(degrees, resample=Image.BILINEAR)
Solarize
def solarize(pil_img, level):  
    level = int_parameter(sample_level(level), 256)  
    return ImageOps.solarize(pil_img, 256 - level)
Shear_x
def shear_x(pil_img, level):  
    level = float_parameter(sample_level(level), 0.3)  
    if np.random.uniform() > 0.5:  
        level = -level  
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),  
                             Image.AFFINE, (1, level, 0, 0, 1, 0),  
                             resample=Image.BILINEAR)
Shear_y
def shear_y(pil_img, level):  
    level = float_parameter(sample_level(level), 0.3)  
    if np.random.uniform() > 0.5:  
        level = -level  
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),  
                             Image.AFFINE, (1, 0, 0, level, 1, 0),  
                             resample=Image.BILINEAR)
Translate_x
def translate_x(pil_img, level):  
    level = int_parameter(sample_level(level), IMAGE_SIZE / 3)  
    if np.random.random() > 0.5:  
        level = -level  
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),  
                             Image.AFFINE, (1, 0, level, 0, 1, 0),  
                             resample=Image.BILINEAR)  
Translate_y
def translate_y(pil_img, level):  
    level = int_parameter(sample_level(level), IMAGE_SIZE / 3)  
    if np.random.random() > 0.5:  
        level = -level  
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),  
                             Image.AFFINE, (1, 0, 0, 0, 1, level),  
                             resample=Image.BILINEAR)  
Color
def color(pil_img, level):  
    level = float_parameter(sample_level(level), 1.8) + 0.1  
    return ImageEnhance.Color(pil_img).enhance(level)  
Contrast
def contrast(pil_img, level):  
    level = float_parameter(sample_level(level), 1.8) + 0.1  
    return ImageEnhance.Contrast(pil_img).enhance(level)  
AutoContrast
def autocontrast(pil_img, level):  
    level = float_parameter(sample_level(level), 10)  
    return ImageOps.autocontrast(pil_img, 10 - level)  
Brightness
def brightness(pil_img, level):  
    level = float_parameter(sample_level(level), 1.8) + 0.1  
    return ImageEnhance.Brightness(pil_img).enhance(level)  
Sharpness
def sharpness(pil_img, level):  
    level = float_parameter(sample_level(level), 1.8) + 0.1  
    return ImageEnhance.Sharpness(pil_img).enhance(level)
[!success] 使用案例
对于这样一张原图:

augmentations_all = {  
    "autocontrast":autocontrast,  
    "equalize":equalize,  
    "posterize":posterize,  
    "rotate":rotate,  
    "solarize":solarize,  
    "shear_x":shear_x,  
    "shear_y":shear_y,  
    "translate_x":translate_x,  
    "translate_y":translate_y,  
    "color":color,  
    "contrast":contrast,  
    "brightness":brightness,  
    "sharpness":sharpness,  
    "mirror":mirror,  
    "invert":invert  
    }  
  
import matplotlib.pyplot as plt  
  
img=Image.open(r"C:\Users\Administrator\Downloads\result1.5\result\original_resized\class0\0.jpg")  
def draw(plt,idx,img,title):  
    plt.subplot(int("24"+str(idx)))  
    plt.imshow(img)  
    plt.xticks([])  
    plt.yticks([])  
    plt.title(title)  
  
plt.figure(figsize=(20,16))  
for idx,(k,v) in enumerate(augmentations_all.items()):  
    draw(plt,(idx)%8+1,v(img.copy(),1),k)  
    if idx!=0 and idx % 7 == 0:  
        plt.show()  
        plt.figure(figsize=(20,16))

 



















