文章目录
- 一、简介
- Task: Few-shot Classification
 
- 实验
- 1、simple
- 2、medium
- 3、strong
- 4、boss
 
- 三、代码
- 模型构建准备工作
 
 
一、简介
任务对象是Omniglot数据集上的few-shot classification任务,内容是利用元学习找到好的初始化参数。
Task: Few-shot Classification
The Omniglot dataset
 
 Omniglot数据集-背景集: 30个字母 -评估集: 20个字母
 问题设置: 5-way 1-shot classification
 
 Training MAML on Omniglot classification task.
 
 Training / validation set:30 alphabets
- multiple characters in one alphabet
- 20 images for one character
  
 Testing set:
 640 support and query pairs
- 5 support images
- 5 query images
  
实验
1、simple
简单的迁移学习模型
 训练:对随机选择的5个任务进行正常的分类训练
 验证和测试:对这5个支持图像进行微调,并对查询图像进行推理
 

2、medium
完成元学习内部和外部循环的TODO块,使用FO-MAML。设置solver = ‘meta’,epoch调节为120。FOMAML是MAML的简化版本,可节省训练时间,它忽略了内循环梯度对结果的影响。
# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values())
fast_weights = OrderedDict((name, param - inner_lr*grad)
      for ((name, param), grad) in zip(fast_weights.items(), grads)
      )    
#raise NotImplementedError训练过程中需要设置该函数为损失函数
# TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()
#raise NotimplementedError


 
3、strong
使用MAML,可以计算更高阶的梯度,MAML就能用到内循环梯度的梯度 。
# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = OrderedDict((name, param - inner_lr*grad)      
    for ((name, param), grad) in zip(fast_weights.items(), grads)
    )    
#raise NotImplementedError训练过程中需要设置该函数为损失函数

 
4、boss
任务增强(通过元学习)-什么是合理的方法来创建新任务?
 使用了task augmentation的方法来增加训练任务的变化性,有40%的可能性做augmentation,旋转90度或270度。
 
#MetaSolver函数中修改
for meta_batch in x:
    # Get data
    if torch.rand(1).item() > 0.6:
        times = 1 if torch.rand(1).item() > 0.5 else 3
        meta_batch = torch.rot90(meta_batch, times, [-1, -2])

 
三、代码
模型构建准备工作
由于我们的任务是图像分类,我们需要建立一个基于CNN的模型。但是,要实现MAML算法,我们需要调整“nn.Module”中的一些代码。在第10行,我们采用的梯度是代表原始模型参数(外环)的θ,而不是内环中的θ,因此我们需要使用functional_forward来计算输入图像的输出逻辑,而不是在nn.Module中使用forward。下面定义了这些功能。
def functional_forward(self, x, params):
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x
创建labels for 5-way 2-shot
def create_label(n_way, k_shot):
    return torch.arange(n_way).repeat_interleave(k_shot).long()
# Try to create labels for 5-way 2-shot setting
create_label(5, 2)
计算精度
def calculate_accuracy(logits, labels):
    """utility function for accuracy calculation"""
    acc = np.asarray(
        [(torch.argmax(logits, -1).cpu().numpy() == labels.cpu().numpy())]
    ).mean()
    return acc
求解器首先从训练集中选择五个任务,然后对选择的五个任务进行正常的分类训练。在推理中,模型在支持集图像上对inner_train_step步骤进行微调,然后在查询集图像上进行推理。为了与元学习解算器保持一致,基本解算器具有与元学习解算器完全相同的输入和输出格式。
def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []
    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]
        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())
            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)
                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )
            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)
                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())
    if return_labels:
        return labels
    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)
    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
    return batch_loss, task_acc
元学习
def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []
    for meta_batch in x:
        # Get data
        if torch.rand(1).item() > 0.6:
            times = 1 if torch.rand(1).item() > 0.5 else 3
            meta_batch = torch.rot90(meta_batch, times, [-1, -2])#  B = rot90(A,k) 将数组 A 按逆时针方向旋转 k*90 度
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]
        # Copy the params for inner loop
        fast_weights = OrderedDict(model.named_parameters())
        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):
            # Simply training
            train_label = create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss = criterion(logits, train_label)
            # Inner gradients update! vvvvvvvvvvvvvvvvvvvv #
            """ Inner Loop Update """
            # TODO: Finish the inner loop update rule
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            fast_weights = OrderedDict((name, param - inner_lr*grad)
                                        for ((name, param), grad) in zip(fast_weights.items(), grads)
                                        )
            
            #raise NotImplementedError
            # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #
        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:
            """ training / validation """
            val_label = create_label(n_way, q_query).to(device)
            # Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss = criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))
        else:
            """ testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits, -1).cpu().numpy())
    if return_labels:
        return labels
    # Update outer loop
    model.train()
    optimizer.zero_grad()
    meta_batch_loss = torch.stack(task_loss).mean()
    if train:
        """ Outer Loop Update """
        # TODO: Finish the outer loop update
        meta_batch_loss.backward()
        optimizer.step()
        #raise NotimplementedError
    task_acc = np.mean(task_acc)
    return meta_batch_loss, task_acc



















