FNN算法原理&实现
🔎

FNN算法原理&实现

text
FNN算法的Python代码实现和原理
Tags
机器学习
深度学习
推荐系统
Created
Aug 25, 2022 09:49 AM

背景

FNN由伦敦大学学院的研究人员于2016年提出,模型结构类似于Deep Crossing模型;FNN模型的解决思路是用FM模型训练好的各特征隐向量初始化Embedding层的参数,相当于再初始化神经网络参数时,已经引入了有价值的先验信息。

原理

notion image
主要分为两阶段训练,阶段一训练一个FM,阶段二训练一个带嵌入层的DNN;
  • 阶段一:先使用带标签的训练集有监督的训练FM模型,训练完成后得到每个特征对应的隐向量;若输入特征维度为n,隐向量维度为k,则隐向量矩阵W的形状为[n, k];
  • 阶段二:利用隐向量矩阵W初始化DNN的嵌入层,然后有监督训练DNN。DNN的输入包含了FM学到的先验知识,可减轻DNN的学习压力;

优缺点

优点

  • 将FM学习得到的隐向量作为DNN的输入,隐向量包含了FM学到的先验知识,可减轻DNN的学习压力;
  • FM只考虑到二阶特征交互,忽略了高阶特征,利用后续的DNN将其进行弥补,提高模型的表达能力;

缺点

  • 采用两阶段、非端到端的训练方式,不利于模型的线上部署;
  • 将FM的隐向量直接拼接作为DNN的输入,忽略了field的概念;
  • FNN未考虑低阶特征组合,低阶、高阶特征都很重要;

代码实现

if __name__ == '__main__': file = '../../data/criteo/train_1w.txt' (X_train, y_train), (X_test, y_test) = create_criteo_dataset('fnn', file, test_size=0.3) k = 8 model = FM(k) optimizer = optimizers.gradient_descent_v2.SGD(0.01) train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) train_dataset = train_dataset.batch(32).prefetch(tf.data.experimental.AUTOTUNE) model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy']) model.fit(train_dataset, epochs=20) fm_pre = model(X_test) fm_pre = [1 if x > 0.5 else 0 for x in fm_pre] v = model.variables[2] X_train = tf.cast(tf.expand_dims(X_train, -1), tf.float32) X_train = tf.reshape(tf.multiply(X_train, v), shape=(-1, v.shape[0] * v.shape[1])) hidden_units = [256, 128, 64] model = DNN(hidden_units, 1, 'relu') optimizer = optimizers.gradient_descent_v2.SGD(0.0001) train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) train_dataset = train_dataset.batch(32).prefetch(tf.data.experimental.AUTOTUNE) model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy']) model.fit(train_dataset, epochs=50) X_test = tf.cast(tf.expand_dims(X_test, -1), tf.float32) X_test = tf.reshape(tf.multiply(X_test, v), shape=(-1, v.shape[0] * v.shape[1])) fnn_pre = model(X_test) fnn_pre = [1 if x > 0.5 else 0 for x in fnn_pre] print('FM Accuracy: ', accuracy_score(y_test, fm_pre)) print('FNN Accuracy: ', accuracy_score(y_test, fnn_pre))

参考