应用背景
Recurrent network的应用主要以下部分:
- 文本相关:主要用于自然语言处理(NLP)、对话系统、情感分析、机器翻译等领域;Google翻译用的就是一个7-8层的LSTM模型;
- 时序相关:时序问题,诸如预测天气、温度、股票等等;
- 行为演化相关:用户行为序列中用户兴趣的演化,诸如DIEN中GRU的应用;
以上问题都有先后顺序的概念,而传统全连接网络缺乏”记忆“,即RNN的做法就是维护一些中间状态信息。
SimpleRNN
Recurrent Neural Network用来维护中间状态,记录之前看到的信息;
RNN如图所示,每次循环时会用到上一次的计算结果
state_t = 0 #时刻t的状态 for input_t in input_sequence: # 在timesteps上loop output_t = f(input_t, state_t) # input_t state_t得到时刻t输出 state_t = output_t # 用当前输出去更新内部状态 """ 其中f为一个函数,完成input和state到output的转换; f = activation(dot(W, input) + dot(U, state) + b) """
优缺点
- 优点:处理文本句子和时间序列的效果会比普通DNN要好,中间状态理论上维护额从开头到现在的所有信息;
- 缺点:不能处理长文本句子或时间序列,因为会出现梯度消失,网络几乎不可训练;
LSTM算法原理
LSTM就是用来解决RNN中梯度消失问题的,从而可以处理Long-term sequences
LSTM增加一个可以相隔多个timesteps来传递信息的方法
因此当前时刻的输出就受到三个信息的影响:当前时刻的输入、当前时刻的状态、传送带上带来的很久之前的信息:
output_t = activation(dot(state_t, Uo) + dot(input_t, Wo) + dot(C_t, Vo) + bo)
因此需要更新的为:State_c,C_t,其中state_c和SimpleRNN中相同的更新方式,C_t更新如下:
i_t = activation(dot(state_t, Ui) + dot(input_t, Wi) + bi) # 输入门 f_t = activation(dot(state_t, Uf) + dot(input_t, Wf) + bf) # 遗忘门 k_t = activation(dot(state_t, Uk) + dot(input_t, Wk) + bk) # 遗忘门 # 更新公式 C_t+1 = i_t * k_t + c_t * f_t # 长记忆 # 解释 c_t * f_t 用作让模型忘记不相关的信息; i_k * k_t 用作让模型提供关于当前时刻的信息;
优缺点
- 优点:解决了SimpleRNN梯度消失的问题,可以用来处理long-term sequence;
- 缺点:计算复杂度高,谷歌翻译也只有7-8层LSTM;
实践指南
RNN表达能力
有时候RNN表达能力有限,为了增加RNN的表达能力,可以使用stack rnn layers来增加其表达能力;
过拟合
RNN LSTM同样会过拟合,解决方法类似dropout,在整个timesteps上使用一个固定的drop mask,其中分别有对输入的drop_rate以及recurrent connection(state_t输入到SimpleRNN的部分)的drop_rate;
GRU算法
LSTM的计算比较慢,所以有了Gated Recurrent Unit(GRU),你可以认为他是经过特殊优化提速的LSTM,但是他的表达能力也是受到限制的;
代码实现
# 主要参考keras LSTM代码进行实现 # 简单形式,即implementation!=1情况下,各个门的输入input使用相同的mask来进行dropout if 0. < self.dropout < 1.: inputs = inputs * dp_mask[0] # 统一乘dp_mask[0] z = backend.dot(inputs, self.kernel) z += backend.dot(h_tm1, self.recurrent_kernel) if self.use_bias: z = backend.bias_add(z, self.bias) z = array_ops.split(z, num_or_size_splits=4, axis=1) c, o = self._compute_carry_and_output_fused(z, c_tm1) def _compute_carry_and_output_fused(self, z, c_tm1): """Computes carry and output using fused kernels.""" z0, z1, z2, z3 = z i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3) return c, o ########################################################################### # 稍微复杂形式,即implementation==1情况下,各个门的输入input/状态state可以使用不同的mask参数来进行dropout if 0 < self.dropout < 1: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] inputs_c = inputs * dp_mask[2] inputs_o = inputs * dp_mask[3] else: inputs_i = inputs inputs_f = inputs inputs_c = inputs inputs_o = inputs k_i, k_f, k_c, k_o = array_ops.split(self.kernel, num_or_size_splits=4, axis=1) x_i = backend.dot(inputs_i, k_i) x_f = backend.dot(inputs_f, k_f) x_c = backend.dot(inputs_c, k_c) x_o = backend.dot(inputs_o, k_o) if self.use_bias: b_i, b_f, b_c, b_o = array_ops.split(self.bias, num_or_size_splits=4, axis=1) x_i = backend.bias_add(x_i, b_i) x_f = backend.bias_add(x_f, b_f) x_c = backend.bias_add(x_c, b_c) x_o = backend.bias_add(x_o, b_o) if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] h_tm1_f = h_tm1 * rec_dp_mask[1] h_tm1_c = h_tm1 * rec_dp_mask[2] h_tm1_o = h_tm1 * rec_dp_mask[3] else: h_tm1_i = h_tm1 h_tm1_f = h_tm1 h_tm1_c = h_tm1 h_tm1_o = h_tm1 x = (x_i, x_f, x_c, x_o) h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) def _compute_carry_and_output(self, x, h_tm1, c_tm1): """Computes carry and output using split kernels.""" x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 i = self.recurrent_activation( x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) f = self.recurrent_activation( x_f, backend.dot(h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) c = f * c_tm1 + i * self.activation(x_c + backend.dot( h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) o = self.recurrent_activation( x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) return c, o