目录

一、概述

  • Meta Learning = Learn to learn

    让机器去学习如何进行学习:使用一系列的任务来训练模型,模型根据在这些任务上汲取的经验,成为了一个强大的学习者,能够更快的学习新任务。

  • Meta Learning VS Lifelong Learning

    • 终身学习:着眼于用同一个模型去学习不同的任务。
    • 元学习:不同任务使用不同的模型,元学习者积累经验后,在新任务上训练的更快更好。
  • Meta Learning VS Machine Learning

    • 机器学习:核心是通过人为设计的学习算法(Learning Algorithm),利用训练数据训练得到一个函数f,这个函数可以用于新数据的预测分类。

    • 元学习:让机器自己学习找出最优的学习算法。根据提供的训练数据找到一个可以找到函数f的函数F的能力。

二、元学习的实现框架

  1. 定义一系列的学习算法

    不同的网络结构、参数初始化策略、参数更新策略决定了不同学习算法。

  2. 定义学习算法函数F的评价标准

    综合考虑学习算法F针对不同任务产生的函数f在进行测试时得到的损失。

  3. 选取最好的学习算法F*=argminL(F)

    最佳学习算法一般可以通过梯度下降方法来确定。

三、元学习的训练数据

  • 机器学习

    机器学习的训练数据和测试数据来自同一分布的数据集。

  • 元学习:

    • 元学习的训练数据是由一个个的训练任务构成的,一个训练任务对应一个传统的机器学习的应用实例。

      • 需要大批次数据的训练任务显然难以进行元学习训练,因此常规的元学习的训练任务一般是Few Shot Learning类型的任务,即通过少量数据就能构建一个任务,进行快速的学习与训练。

        • 考虑到运算性能,现阶段的元学习经常是与Few Shot Learning绑定在一起。
    • 训练数据分为训练任务集和测试任务集。

    • 任务集中的每一个任务的训练数据即传统的机器学习应用实例中的训练数据集和测试数据集,不过为了区分训练任务(Training Set)和测试任务(Testing Test) ,这里将它们命名为支持集(Support Set)和查询集(Query Set).

