【Deep Learning】Meta-Learning:训练训练神经网络的神经网络

元学习:训练训练神经网络的神经网络

本文基于清华大学《深度学习》第12节《Beyond Supervised Learning》的内容撰写,既是课堂笔记,亦是作者的一些理解。

1 Meta-Learning

在经典监督学习中,给定训练数据 { ( x i , y i ) } i \{(x_i,y_i)\}_i {(xi,yi)}i,我们需要训练一个神经网络 f f f使得 f ( x i ) = y i f(x_i)=y_i f(xi)=yi,而在测试的时候给一组新的 { x i } i \{x_i\}_i {xi}i,我们需要准确预测这组 { x i } i \{x_i\}_i {xi}i对应的 y y y.

在元学习中,问题变得不太一样。

训练数据 D \mathcal{D} D是一个任务 T T T的集合

  • 其中 T = { S T , B T } = { ( x i S , y i S ) i , ( x j B , y j B ) j } T=\{S_T,B_T\}=\{(x_i^S,y_i^S)_i,(x_j^B,y_j^B)_j\} T={ST,BT}={(xiS,yiS)i,(xjB,yjB)j}
  • T T T S T S_T ST B T B_T BT组成, S T = { ( x i S , y i S ) } i S_T=\{(x_i^S,y_i^S)\}_i ST={(xiS,yiS)}i是一些训练样例, B T = { ( x j B , y j B ) } j B_T=\{(x_j^B,y_j^B)\}_j BT={(xjB,yjB)}j是测试样例。
  • 也就是说训练数据中的每个任务是一个独立的问题,这个问题由一些训练样例和测试样例刻画。

在测试的时候,模型将会获得一个任务 T T T S T S_T ST,并希望回答 B T B_T BT中的问题,即返回 y j B y_j^B yjB给定 x j B x_j^B xjB。我们希望训练一个模型使得它能够快速地通过 S T S_T ST中的几个样例回答 B T B_T BT的问题,尽管测试时的 T T T和训练时的 T T T可能不太一样。

  • 举个例子,如下图所示,在训练的时候一个任务是分辨猫和鸟,一个任务是分辨花和单车,而在测试的时候任务就变成了分辨狗和水獭。这要求模型学会在非常有限的样例中学会回答问题。
  • 又比如:你有很多门考试,你需要强化自己的能力使得你在一门新的科目上只需要看几张考卷就能够回答问题。

这就是元学习(Meta Learning)。这种少样本的泛化能力也被称作Few-shot Learning.
在这里插入图片描述

接下来将介绍几个元学习的模型。

1.1 度量学习 Metric Learning

度量学习(Metric Learning)是使用最近邻居分类(Nearest Neighbor Classifier)来完成上述的问题。

  • 最近邻居分类:对于一个测试样例,在训练数据中找到与它最“像”的样例,输出这个最“像”的样例的标签。
    • 例如:在猫和鸟的分类中,给一张图片,看看在训练数据中与它最像的图片是猫还是鸟。

