LSTM算法原理&实现
🤿

LSTM算法原理&实现

text
LSTM算法原理学习&输出&Python代码实现
Tags
机器学习
深度学习
RNN
Created
Sep 7, 2022 07:26 AM

应用背景

Recurrent network的应用主要以下部分:
  1. 文本相关:主要用于自然语言处理(NLP)、对话系统、情感分析、机器翻译等领域;Google翻译用的就是一个7-8层的LSTM模型;
  1. 时序相关:时序问题,诸如预测天气、温度、股票等等;
  1. 行为演化相关:用户行为序列中用户兴趣的演化,诸如DIEN中GRU的应用;
以上问题都有先后顺序的概念,而传统全连接网络缺乏”记忆“,即RNN的做法就是维护一些中间状态信息

SimpleRNN

Recurrent Neural Network用来维护中间状态,记录之前看到的信息;
notion image
notion image
RNN如图所示,每次循环时会用到上一次的计算结果
state_t = 0 #时刻t的状态 for input_t in input_sequence: # 在timesteps上loop output_t = f(input_t, state_t) # input_t state_t得到时刻t输出 state_t = output_t # 用当前输出去更新内部状态 """ 其中f为一个函数,完成input和state到output的转换; f = activation(dot(W, input) + dot(U, state) + b) """

优缺点

  • 优点:处理文本句子和时间序列的效果会比普通DNN要好,中间状态理论上维护额从开头到现在的所有信息;
  • 缺点:不能处理长文本句子或时间序列,因为会出现梯度消失,网络几乎不可训练;

LSTM算法原理

💡
LSTM就是用来解决RNN中梯度消失问题的,从而可以处理Long-term sequences
LSTM增加一个可以相隔多个timesteps来传递信息的方法
notion image
因此当前时刻的输出就受到三个信息的影响:当前时刻的输入、当前时刻的状态、传送带上带来的很久之前的信息:
output_t = activation(dot(state_t, Uo) + dot(input_t, Wo) + dot(C_t, Vo) + bo)
因此需要更新的为:State_c,C_t,其中state_c和SimpleRNN中相同的更新方式,C_t更新如下:
i_t = activation(dot(state_t, Ui) + dot(input_t, Wi) + bi) # 输入门 f_t = activation(dot(state_t, Uf) + dot(input_t, Wf) + bf) # 遗忘门 k_t = activation(dot(state_t, Uk) + dot(input_t, Wk) + bk) # 遗忘门 # 更新公式 C_t+1 = i_t * k_t + c_t * f_t # 长记忆 # 解释 c_t * f_t 用作让模型忘记不相关的信息; i_k * k_t 用作让模型提供关于当前时刻的信息;
notion image
notion image

优缺点

  • 优点:解决了SimpleRNN梯度消失的问题,可以用来处理long-term sequence;
  • 缺点:计算复杂度高,谷歌翻译也只有7-8层LSTM;

实践指南

RNN表达能力

有时候RNN表达能力有限,为了增加RNN的表达能力,可以使用stack rnn layers来增加其表达能力;

过拟合

RNN LSTM同样会过拟合,解决方法类似dropout,在整个timesteps上使用一个固定的drop mask,其中分别有对输入的drop_rate以及recurrent connection(state_t输入到SimpleRNN的部分)的drop_rate;

GRU算法

LSTM的计算比较慢,所以有了Gated Recurrent Unit(GRU),你可以认为他是经过特殊优化提速的LSTM,但是他的表达能力也是受到限制的;

代码实现