四、元学习的Benchmarks

  • Omniglot数据集

    • 组成

      • 整个数据集由1623个符号(Characters)组成;

      • 每个符号有20个样例(Examples),每个样例由不同的人书写.

    • 使用:结合Few-shot Learning中的N-ways K-shot分类问题

      • 对于每一个训练任务和测试任务,样本数据分为N个类,每个类提供K个样本。
      • 整个字符集分为训练字符集(Training Set or Support Set)和测试字符集(Testing Set or Query Set)
      • 训练任务:从训练字符集中抽取N个类的字符,每种字符抽取K个样本,组成一个训练任务的训练数据
      • 测试任务:从测试字符集中抽取N个类的字符,每种字符抽取K个样本,组成一个测试任务的训练数据
  • MAML

    Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017: 1126-1135.

    • 损失函数 (Loss Function):(L(Phi)= sum_{n=1}^N{l^n(hat{theta}^n)})

      • (hat{theta}^n):第n个任务中学习到的模型参数,取决于参数(Phi)
      • (l^n(hat{theta}^n)):第n个任务在其测试集上得到的损失。
    • 损失函数最小化:使用梯度下降(Gradient Descent)

      [PhileftarrowPhi-etanabla_Phi{L(Phi)}]

    • 只考虑一次训练之后对初始化参数的梯度更新。

      • 只取进行一次梯度更新后的参数作为当前任务的最佳参数。
      • 上式求出的是元学习模型的通用参数,下式求出的是每个任务的最佳参数。
      • (L(Phi))(hat{theta})用于元学习模型的参数更新。
      • 既能加快模型的适应速度,在一定程度上还能减轻过拟合。、

      [hat{theta}=Phi-epsilonnabla_Phi{l(Phi)}]

    • 整体执行流程:

      1. 将每一个训练任务和测试任务的模型参数初始化:(Phi_0)
      2. 对每一个任务执行一次梯度更新得到新的模型参数:(hat{theta})
      3. 综合考虑所有训练任务在(hat{theta})下的损失:(L(Phi))
      4. (L(Phi))执行梯度更新,得到最优的元学习模型的参数:(Phi)
      5. 将该(Phi)用于测试任务,检验更新效果。
    • 二阶微分与一阶近似(数学推导):

      • 训练过程的参数更新公式如下:

      [ PhileftarrowPhi-etanabla_Phi{L(Phi)} \ L(Phi)= sum_{n=1}^N{l^n(hat{theta}^n)} \ hat{theta}=Phi-epsilonnabla_Phi{l(Phi)} ]

      • $ nabla_Phi{L(Phi)} $的计算
        [ nabla_Phi{L(Phi)}=nabla_Phi{sum_{n=1}^{N}l^n(hat{theta}^n)}=sum_{n=1}^{N}nabla_Phi{l^n(hat{theta}^n)} \ ]

        • 其中(nabla_Phi{l^n(hat{theta}^n)})为:
          [ nabla_Phi{l(hat{theta})}=left| begin{matrix} partial l(hat{theta})/partial Phi_1\ partial l(hat{theta})/partial Phi_2\ vdots\ partial l(hat{theta})/partial Phi_i\ end{matrix} right| ]

          (Phi_i)表示模型的各个参数(Weight),(Phi_i)决定当前任务的(hat{theta})的第j个参数(hat{theta}_j),从而影响(l(hat{theta}))

          • 根据三者之间的关系:(Phi_i rightarrow hat{theta}_j rightarrow l(hat{theta})),有:
            [ frac{partial l(hat{theta})}{partial Phi_i}=sum_jsum_i {frac{partial l(hat{theta})}{partial hat{theta}_j}frac{partial hat{theta}_j}{partial hat{Phi}_i}} ]

          • 又因为根据参数更新公式(3),取(hat{theta})的第j维为例,有:
            [ hat{theta}_j=Phi_j-epsilon frac{partial l(Phi)}{partial Phi_j} ]

          • (hat{theta}_j)(Phi_j)的偏导,有:
            [ frac{partial hat{theta}_j}{partial hat{Phi}_i}= begin{cases} -epsilon frac{partial l(Phi)}{partial Phi_j partial Phi_i},i neq j\ 1-epsilon frac{partial l(Phi)}{partial Phi_j partial Phi_i},i = j end{cases} ]

            将该式代回到(frac{partial l(hat{theta})}{partial Phi_i})中即可求出(nabla_Phi{L(Phi)})

            但实际上该式存在二次微分的计算,会极大的影响运算效率。

          • 作者用一次微分来近似代替二次微分的结果:
            [ frac{partial hat{theta}_j}{partial hat{Phi}_i}= begin{cases} -epsilon frac{partial l(Phi)}{partial Phi_j partial Phi_i} approx{0} ,i neq j\ 1-epsilon frac{partial l(Phi)}{partial Phi_j partial Phi_i} approx{1},i = j end{cases} ]

          • 所以
            [ frac{partial l(hat{theta})}{partial Phi_i}=sum_jsum_i {frac{partial l(hat{theta})}{partial hat{theta}_j}frac{partial hat{theta}_j}{partial hat{Phi}_i}} approx frac{partial l(hat{theta})}{partial hat{theta}_i}\ nabla_Phi{l(hat{theta})}=left| begin{matrix} partial l(hat{theta})/partial Phi_1\ partial l(hat{theta})/partial Phi_2\ vdots\ partial l(hat{theta})/partial Phi_i\ end{matrix} right|=left| begin{matrix} partial l(hat{theta})/partial hat{theta}_1\ partial l(hat{theta})/partial hat{theta}_2\ vdots\ partial l(hat{theta})/partial hat{theta}_i\ end{matrix} right|=nabla_hat{theta}{l(hat{theta})} ]

      • 所以$ nabla_Phi{L(Phi)} $可以化为:
        [ nabla_Phi{L(Phi)}=nabla_Phi{sum_{n=1}^{N}l^n(hat{theta}^n)}=sum_{n=1}^{N}nabla_Phi{l^n(hat{theta}^n)}=sum_{n=1}^{N}nabla_{hat{theta}^n}{l^n(hat{theta}^n)} ]

        通过将二阶微分近似为一阶微分,提升运算效率的同时对模型预测的准确率没有太大的影响。

  • Reptile

    Nichol A, Achiam J, Schulman J. On first-order meta-learning algorithms[J]. arXiv preprint arXiv:1803.02999, 2018.

    • 基本思想

      • 基于MAML进行改善,对参数更新次数不加限制。

    • Reptile VS Pretraining VS MAML

内容来源于网络如有侵权请私信删除
你还没有登录,请先登录注册
  • 还没有人评论,欢迎说说您的想法!