如何评判最“像”?学习一个度量(Metric) D ( x , x ′ ) D(x,x') D(x,x).

  • 例如:欧几里得距离(越小代表 x x x x ′ x' x越像)

Siamese Nerual Network[1] 在2015年就使用了度量学习以解决少样本图像分类的问题:

  • 度量被定义为 f ( x , y ) = σ ( W ∣ ϕ ( x ) − ϕ ( y ) ∣ ) f(x,y)=\sigma(W|\phi(x)-\phi(y)|) f(x,y)=σ(Wϕ(x)ϕ(y))
  • 这里 ϕ ( x ) \phi(x) ϕ(x)将原图映射到了表征空间中,过了一个权重矩阵 W W W后经过sigmoid输出 x , y x,y x,y标签相同的概率。
  • 在一个任务中,定义 ( x i S , y i ) (x_i^S,y_i) (xiS,yi)为第 i i i个训练样例,对于测试样例 x x x,它将被分类为 y = y arg ⁡ max ⁡ i f ( x , x i S ) y=y_{\arg\max_if(x,x_i^S)} y=yargmaxif(x,xiS).

Matching Network[2] 在2016年引入了注意力机制(Attention)将最近邻居分类进行了加权:

  • 与上一个网络类似,对于测试样例 x x x和训练样例 x i S x_i^S xiS,分别计算它们的表征向量(embedding) f θ ( x ) f_\theta(x) fθ(x) g θ ( x i S ) g_\theta(x_i^S) gθ(xiS)
  • 在分类的时候不是直接 arg ⁡ max ⁡ \arg\max argmax,而是根据 α i = s o f t m a x ( g i ⊤ f ) \alpha_i=\mathrm{softmax}(g_i^\top f) αi=softmax(gif)计算 x i S x_i^S xiS的权重,最终分类结果为 y = ∑ i α i y i y=\sum_{i}\alpha_iy_i y=iαiyi(这里可以简单假设 y i ∈ [ 0 , 1 ] y_i\in[0,1] yi[0,1]为二分类问题)。
  • 可以将 f θ f_\theta fθ用LSTM模型计算,即 f θ = f θ ( x , S ) = L S T M θ ( S , x ) f_\theta=f_\theta(x,S)=LSTM_\theta(S,x) fθ=fθ(x,S)=LSTMθ(S,x),这样可以使 f θ f_\theta fθ对于不同的任务有不一样的结果(增强了表达能力)。对 g θ g_\theta gθ同理。

1.2 贝叶斯推断 Bayesian Inference

贝叶斯推断(Bayesian Inference)的核心思路在于学习一个概率模型 P θ ( x , y ) P_\theta(x,y) Pθ(x,y)表示 x x x(输入样例)和 y y y(样例的标签)的联合分布。 P θ ( x , y ) P_\theta(x,y) Pθ(x,y)具有一些简单的结构使得我们能够在上面计算一些概率(例如:一个有向无环图,每个节点是一个事件,每个边上表示一些条件概率)。

在训练的时候希望最大化 ∑ i log ⁡ P θ ( x i S , y i S ) \sum_i\log P_\theta(x^S_i,y^S_i) ilogPθ(xiS,yiS)。在测试的时候找到满足后验概率 P θ ( y ∣ x , S ) P_\theta(y|x,S) Pθ(yx,S)最大的 y y y

构建贝叶斯模型的通用方法是Probabilistic Programming(如:Church)。

例如在Lake在2015年的 一篇工作[3] 就使用概率模型实现了人类水平的概念学习。以手写字符作为例子(如下图所示),简单来说:

  • 每个字符可以分解成一些基本部件(如线条、曲线)

  • 每个部件组合成字符的不同部分(如数字"3"由两个向右的半圆组成)

  • 不同部件之间有连接关系(如写数字"3"的时候先写上面的半圆、再写下面的半圆)

  • 概率模型就是通过建模这个写字的过程一步步地"写"出了一个字

    在这里插入图片描述

1.3 序列模型 Sequence Model

一个简单的想法是直接把所有的训练数据排成一个序列丢给递归神经网络(Recurrent Nerual Network),然后直接输出测试样例的标签,即学习
f θ ( x ∣ S ) = L S T M ( x 1 S , y 1 S , . . . , x N S , y N S , x ) f_\theta(x|S)=LSTM(x_1^S,y_1^S,...,x_N^S,y_N^S,x) fθ(xS)=LSTM(x1S,y1S,...,xNS,yNS,x)
最早由Memory-Augmented Neural Networks (DeepMind, ICML 2016)[4] 提出了这种模型。

事实上,将训练数据排成序列的方法也能够直接应用在GPT等大语言模型(同时也是序列模型),这使得大语言模型具备一定的少样本学习的能力。

1.4 梯度下降 Gradient Descent

在模型训练的时候我们都要使用梯度下降,但在少样本学习的时候训练数据太少,仅仅基于很少的样本难以训练一个模型。

如果我们想基于新任务非常有限的样本进行梯度下降,初始的网络就非常关键。我们希望先在很多任务上训一个好的神经网络,在新的少样本任务中只需要进行很少的梯度下降(即微调)就能够达到很好的效果。

  • 这就像一个人已经见过很多的任务、学到很多知识了。此时他面对新的任务时只需要再稍微学习一下就可以了。

模型无关元学习(Model-Agnostic Meta-Learning, MAML)[5] 首次提出了基于一个好的初始参数,在新的任务下再进行梯度下降以微调的方法。

  • 在训练的时候,如何才能找到一个好的初始化参数呢?
  • 答:找到一个参数使得所有任务经过几步梯度下降就能有好的结果。
  • 多步随机梯度下降(SGD)太难计算了,怎么求导?
  • 答:只进行一步随机梯度下降,这样模型的计算就是固定的了。

2020年,Sun进一步提出了测试时训练(Test Time Training)[6] 的方法:

  • 在过去的任务中,模型仅仅通过在训练任务上进行训练,然后固定参数再完成测试任务。
  • 在这个工作中提出了在完成测试任务的时候进行自监督学习以加强模型能力的方法。
  • 打个比方:
    • 过去的模型先是做了很多科目的考试,然后再去做另一个科目的测试,但是在测试的时候会忘记每一道刚刚回答的题目。
    • 而现在的模型在做新的科目的测试时,它将每一道题的题面也加入到了训练中,相当于是它能够从考卷中也学习到东西(尽管没有标准答案,也能通过自监督学习的方式“自己考自己”),这样模型就变得更强了。

2 Learning to Learn

在元学习中另一个有趣的话题是如何Learn to learn,即通过神经网络来学习如何学习一个神经网络(套娃),例如学一个网络的结构、学一些超参数等等。

2.1 学习梯度下降

《Learning to learn by gradient descent by gradient descent》[7]首次提出了用一个神经网络(元优化器)来学习如何优化另一个神经网络(任务网络)的参数。元优化器通过观察任务网络的梯度信息来调整其参数,从而实现更高效的学习和收敛。具体来说:

  • 学了一个LSTM模型的优化器(optimizer) m ϕ ( ∇ θ ) m_\phi(\nabla _\theta) mϕ(θ)负责更新 θ \theta θ
    • θ k + 1 = θ k + m ϕ ( ∇ θ ) , ∇ θ = ∂ L ( X , Y , θ ) ∂ θ \theta^{k+1}=\theta^k+m_\phi(\nabla_\theta),\nabla _\theta=\frac{\partial L(X,Y,\theta)}{\partial \theta} θk+1=θk+mϕ(θ),θ=θL(X,Y,θ)

Google Brain于2017年在《Learned Optimizers that Scale and Generalize》[8] 进一步优化了基于神经网络的优化器的架构。

Quoc Le于2017年在《Neural Optimizer Search with Reinforcement Learning》[9]中将强化学习应用在优化器上,在自然语言处理任务与图像识别任务上表现很好。

2.2 学习损失函数

我们甚至可以使用强化学习来学习损失函数。

  • RL^2(OpenAI, 2017)[10] 使用LSTM来评估当前状态、调整策略,使内层强化学习代理能够在新任务上快速适应和学习。
  • Evolved Policy Gradient (OpenAI, 2018)[11] 使用进化算法来学习强化学习的损失函数,损失函数通过神经网络表示。
  • Meta-Gradient Reinforcement Learning (DeepMind 2020)[12] 提出了一种元梯度强化学习(Meta-Gradient Reinforcement Learning, MGRL)的方法,该方法通过优化元参数(meta-parameters)来提高强化学习算法的表现。

2.3 学习数据集

我们甚至可以学习数据集!

Dataset Distillation (Wang et al, MIT & Berkeley, 2018)[13] 提出了基于梯度下降的数据蒸馏的框架:

  • 对于一个很大的数据集和一个模型,我们可以基于这个数据集训练出一个很小的数据集,使得这个很小的数据集和很大的数据集在这个模型上的效果差不多。这有利于加快模型的训练。
  • 如下图所示,我们可以从MNIST数据集中合成出10张图片就可以训练出 94 % 94\% 94%正确率的模型
    • 尽管合成的数据“看起来”并没有什么意义。

在这里插入图片描述

2.4 学习神经网络结构

我们还可以通过神经网络学习神经网络的架构!

NAS-RL(Zoph and Le, ICLR 2017)[14] 提出了一种基于强化学习(Reinforcement Learning, RL)的神经架构搜索(Neural Architecture Search, NAS)方法。该方法通过训练一个RNN控制器来自动生成神经网络架构,并使用强化学习算法优化控制器,使其生成的架构在目标任务上表现最佳。

3 总结

本文简单介绍了两类与Learn to Learn有关的问题

  • 元学习(Meta-Learning)或少样本学习(Few-Shot Learning)
    • 少样本学习在大模型强大的Few-Shot能力下已经成为主流,在很多下游任务上大放异彩。
  • Learn to learn nerual network by nerual network:即通过神经网络来优化训练神经网络的过程
    • 由于效率不高,再套一层娃(Learn to learn to learn) 的效果也并不好。
    • 由于Learn to learn复杂的结构和极高的算力要求,如今主流的神经网络并没有采用此类的方法。

Reference

[1]Koch, G., Zemel, R., & Salakhutdinov, R. (2015). Siamese Neural Networks for One-shot Image Recognition. In Proceedings of the 32nd International Conference on Machine Learning (ICML).

[2]Vinyals, O., Blundell, C., Lillicrap, T., Kavukcuoglu, K., & Wierstra, D. (2016). Matching Networks for One Shot Learning. In Advances in Neural Information Processing Systems (NIPS).

[3]Lake, B. M., Salakhutdinov, R., & Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.

[4]Santoro, A., Bartunov, S., Botvinick, M., Wierstra, D., & Lillicrap, T. (2016). Meta-Learning with Memory-Augmented Neural Networks. In Proceedings of the 33rd International Conference on Machine Learning (ICML).

[5]Finn, C., Abbeel, P., & Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. In Proceedings of the 34th International Conference on Machine Learning (ICML).

[6]Sun, Y., Liu, X., Harada, T. (2020). Test-Time Training with Self-Supervised Learning. In Proceedings of the 37th International Conference on Machine Learning (ICML).

[7]Andrychowicz, M., Denil, M., Gomez, S., Hoffman, M. W., Pfau, D., Schaul, T., … & de Freitas, N. (2016). Learning to learn by gradient descent by gradient descent. In Advances in Neural Information Processing Systems (NIPS).

[8]Wichrowska, O., Maheswaranathan, N., Hoffman, M. W., Colmenarejo, S. G., Denil, M., Freitas, N. D., & Sohl-Dickstein, J. (2017). Learned Optimizers that Scale and Generalize. In Proceedings of the 34th International Conference on Machine Learning (ICML).

[9]Bello, I., Pham, H., Le, Q. V., Norouzi, M., & Bengio, S. (2017). Neural Optimizer Search with Reinforcement Learning. In Proceedings of the 34th International Conference on Machine Learning (ICML).

[10]Duan, Y., Schulman, J., Chen, X., Bartlett, P., Sutskever, I., & Abbeel, P. (2017). RL^2: Fast Reinforcement Learning via Slow Reinforcement Learning. In Proceedings of the 34th International Conference on Machine Learning (ICML).

[11]Houthooft, R., Chen, X., Duan, Y., Schulman, J., De Turck, F., & Abbeel, P. (2018). Evolved Policy Gradients. In Proceedings of the 35th International Conference on Machine Learning (ICML).

[12]Xu, Z., van Hasselt, H., & Silver, D. (2020). Meta-Gradient Reinforcement Learning. In Advances in Neural Information Processing Systems (NeurIPS).

[13]Wang, T., Zhu, J.-Y., Torralba, A., & Efros, A. A. (2018). Dataset Distillation. In Proceedings of the 35th International Conference on Machine Learning (ICML).

[14]Zoph, B., & Le, Q. V. (2017). Neural Architecture Search with Reinforcement Learning. In Proceedings of the 5th International Conference on Learning Representations (ICLR).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/753087.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

cython 笔记

数据类型 # bool 类型 // bool_type_ptactice.pyx cdef bint a 123 # 非0 为 真 , 0 为假 cdef bint b -123 cdef bint c 0 py_a a # cdef 定义的内容没法直接在python中直接引用 py_b b py_c c// main.py import pyximport pyximport.install(language_le…

超详细之IDEA上传项目到Gitee完整步骤

1. 注册gitee 账号密码,gitee官网地址:Gitee官网,注册完成后,登录。 2. 创建仓库,在主页左下角有新建按钮,点击新建后会进入到此页面填写仓库信息。 3. 创建完成后复制仓库地址 4. 打开IntelliJ IDEA新建或…

Python 语法基础一

1.变量 python 中变量很简单,不需要指定数据类型,直接使用等号定义就好。python变量里面存的是内存地址,也就是这个值存在内存里面的哪个地方,如果再把这个变量赋值给另一个变量,新的变量通过之前那个变量知道那个变量…

新品Coming Soon!OAK-D-SR-PoE:使用3D+AI视觉结合ToF实现箱体测量和鉴别!

OAKChina 新品:OAK-D SR PoE结合ToF实现箱体检测 3DAI解决方案提供商 手动测量箱体、缺陷、大小等操作可能是一项繁琐并且劳累而机械的任务,但OAK中国本次将提供了更好的解决方案:3DAI视觉处理箱体的识别和检测,使用了即将发布的…

在Ubuntu上安装VNC服务器教程

Ubuntu上安装VNC服务器方法:按照root安装TeactVnc,随后运行vncserver输入密码,安装并打开RickVNC客户端,输入服务器的IP,最后连接输入密码即可。 VNC或虚拟网络计算,可让您连接到远程Linux / Unix服务器的…

力扣 刷题 使用双指针进行数组去重分析

目录 双指针 一、26.删除有序数组中的重复项 题目 题解 二、80. 删除有序数组中的重复项 II 题目 题解 三、27. 移除元素 题目 题解 双指针 我们这里所说的双指针实际上并不是真正的指针,它只是两个变量,用于标识数组的索引等,因其…

基于AiService实现智能文章小助手

顾名思义,这个应用就是希望能利用大模型的能力来帮助我写文章,那这样一个应用该如何利用LangChain4j来实现呢?接下来我们来利用AiService进行实现。 AiService代理 首先,我们定义一个接口Writer,表示作家&#xff1a…

高质量AIGC/ChatGPT/大模型资料分享

2023年要说科技圈什么最火爆,一定是ChatGPT、AIGC(人工智能生成内容)和大型语言模型。这些技术前沿如同科技世界的新潮流,巨浪拍岸,引发各界关注。ChatGPT的互动性和逼真度让人们瞠目,它能与用户展开流畅对…

谷歌如何进行失效链接建设?

失效链接建设是一种高效的外链建设策略,通过发现并利用失效链接来提升自己网站的SEO。以下是详细的步骤: 寻找失效页面:你需要使用SEO工具,如Ahrefs,来查找与你的网站内容相关的失效页面。这些页面可能是竞争对手的失…

Vue项目生产环境的打包优化

Vue项目生产环境的打包优化 前言 在这篇文章我们讨论Vue项目生产环境的打包优化,并按步骤展示实际优化过程中的修改和前后对比。 背景 刚开始的打包体积为48.71M 优化 步骤一:删除viser-vue viser-vue底层依赖antv/g2等库一并被删除,…

【EI会议】2024年机械、计算机工程与材料国际会议 (MCEM 2024)

2024年机械、计算机工程与材料国际会议 (MCEM 2024) 2024 International Conference on Mechanical, Computer Engineering and Materials 【重要信息】 大会地点:广州 官网地址:http://www.ismcem.com 投稿邮箱:ismcemsub-conf.com 【注意…

《XR应用开发者头显运行需求调研报告》重磅发布 ,开发者更加关注集成和可扩展性!

近期,LarkXR发布了一项新的解决方案,实现了3D/XR企业级应用全面接入Apple Vision Pro等头显设备。作为长期陪伴在XR行业开发者身边的技术伙伴,Paraverse平行云发起了此次行业调研,希望通过调研更直观地了解开发者在使用头显运行XR…

IDEA 导出ER图无表关系

一、通过IDEA导出的ER图无表关系,如下: 二、解决无表关系方法 1)这是建表时,user_work表中的t_id不规范,导致idea 找不到虚拟外键,也就不能绘制虚拟外键关系。那我们把user_work表t_id命名规范,t_id是user表…

VBA 批量变换文件名

1. 页面布局 在“main”Sheet中按照下面的格式编辑。 2. 实现代码 Private wsMain As Worksheet Private intIdx As LongPrivate Sub getExcelBookList(strPath As String)Dim fso As ObjectDim objFile As ObjectDim objFolder As ObjectSet fso = CreateObject("Scrip…

Steam新用户怎么参加夏促 Steam最新注册账号+下载客户端教程

steam夏促来了,这里给新玩家科普一下,steam就是一个游戏平台,里面的海量的各种游戏,而steam经常会有各种打折的活动,夏促就是其中之一,并且是其中规模最大的之一,涵盖游戏数量多,优惠…

ZW3D二次开发_CAM_添加刀具

在ZW3D2025中可以通过库添加刀具,代码如下: int idx_tool;int ret ZwCamToolInsertFromLibrary("", "001 METRIC TOOLS.xlsx", "10 mm Flat Endmill", &idx_tool); 平台功能添加刀具如下:

【Linux】进程间通信_2

文章目录 七、进程间通信1. 进程间通信分类管道 未完待续 七、进程间通信 1. 进程间通信分类 管道 管道的四种情况: ①管道内部没有数据,并且具有写端的进程没有关闭写端,读端就要阻塞等待,知道管道pipe内部有数据。 ②管道内部…

esp8266 GPIO

功能综述 ESP8266 的 16 个通⽤ IO 的管脚位置和名称如下表所示。 管脚功能选择 功能选择寄存器 PERIPHS_IO_MUX_MTDI_U(不同的 GPIO,该寄存器不同) PIN_FUNC_SELECT(PERIPHS_IO_MUX_MTDI_U,FUNC_GPIO12);PERIPHS_IO_MUX_为前缀。后面的…

uniapp开发企业微信内部应用

最近一直忙着开发项目,终于1.0版本开发完成,抽时间自己总结下在项目开发中遇到的技术点。此次项目属于自研产品,公司扩展业务,需要在企业微信中开发内部应用。因为工作中使用的是钉钉,很少使用企业微信,对于…

C# 警告 warning MSB3884: 无法找到规则集文件“MinimumRecommendedRules.ruleset”

警告 warning MSB3884: 无法找到规则集文件“MinimumRecommendedRules.ruleset” C:\Program Files\Microsoft Visual Studio\2022\Professional\MSBuild\Current\Bin\amd64\Microsoft.CSharp.CurrentVersion.targets(129,9): warning MSB3884: 无法找到规则集文件“MinimumRe…