# 主要参考keras LSTM代码进行实现 # 简单形式,即implementation!=1情况下,各个门的输入input使用相同的mask来进行dropout if 0. < self.dropout < 1.: inputs = inputs * dp_mask[0] # 统一乘dp_mask[0] z = backend.dot(inputs, self.kernel) z += backend.dot(h_tm1, self.recurrent_kernel) if self.use_bias: z = backend.bias_add(z, self.bias) z = array_ops.split(z, num_or_size_splits=4, axis=1) c, o = self._compute_carry_and_output_fused(z, c_tm1) def _compute_carry_and_output_fused(self, z, c_tm1): """Computes carry and output using fused kernels.""" z0, z1, z2, z3 = z i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3) return c, o ########################################################################### # 稍微复杂形式,即implementation==1情况下,各个门的输入input/状态state可以使用不同的mask参数来进行dropout if 0 < self.dropout < 1: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] inputs_c = inputs * dp_mask[2] inputs_o = inputs * dp_mask[3] else: inputs_i = inputs inputs_f = inputs inputs_c = inputs inputs_o = inputs k_i, k_f, k_c, k_o = array_ops.split(self.kernel, num_or_size_splits=4, axis=1) x_i = backend.dot(inputs_i, k_i) x_f = backend.dot(inputs_f, k_f) x_c = backend.dot(inputs_c, k_c) x_o = backend.dot(inputs_o, k_o) if self.use_bias: b_i, b_f, b_c, b_o = array_ops.split(self.bias, num_or_size_splits=4, axis=1) x_i = backend.bias_add(x_i, b_i) x_f = backend.bias_add(x_f, b_f) x_c = backend.bias_add(x_c, b_c) x_o = backend.bias_add(x_o, b_o) if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] h_tm1_f = h_tm1 * rec_dp_mask[1] h_tm1_c = h_tm1 * rec_dp_mask[2] h_tm1_o = h_tm1 * rec_dp_mask[3] else: h_tm1_i = h_tm1 h_tm1_f = h_tm1 h_tm1_c = h_tm1 h_tm1_o = h_tm1 x = (x_i, x_f, x_c, x_o) h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) def _compute_carry_and_output(self, x, h_tm1, c_tm1): """Computes carry and output using split kernels.""" x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 i = self.recurrent_activation( x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) f = self.recurrent_activation( x_f, backend.dot(h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) c = f * c_tm1 + i * self.activation(x_c + backend.dot( h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) o = self.recurrent_activation( x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) return c, o

参考

LSTM原理与实践,原来如此简单
LSTM全称 Long Short-Term Memory ,是1997年就被发明出来的算法,作者是谁说了你们也记不住干脆就不说了(主要是我记不住,逃...) 经过这么多年的发展,基本上没有什么理论创新,唯一值得说的一点也就是加入了Dropout来对抗过拟合。真的是应了那句话呀: Deep learning is an art more than a science. 怪不得学数学的一直看不起搞算法的,不怪人家,整天拿着个梯度下降搞来搞去,确实有点low。。。 即使这样,LSTM的应用依旧非常的广泛,而且效果还不错。但是,LSTM的原理稍显复杂,苦于没有找到非常好的资料,小编之前也是被各种博客绕的团团转,今天重新梳理了一次,发现并没有那么难,这里把总结的资料分享给大家。 认真阅读本文,你将学到: 1. RNN原理、应用背景、缺点2. LSTM产生原因、原理,以及关于LSTM各种"门"的一些intuition(哲学解释) (别怕,包教包会)3. 如何利用Keras使用LSTM来解决实际问题4. 关于Recurrent Network的一些常用技巧,包括:过拟合,stack rnn Recurrent network的应用主要如下两部分: 文本相关。主要应用于自然语言处理(NLP)、对话系统、情感分析、机器翻译等等领域,Google翻译用的就是一个7-8层的LSTM模型。 时序相关。就是时序预测问题(timeseries),诸如预测天气、温度、包括个人认为根本不可行的但是很多人依旧在做的预测股票价格问题 这些问题都有一个共同点,就是 有先后顺序的概念 的。举个例子: 根据前5天每个小时的温度,来预测接下来1个小时的温度。典型的时序问题,温度是从5天前,一小时一小时的记录到现在的,它们的顺序不能改变,否则含义就发生了变化;再比如情感分析中,判断一个人写的一篇文章或者说的一句话,它是积极地(positive),还是消极的(negative),这个人说的话写的文章,里面每个字都是有顺序的,不能随意改变,否则含义就不同了。 全连接网络Fully-Connected Network,或者卷积神经网络Convnet,他们在处理一个sequence(比如一个人写的一条影评),或者一个timeseries of data points(比如连续1个月记录的温度)的时候,他们 缺乏记忆。一条影评里的每一个字经过word embedding后,被当成了一个独立的个体输入到网络中;网络不清楚之前的,或者之后的文字是什么。这样的网络,我们称为 feedforward network 。 但是实际情况,我们理解一段文字的信息的时候,每个文字并不是独立的,我们的脑海里也有它的上下文。比如当你看到这段文字的时候,你还记得这篇文章开头表达过一些关于LSTM的信息; 所以,我们在脑海里维护一些信息,这些信息随着我们的阅读不断的更新,帮助我们来理解我们所看到的每一个字,每一句话。这就是RNN的做法: 维护一些中间状态信息 。 RNN是 Recurrent Neural Network 的缩写,它就是实现了我们来维护中间信息,记录之前看到信息这样最简单的一个概念的模型。 关于名称,你可以这样理解: Recurrent Neural Network = A network with a loop.
LSTM原理与实践,原来如此简单