导语: Google Tensor2Tensor 系统是一套十分强大的深度学习系统,在多个任务上的表现非常抢眼。尤其在机器翻译问题上,单模型的表现就可以超过之前方法的集成模型。这一套系统的模型结构、训练和优化技巧等,可以被利用到公司的产品线上,直接转化成生产力。本文对 Tensor2Tensor 系统从模型到代码进行了全面的解析,期望能够给大家提供有用的信息。
第一章:概述
Tensor2Tensor(T2T)是 Google Brain Team 在 Github 上开源出来的一套基于 TensorFlow 的深度学习系统。该系统最初是希望完全使用 Attention 方法来建模序列到序列(Sequence-to-Sequence,Seq2Seq)的问题,对应于《Attention Is All You Need》这篇论文。该项工作有一个有意思的名字叫 “Transformer”。随着系统的不断扩展,T2T 支持的功能变得越来越多,目前可以建模的问题包括:图像分类,语言模型、情感分析、语音识别、文本摘要,机器翻译。T2T 在很多任务上的表现很好,并且模型收敛比较快,在 TF 平台上的工程化代码实现的也非常好,是一个十分值得使用和学习的系统。
如果是从工程应用的角度出发,想快速的上手使用 T2T 系统,只需要对模型有一些初步的了解,阅读一下 workthrough 文档,很快就能做模型训练和数据解码了。这就是该系统想要达到的目的,即降低深度学习模型的使用门槛。系统对数据处理、模型、超参、计算设备都进行了较高的封装,在使用的时候只需要给到数据路径、指定要使用的模型和超参、说明计算设备就可以将系统运行起来了。
如果想深入了解系统的实现细节,在该系统上做二次开发或是实现一些研究性的想法,那就需要花费一定的时间和精力来对模型和代码进行研究。T2T 是一个较复杂的系统,笔者近期对模型和代码实现进行了全面的学习,同时对涉及到序列到序列功能的代码进行了剥离和重构,投入了较多的时间成本。因笔者是做自然语言处理研究的,这篇文章里主要关注的是 Transformer 模型。写这篇文章一方面是总结和记录一下这个过程中的一些收获,另一方面是把自己对 T2T 的理解分享出来,希望能够提供一些有用的信息给同学们。
第二章:序列到序列任务与 Transformer 模型
2.1 序列到序列任务与 Encoder-Decoder 框架
序列到序列(Sequence-to-Sequence)是自然语言处理中的一个常见任务,主要用来做泛文本生成的任务,像机器翻译、文本摘要、歌词 / 故事生成、对话机器人等。最具有代表性的一个任务就是机器翻译(Machine Translation),将一种语言的序列映射到另一个语言的序列。例如,在汉 - 英机器翻译任务中,模型要将一个汉语句子(词序列)转化成一个英语句子(词序列)。
目前 Encoder-Decoder 框架是解决序列到序列问题的一个主流模型。模型使用 Encoder 对 source sequence 进行压缩表示,使用 Decoder 基于源端的压缩表示生成 target sequence。该结构的好处是可以实现两个 sequence 之间 end-to-end 方式的建模,模型中所有的参数变量统一到一个目标函数下进行训练,模型表现较好。图 1 展示了 Encoder-Decoder 模型的结构,从底向上是一个机器翻译的过程。
图 1: 使用 Encoder-Decoder 模型建模序列到序列的问题
Encoder 和 Decoder 可以选用不同结构的 Neural Network,比如 RNN、CNN。RNN 的工作方式是对序列根据时间步,依次进行压缩表示。使用 RNN 的时候,一般会使用双向的 RNN 结构。具体方式是使用一个 RNN 对序列中的元素进行从左往右的压缩表示,另一个 RNN 对序列进行从右向左的压缩表示。两种表示被联合起来使用,作为最终序列的分布式表示。使用 CNN 结构的时候,一般使用多层的结构,来实现序列局部表示到全局表示的过程。使用 RNN 建模句子可以看做是一种时间序列的观点,使用 CNN 建模句子可以看做一种结构化的观点。使用 RNN 结构的序列到序列模型主要包括 RNNSearch、GNMT 等,使用 CNN 结构的序列到序列模型主要有 ConvS2S 等。
2.2 神经网络模型与语言距离依赖现象
Transformer 是一种建模序列的新方法,序列到序列的模型依然是沿用了上述经典的 Encoder-Decoder 结构,不同的是不再使用 RNN 或是 CNN 作为序列建模机制了,而是使用了 self-attention 机制。这种机制理论上的优势就是更容易捕获 “长距离依赖信息(long distance dependency)”。所谓的“长距离依赖信息” 可以这么来理解:1)一个词其实是一个可以表达多样性语义信息的符号(歧义问题)。2)一个词的语义确定,要依赖其所在的上下文环境。(根据上下文消岐)3)有的词可能需要一个范围较小的上下文环境就能确定其语义(短距离依赖现象),有的词可能需要一个范围较大的上下文环境才能确定其语义(长距离依赖现象)。
举个例子,看下面两句话:“山上有很多杜鹃,春天到了的时候,会漫山遍野的开放,非常美丽。” “山上有很多杜鹃,春天到了的时候,会漫山遍野的啼鸣,非常婉转。”在这两句话中,“杜鹃”分别指花(azalea)和鸟(cuckoo)。在机器翻译问题中,如果不看距其比较远的距离的词,很难将 “杜鹃” 这个词翻译正确。该例子是比较明显的一个例子,可以明显的看到词之间的远距离依赖关系。当然,绝大多数的词义在一个较小范围的上下文语义环境中就可以确定,像上述的例子在语言中占的比例会相对较小。我们期望的是模型既能够很好的学习到短距离的依赖知识,也能够学习到长距离依赖的知识。
那么,为什么 Transformer 中的 self-attention 理论上能够更好的捕获这种长短距离的依赖知识呢?我们直观的来看一下,基于 RNN、CNN、self-attention 的三种序列建模方法,任意两个词之间的交互距离上的区别。图 2 是一个使用双向 RNN 来对序列进行建模的方法。由于是对序列中的元素按顺序处理的,两个词之间的交互距离可以认为是他们之间的相对距离。W1 和 Wn 之间的交互距离是 n-1。带有门控(Gate)机制的 RNN 模型理论上可以对历史信息进行有选择的存储和遗忘,具有比纯 RNN 结构更好的表现,但是门控参数量一定的情况下,这种能力是一定的。随着句子的增长,相对距离的增大,存在明显的理论上限。
图 2 使用双向 RNN 对序列进行建模
图 3 展示了使用多层 CNN 对序列进行建模的方法。第一层的 CNN 单元覆盖的语义环境范围较小,第二层覆盖的语义环境范围会变大,依次类推,越深层的 CNN 单元,覆盖的语义环境会越大。一个词首先会在底层 CNN 单元上与其近距离的词产生交互,然后在稍高层次的 CNN 单元上与其更远一些词产生交互。所以,多层的 CNN 结构体现的是一种从局部到全局的特征抽取过程。词之间的交互距离,与他们的相对距离成正比。距离较远的词只能在较高的 CNN 节点上相遇,才产生交互。这个过程可能会存在较多的信息丢失。
图 3 使用多层 CNN 对序列进行建模
图 4 展示的是基于 self-attention 机制的序列建模方法。注意,为了使图展示的更清晰,少画了一些连接线,图中 “sentence” 层中的每个词和第一层 self-attention layer 中的节点都是全连接的关系,第一层 self-attention layer 和第二层 self-attention layer 之间的节点也都是全连接的关系。我们可以看到在这种建模方法中,任意两个词之间的交互距离都是 1,与词之间的相对距离不存在关系。这种方式下,每个词的语义的确定,都考虑了与整个句子中所有的词的关系。多层的 self-attention 机制,使得这种全局交互变的更加复杂,能够捕获到更多的信息。
图 4 使用 self-attention 对序列进行建模
综上,self-attention 机制在建模序列问题时,能够捕获长距离依赖知识,具有更好的理论基础。
2.3 self-attention 机制的形式化表达
上面小节介绍了 self-attention 机制的好处,本小结来介绍一下 self-attention 机制的的数学形式化表达。首先,从 attention 机制讲起。可以将 attention 机制看做一种 query 机制,即用一个 query 来检索一个 memory 区域。我们将 query 表示为 key_q,memory 是一个键值对集合(a set of key-value pairs),共有 M 项,其中的第 i 项我们表示为
2.4 “Attention is All You Need”
《Attention Is All You Need》这篇文章,描述了一个基于 self-attention 的序列到序列的模型,即 “Transformer”。该模型将 WMT2014 英 - 德翻译任务的 BLEU 值推到了新高,在英 - 法翻译任务上,接近于之前报出的最好成绩,而这仅仅是 Transformer 单模型的表现。之前报出的最好成绩都是基于集成方法的,需要训练多个模型,最后做集成。同时该模型也被用在英语的成分句法分析任务上,表现也基本接近于之前报出的最好模型成绩。该模型的收敛速度也非常的快,在英 - 法 3600 万句对的训练集上,只需要 8 卡并行 3.5 天就可以收敛。
该模型的表现的如此好的原因,其实不仅仅是一个 self-attention 机制导致的,实际上 Transformer 模型中使用了非常多有效的策略来使得模型对数据的拟合能力更强,收敛速度更快。整个 Transformer 的模型是一套解决方案,而不仅仅是对序列建模机制的改进。下面我们对其进行一一讲解。
2.4.1 Self-attention 机制的变种
首先,还是来讲一下 Transformer 中的 self-attention 机制。上面讲到了 self-attention 的基本形式,但是 Transformer 里面的 self-attention 机制是一种新的变种,体现在两点,一方面是加了一个缩放因子(scaling factor),另一方面是引入了多头机制(multi-head attention)。
缩放因子体现在 Attention 的计算公式中多了一个向量的维度作为分母,目的是想避免维度过大导致的点乘结果过大,进入 softmax 函数的饱和域,引起梯度过小。Transformer 中的 self-attention 计算公式如下:
多头机制是指,引入多组的参数矩阵来分别对 Q、K、V 进行线性变换求 self-attention 的结果,然后将所有的结果拼接起来作为最后的 self-attention 输出。这样描述可能不太好理解,一看公式和示意图就会明白了,如下:
图 5 单头和多头的 Attention 结构
这种方式使得模型具有多套比较独立的 attention 参数,理论上可以增强模型的能力。
2.4.2 位置编码(Positional Encoding)
self-attention 机制建模序列的方式,既不是 RNN 的时序观点,也不是 CNN 的结构化观点,而是一种词袋(bag of words)的观点。进一步阐述的话,应该说该机制视一个序列为扁平的结构,因为不论看上去距离多远的词,在 self-attention 机制中都为 1。这样的建模方式,实际上会丢失词之间的相对距离关系。举个例子就是,“牛 吃了 草”、“草 吃了 牛”,“吃了 牛 草” 三个句子建模出来的每个词对应的表示,会是一致的。
为了缓解这个问题,Transformer 中将词在句子中所处的位置映射成 vector,补充到其 embedding 中去。该思路并不是第一次被提出,CNN 模型其实也存在同样的难以建模相对位置(时序信息)的缺陷,Facebook 提出了位置编码的方法。一种直接的方式是,直接对绝对位置信息建模到 embedding 里面,即将词 Wi 的 i 映射成一个向量,加到其 embedding 中去。这种方式的缺点是只能建模有限长度的序列。Transformer 文章中提出了一种非常新颖的时序信息建模方式,就是利用三角函数的周期性,来建模词之间的相对位置关系。具体的方式是将绝对位置作为三角函数中的变量做计算,具体公式如下:
该公式的设计非常先验,尤其是分母部分,不太好解释。从笔者个人的观点来看,一方面三角函数有很好的周期性,也就是隔一定的距离,因变量的值会重复出现,这种特性可以用来建模相对距离;另一方面,三角函数的值域是 [-1,1],可以很好的提供 embedding 元素的值。
2.4.3 多层结构
Transformer 中的多层结构非常强大,使用了之前已经被验证过的很多有效的方法,包括:residual connection、layer normalization,另外还有 self-attention 层与 Feed Forward 层的堆叠使用,也是非常值得参考的结构。图 6 展示了 Transformer 的 Encoder 和 Decoder 一层的结构。
图 6 Transformer 模型结构
图 6 中,左侧的 Nx 代表一层的 Encoder,这一层中包含了两个子层(sub-layer),第一个子层是多头的 self-attention layer,第二个子层是一个 Feed Forward 层。每个子层的输入和输出都存在着 residual connection,这种方式理论上可以很好的回传梯度。Layer Normalization 的使用可以加快模型的收敛速度。self-attention 子层的计算,我们前面用了不少的篇幅讲过了,这里就不再赘述了。Feed Forward 子层实现中有两次线性变换,一次 Relu 非线性激活,具体计算公式如下:
文章的附页中将这种计算方式也看做是一种 attention 的变种形式。
图 6 中,右侧是 Decoder 中一层的结构,这一层中存在三个子层结构,第一层是 self-attention layer 用来建模已经生成的目标端句子。在训练的过程中,需要一个 mask 矩阵来控制每次 self-attention 计算的时候,只计算到前 t-1 个词,具体的实现方式,我们会在后面讲代码实现的时候进行说明。第二个子层是 Encoder 和 Decoder 之间的 attention 机制,也就是去源语言中找相关的语义信息,这部分的计算与其他序列到序列的注意力计算一致,在 Transformer 中使用了 dot-product 的方式。第三个子层是 Feed Forward 层,与 Encoder 中的子层完全一致。每个子层也都存在着 residual connection 和 layer normalization 操作,以加快模型收敛。
Transformer 中的这种多层 - 多子层的机制,可以使得模型的复杂度和可训练程度都变高,达到非常强的效果,值得我们借鉴。
2.4.4 优化方法与正则策略
模型的训练采用了 Adam 方法,文章提出了一种叫 warm up 的学习率调节方法,如公式所示:
公式比较先验,看上去比较复杂,其实逻辑表达起来比较清楚,需要预先设置一个 warmup_steps 超参。当训练步数 step_num 小于该值时,以括号中的第二项公式决定学习率,该公式实际是 step_num 变量的斜率为正的线性函数。当训练步数 step_num 大于 warm_steps 时,以括号中的第一项决定学习率,该公式就成了一个指数为负数的幂函数。所以整体来看,学习率呈先上升后下降的趋势,有利于模型的快速收敛。
模型中也采用了两项比较重要的正则化方法,一个就是常用的 dropout 方法,用在每个子层的后面和 attention 的计算中。另一个就是 label smoothing 方法,也就是训练的时候,计算交叉熵的时候,不再是 one-hot 的标准答案了,而是每个 0 值处也填充上一个非 0 的极小值。这样可以增强模型的鲁棒性,提升模型的 BLEU 值。这个思路其实也是一定程度在解决训练和解码过程中存在的 exposure bias 的问题。
2.4.5 本章小结
Transformer 系统的强大表现,不仅仅是 self-attention 机制,还需要上述的一系列配合使用的策略。设计该系统的研究者对深度学习模型和优化算法有着非常深刻的认识和敏锐的感觉,很多地方值得我们借鉴学习。Transformer 的代码实现工程化比较好,但是也存在一些地方不方便阅读和理解,后面的章节中会对 Transformer 的代码实现进行详细讲解,将整体结构讲清楚,把其中的疑难模块点出来。
第三章:Tensor2Tensor 系统实现深度解析
Tensor2Tensor 的系统存在一些特点,导致使用和理解的时候可能会存在一些需要时间来思考和消化的地方,在此根据个人的理解,写出一些自己曾经花费时间的地方。
3.1 使用篇
Tensor2Tensor 的使用是比较方便的,对于系统中可以支持的问题,直接给系统设置好下面的信息就可以运行了:数据,问题 (problem),模型,超参集合,运行设备。这里的实现其实是采用了设计模型中的工厂模式,即给定一个问题名字,返回给相应的处理类;给定一个超参名,返回一套超参的对象。实现这种方式的一个重点文件是 utils/registry.py。在系统启动的时候,所有的问题和超参都会在 registry 中注册,保存到_MODELS,_HPAPAMS,_RANGED_HPARAMS 中等待调用。
在此主要以序列到序列的系统使用和实现为主线进行讲解。系统的运行分三个阶段:数据处理,训练,解码。对应着三个入口:t2t-datagen,t2t-trainer,t2t-decoder。
数据处理的过程包括:
1.(下载)读取训练和开发数据。如果需要使用自己的数据的话,可以在问题中指定。
2.(读取)构造词汇表。可以使用自己预先构造好的词汇表。系统也提供构建 BPE 词汇表的方法。注意,这里有个实现细节是系统在抽取 BPE 词汇表的时候,有个参数,默认并非使用全量的数据。通过多次迭代尝试,得到最接近预设词汇表规模的一个词汇表。在大数据量的时候,这个迭代过程会非常慢。
3. 使用词汇表将单词映射成 id,每个句子后会加 EOS_ID,每个平行句对被构造成一个 dict 对象 ({‘inputs’:value,‘targets’:value}),将所有对象序列化,写入到文件中,供后面训练和评价使用。
模型训练的过程的过程主要通过高级的 Tensorflow API 来管理,只是需要指定数据、问题名、模型名、超参名、设备信息就可以运行了。比较关键的一个文件是 utils/trainer_lib.py 文件,在这个文件中,构建 Experiment、Estimator、Monitor 等来控制训练流程。使用者主要需要设置的就是训练过程的一些参数,比如训练最大迭代次数,模型评估的频率,模型评估的指标等。超参可以直接使用系统已有的参数集,也可以通过字符串的形式向内传参。简单的任务不太需要动超参,因为系统中的超参集合基本上都是经过实验效果验证的。需要注意的就是 batch_size 过大的时候,可能会导致显存不足,导致程序错误。一般是使用 continuous_train_and_eval 模式,使模型的训练和评估间隔进行,随时可以监控模型的表现。
解码的过程,可以提供整体文件、也可以是基于 Dataset 的,同时系统也提供 server 的方式,可以提供在线的服务,并没有什么特别好讲的。
3.2 深度掌握篇
3.2.1 Tensor2Tensor 系统实现的特点
下面列出了要深度掌握 Tensor2Tensor 系统时,可能因为其实现特点,会遇到的一些问题:
1. 系统支持多任务,任务混杂,导致代码结构比较复杂。在实现的时候,要考虑到整体的结构,所以会存在各种封装、继承、多态的实现。可能你只想用其中的一个功能,理解该功能对应的代码,但是却需要排除掉大量的不相关的代码。
2. 系统基于 Tensorflow 封装较高的 API。使用了 Tensorflow 中比较高的 API 来管理模型的训练和预测,Experiment,Monitor,Estimator,Dataset 对象的使用隐藏了比较多的控制流程,对于侧重应用的用户来说,可能是是好事情,设一设参数就能跑。但是对于想了解更多的开发人员来说,TF 该部分的文档实在很少,说的也不清楚,很多时候需要去阅读源代码才能知道实验到底是不是按照自己预期的进行的。这种方式也不太方便找 bug 和调试。
3. 某些方法调用比较深。原因应该还是出于整体结构和扩展性的考虑。这导致了实现一点很小功能的方法 A,需要再调一个其他方法 B,B 再去调用方法 C,实际上每个方法中就几行代码,甚至有的方法就是空操作。
4. 多层继承和多态也降低了代码的可读性。追溯一个类的某个方法的时候,需要看到其父类的父类的父类。。。这些父类和子类之间的方法又存在着调来调去的关系,同名方法又存在着覆盖的关系,所以要花一些时间来确定当前的方法名到底是调用的的哪个类中的方法。
5. 要求开发者有模型层面的理解和与代码实现的挂钩。肯定是要提高对模型逻辑的理解,但在读代码的过程中,会遇到两种问题:第一个,代码实现的是论文中的功能,但不是论文中的原始公式,可能要做变形以规避溢出的问题,或是实现更高的效率;第二个,某些代码实现与其论文中的表述存在不一致的情况。
3.2.2 总体逻辑模块
总体来说,对 T2T 系统的代码逻辑划分如下,共包括三个大的模块:
- 问题定义和数据管理的模块。该模块用来定义问题和处理数据,比如定义一个翻译的问题,里面定义抽词汇表和构造训练样本的方法。
- 模型定义和计算图构建的模块。该模块用来定义模型属性和计算图结构。
- 实验流程控制与并行化模块。该模块用于实验流程控制,设置可用计算设备,提供模型并行化运行方法。
图 7 Tensor2Tensor 主要逻辑模块
这里不会对代码做追踪式的分析,会分条的讲解一些阅读 Tensor2Tensor 系统代码时可能遇到的问题,点出一些重要的功能所在的位置和实现逻辑。
- 工厂模式。系统使用工厂模式管理问题、模型、超参、模态等模块的方法。前面在使用篇讲到了 registry.py 这个比较关键的文件,是系统总体管理和调度模块的一个核心文件。如果要在系统中增加新的问题、模型、超参、模态等,也都需要通过在类前加装饰器的方式来注册到 registry 中,否则系统找不到新加的模块。
- ** 问题类 (problem)。**data_generators/problem.py 中的 class Problem 是后面所有 problem 的基类。之前说到系统中的类之间的多层继承关系导致代码读起来比较麻烦,举个例子来说,一个翻译问题继承路线是这样的:Problem»Text2TextProblem»TranslateProblem»TranslateEndeWmtBpe32k» TranslateEndeWmt32k,中间各种的方法和变量覆盖,父类和子类之间方法的穿插调用,导致一些阅读困难。总的来说,一个序列到序列的问题应该包括以下信息和方法:数据文件信息,词汇表文件名、类型、大小,构造词汇表的方法,序列化训练数据和开发数据的方法,读取数据文件为 model(estimator)构造输入流 input_fn 的方法,设定问题评估 metric 的方法。可以总结来说,问题的属性定义、训练和评价样本的构造、数据的处理和读取,都由 problem 这个体系里面的类和方法来提供。
- 词汇表对象 (TextEncoder)。系统中有多种多样的词汇表(TextEncoder)对象,可以支持字母(character),子词(subword/bpe),词汇(token)等多重方式。TextEncoder 主要功能就是构建词汇表、实现符号到 id 的映射。T2T 里有构造 bpe 词汇表的方法,没有 word piece 词汇表的构造方法,也可以看出 T2T 研究团队和 GNMT 研究团队的区分。两个团队一直在交替的更新机器翻译任务的最高成绩。构建 BPE 词汇表的具体实现在 SubwordTextEncoder 中的 build_to_target_size()方法,该方法不是之前 Sennrich 使用迭代次数来控制词汇表大小的方式,而是使用二分查找的方式,通过搜索最优的 minimum token count 值来逼近预先设置的词汇表的大小。
- T2TModel 类。utils/t2t_model.py 中的 class T2TModel 是模型功能的基类,该类继承自 layer,Transformer 类便继承于此类。T2TModel 类中定义了模型的计算图结构,即给定 feature 后,模型是怎么根据 feature 进行图计算,得到 logit,loss,然后根据 loss 求梯度,调用 optimizer 进行梯度回传,进行参数更新的。构建计算图的目的是最终要构建 tf.estimator.EstimatorSpec()对象。可以理解为,所有的模型图计算过程都在该对象中被表达了。T2TModel 可以返回三种 EstimatorSpec 对象,分别用于训练、评价和解码。训练的过程可以支持数据并行,具体的实现是同时在多个数据片上激活计算图,得到的 loss 做平均,这是一种同步并行训练的方式。T2TModel 中也提供了供解码的方法。
- Transformer 类。models/transformer.py 中的 class Transformer 继承自 class T2TModel,为其父类构建图的时候,提供各种支持的方法,encode 方法可以使用 Encoder 结构对源端进行压缩表示,decode 方法使用 Decoder 结构对目标端进行生成。同时,transformer.py 中有多套参数供选择。模型中 feed-forward 子层的实现也在该文件中 (transformer_ffn_layer)。
- 数据并行类。devices.py 和 expert_utils.py 配合使用,主要功能是根据用户给定的并行设备参数,列出可以使用的设备名,然后给定一个能够调用这些设备,并行执行方法的方法。
- 实验流程控制。实验流程控制使用的是 Tensorflow 的高级 API 对象,主要包括 Experiment 对象、Estimator 对象、Dataset 对象。对这三个对象,我们可以这么理解:a) Experiment 是一次运行的实验,用来控制实验流程,输送数据到模型。b) Estimator 是具体的模型对象,可以包括训练、评估、解码三个功能。c) Dataset 为运行的实验过程读数据文件提供数据流。
- Experiment 对象。我们来看下图中 Experiment 初始化所需的形参就能更好的理解 “实验” 这个概念了。Experiment 对象中需要迭代中的各种 step 参数,需要一个 Estimator 对象,两个输入流函数(input)。Experiment 对象在运行中,将数据给到 Estimator 对象,然后控制训练和迭代流程。
图 8 Experiment 对象的部分形参
9.Estimator 对象。可以理解为模型对象,可以通过 Estimator 执行模型的训练、评估、解码。Estimator 对象最重要的一个形参是 model_fn,也就是具体执行训练、评估、解码的函数入口。三个入口分别对应着三个 EstimatorSpec 对象,如图 9,10 所示。
图 9 Estimator 中最重要的形参是 model_fn
图 10 Estimator 中的三种 model_fn,实现三种功能
从图 10 可以看出,用于训练的 EstimatorSpec 对象需要描述计算图中 feature 和(loss,train_op)之间的关系;用于评估的 EstimatorSpec 对象需要描述计算图中 feature 和(loss,eval_metrics_ops)之间的关系;用于评估的 EstimatorSpec 对象需要描述 features 和 predictions 之间的关系。
- Dataset 对象。该对象是读文件,构造训练和评估的数据流。训练和评估对应着两种不同的数据输入流,如图 11 所示。
图 11 Dataset 对象提供数据流
\11. Positional encoding 的实现。论文中的实现和代码中的实现存在公式变形和不一致的情况,可能会导致困惑,故在此指出。论文中 Positional encoding 中三角函数的参数部分公式如下:
代码中的实现需要对该公式做变形,以规避数值溢出的风险,公式变形过程如下:
还需要指出的是,论文中根据维度下标的奇偶性来交替使用 sin 和 cos 函数的说法,在代码中并不是这样实现的,而是前一半的维度使用 sin 函数,后一半的维度使用 cos 函数,并没有考虑奇偶性
12. 以 token 数量作为 batch size。这种方式比起以句子个数作为 batch size 的方式来,能到 batch 占显存的空间更加平均,不会导致因为训练数据导致的显存占用忽上忽下,造成显存空间不够用,导致程序崩溃。
\13. 如何做 mask。由于模型是以 batch 为单位进行训练的,batch 的句长以其中最长的那个句子为准,其他句子要做 padding。padding 项在计算的过程中如果不处理的话,会引入噪音,所以就需要 mask,来使 padding 项不对计算起作用。mask 在 attention 机制中的实现非常简单,就是在 softmax 之前,把 padding 位置元素加一个极大的负数,强制其 softmax 后的概率结果为 0。举个例子,[1,1,1] 经过 softmax 计算后结果约为 [0.33,0.33,0.33],[1,1,-1e9] softmax 的计算结果约为 [0.5, 0.5,0]。这样就相当于 mask 掉了数组中的第三项元素。在对 target sequence 进行建模的时候,需要保证每次只 attention 到前 t-1 个单词,这个地方也需要 mask,整体的 mask 是一个上三角矩阵,非 0 元素值为一个极大的负值。
\14. 基于 batch 的解码。解码的时候,如果是基于文件的,那么就会将句子组成 batch 来并行解码。这里有个小 trick,就是先对句子进行排序,然后从长的句子开始组 batch,翻译,再把句子恢复成原先的顺序返回。这种方式可以很好的检测到显存不足的错误,因为解句子最长的一个 batch 的时候,显存都是够得,那其他的 batch 也不存在问题。
总结
本文对 Google 的 Tensor2Tensor 系统进行了深度的解读,涉及到了比较多的方面,笔者也还需要对其进行更加深入的学习和研究,希望能够与对该模型以及 DL for NLP 技术感兴趣的同学们一起交流,共同进步!
问答
docker 和 docker-compose 有什么不同?
相关阅读
此文已由作者授权腾讯云 + 社区发布,原文链接:https://cloud.tencent.com/developer/article/1116709?fromSource=waitui
欢迎大家前往腾讯云 + 社区或关注云加社区微信公众号(QcloudCommunity),第一时间获取更多海量技术实践干货哦~
海量技术实践经验,尽在云加社区! https://cloud.tencent.com/developer?fromSource=waitui