排序模型进阶-FTRL
1 问题
在实际项目的时候,经常会遇到训练数据非常大导致一些算法实际上不能操作的问题。比如在推荐行业中,因为DSP的请求数据量特别大,一个星期的数据往往有上百G,这种级别的数据在训练的时候,直接套用一些算法框架是没办法训练的,基本上在特征工程的阶段就一筹莫展。通常采用采样、截断的方式获取更小的数据集,或者使用大数据集群的方式进行训练,但是这两种方式在作者看来目前存在两个问题:
- 采样数据或者截断数据的方式,非常的依赖前期的数据分析以及经验。
- 大数据集群的方式,目前spark原生支持的机器学习模型比较少;使用第三方的算法模型的话,需要spark集群的2.3以上;而且spark训练出来的模型往往比较复杂,实际线上运行的时候,对内存以及QPS的压力比较大。
2 在线优化算法-Online-learning
-
模型更新周期慢,不能有效反映线上的变化,最快小时级别,一 般是天级别甚至周级别。
-
模型参数少,预测的效果差;模型参数多线上predict的时候需要内存大,QPS无法保证。
-
对1采用On-line-learning的算法。
-
对2采用一些优化的方法,在保证精度的前提下,尽量获取稀疏解,从而降低模型参数的数量。
比较出名的在线最优化的方法有:
TG(Truncated Gradient)
FOBOS(Forward-Backward Splitting)
RDA(Regularized Dual Averaging)
FTRL(Follow the Regularized Leader)
SGD算法是常用的online learning算法,它能学习出不错的模型,但学出的模型不是稀疏的。为此,学术界和工业界都在研究这样一种online learning算法,它能学习出有效的且稀疏的模型
2 FTRL
一种获得稀疏模型的优化方法
算法原理:http://vividfree.github.io/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/2015/12/05/understanding-FTRL-algorithm
3 离线数据训练FTRL模型
目的:通过离线TFRecords样本数据,训练FTRL模型
步骤:
1、构建模型
2、构建TFRecords的输入数据
3、train训练以及预测测试
完整代码:
import tensorflow as tf
from tensorflow.python import kerasclass LrWithFtrl(object):"""LR以FTRL方式优化"""def __init__(self):self.model = keras.Sequential([keras.layers.Dense(1, activation='sigmoid', input_shape=(121,))])@staticmethoddef read_ctr_records():# 定义转换函数,输入时序列化的def parse_tfrecords_function(example_proto):features = {"label": tf.FixedLenFeature([], tf.int64),"feature": tf.FixedLenFeature([], tf.string)}parsed_features = tf.parse_single_example(example_proto, features)feature = tf.decode_raw(parsed_features['feature'], tf.float64)feature = tf.reshape(tf.cast(feature, tf.float32), [1, 121])label = tf.reshape(tf.cast(parsed_features['label'], tf.float32), [1, 1])return feature, labeldataset = tf.data.TFRecordDataset(["./train_ctr_201904.tfrecords"])dataset = dataset.map(parse_tfrecords_function)dataset = dataset.shuffle(buffer_size=10000)dataset = dataset.repeat(10000)return datasetdef train(self, dataset):self.model.compile(optimizer=tf.train.FtrlOptimizer(0.03, l1_regularization_strength=0.01,l2_regularization_strength=0.01),loss='binary_crossentropy',metrics=['binary_accuracy'])self.model.fit(dataset, steps_per_epoch=10000, epochs=10)self.model.summary()self.model.save_weights('./ckpt/ctr_lr_ftrl.h5')def predict(self, inputs):"""预测:return:"""# 首先加载模型self.model.load_weights('/root/toutiao_project/reco_sys/offline/models/ckpt/ctr_lr_ftrl.h5')init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)predictions = self.model.predict(sess.run(inputs))return predictionsif __name__ == '__main__':lwf = LrWithFtrl()dataset = lwf.read_ctr_records()inputs, labels = dataset.make_one_shot_iterator().get_next()print(inputs, labels)lwf.predict(inputs)
在线预测
def lrftrl_sort_service(reco_set, temp, hbu):"""排序返回推荐文章:param reco_set:召回合并过滤后的结果:param temp: 参数:param hbu: Hbase工具:return:"""print(344565)# 排序# 1、读取用户特征中心特征try:user_feature = eval(hbu.get_table_row('ctr_feature_user','{}'.format(temp.user_id).encode(),'channel:{}'.format(temp.channel_id).encode()))logger.info("{} INFO get user user_id:{} channel:{} profile data".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), temp.user_id, temp.channel_id))except Exception as e:user_feature = []logger.info("{} WARN get user user_id:{} channel:{} profile data failed".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), temp.user_id, temp.channel_id))reco_set = [13295, 44020, 14335, 4402, 2, 14839, 44024, 18427, 43997, 17375]if user_feature and reco_set:# 2、读取文章特征中心特征result = []for article_id in reco_set:try:article_feature = eval(hbu.get_table_row('ctr_feature_article','{}'.format(article_id).encode(),'article:{}'.format(article_id).encode()))except Exception as e:article_feature = []if not article_feature:article_feature = [0.0] * 111f = []f.extend(user_feature)f.extend(article_feature)result.append(f)# 4、预测并进行排序是筛选arr = np.array(result)# 加载逻辑回归模型lwf = LrWithFtrl()print(tf.convert_to_tensor(np.reshape(arr, [len(reco_set), 121])))predictions = lwf.predict(tf.constant(arr))df = pd.DataFrame(np.concatenate((np.array(reco_set).reshape(len(reco_set), 1), predictions),axis=1),columns=['article_id', 'prob'])df_sort = df.sort_values(by=['prob'], ascending=True)# 排序后,只将排名在前100个文章ID返回给用户推荐if len(df_sort) > 100:reco_set = list(df_sort.iloc[:100, 0])reco_set = list(df_sort.iloc[:, 0])return reco_set
4 TensorFlow FTRL 读取训练
训练数据说明
-
原始特征用MurmurHash3的方式,将特征id隐射到(Long.MinValue, Long.MaxValue)范围
-
保存成One-Hot的数据格式
算法参数 -
lambda1:L1正则系数,参考值:10 ~ 15
-
lambda2:L2正则系数,参考值:10 ~ 15
-
alpha:FTRL参数,参考值:0.1
-
beta:FTRL参数,参考值:1.0
-
batchSize: mini-batch的大小,参考值:10000