MLR(mixed logistic regression)算法原理及实现

article/2025/11/9 22:23:51

MLR(mixed logistic regression)算法

参考https://zhuanlan.zhihu.com/p/77798409?utm_source=wechat_session
盖坤ppt : https://wenku.baidu.com/view/b0e8976f2b160b4e767fcfdc.html
原文:《Learning Piece-wise Linear Models from Large Scale Data for Ad Click Prediction》

MLR算法创新地提出并实现了直接在原始空间学习特征之间的非线性关系
MLR算法模型,这是一篇来自阿里盖坤团队的方案(LS-PLM),发表于2017年,但实际在2012年就已经提出并应用于实际业务中(膜拜ing),当时主流仍然是我们上一篇提到过的的LR模型,而本文作者创新性地提出了MLR(mixed logistic regression, 混合逻辑斯特回归)算法,引领了广告领域CTR预估算法的全新升级。

1.背景

CTR预估(click-through-rate prediction)是广告行业比较常见的问题,根据用户的历史行为来判断用户对广告点击的可能性。在常见工业场景中,该问题的输入往往是数以万计的稀疏特征向量,在进行特征交叉后会维数会更高,比较常见的就是采用逻辑回归模型加一些正则化,因为逻辑回归模型计算开销小且容易实现并行。之前提到的facebook的一篇论文(LR+GBDT)中先用树模型做分类之后再加一个逻辑回归模型,最后得出效果出奇的好,应该也是工业界比较常用的方法,同时树模型的选择或者说是再构造特征的特性也逐渐被大家所关注。另一种比较有效的就是因子分解模型系列,包括FM及其的其他变种,它们的主要思想就是构造交叉特征或者是二阶的特征来一起进行训练。
作者主要提出了一种piece-wise的线性模型,并且给出了其在大规模数据上的训练算法,称之为LS-PLM(Large Scale Piecewise Linear Model),LS-PLM采用了分治的思想,先分成几个局部再用线性模型拟合,这两部都采用监督学习的方式,来优化总体的预测误差,总的来说有以下优势:

  • 端到端的非线性学习
    从模型端自动挖掘数据中蕴藏的非线性模式,省去了大量的人工特征设计,这 使得MLR算法可以端到端地完成训练,在不同场景中的迁移和应用非常轻松。通过分区来达到拟合非线性函数的效果;
  • 可伸缩性(scalability)
    与逻辑回归模型相似,都可以很好的处理复杂的样本与高维的特征,并且做到了分布式并行;
  • 稀疏性
    对于在线学习系统,模型的稀疏性较为重要,所以采用了 L 1 L_1 L1 L 2 L_2 L2 正则化,模型的学习和在线预测性能更好。当然,目标函数非凸非光滑为算法优带来了新的挑战。

MLR方法

思想分而治之,由很多个LR模型组合而成。用分片线性模式来拟合高维空间的非线性模式,形式化表述如下:
在这里插入图片描述
当我们将softmax函数作为分割函数 σ ( x ) \sigma(x) σ(x),将sigma函数作为拟合函数 η ( x ) \eta(x) η(x)的时候,该模型为:
在这里插入图片描述
目标损失函数为
在这里插入图片描述
同时MLR还引入了结构化先验、分组稀疏、线性偏置、模型级联、增量训练、Common Feature Trick来提升模型性能。
结构化先验
MLR中非常重要的就是如何划分原始特征空间。
通过引入结构化先验,我们使用用户特征来划分特征空间,使用广告特征来进行基分类器的训练,减小了模型的探索空间,收敛更容易。
同时,这也是符合我们认知的:不同的人群具有聚类特性,同一类人群具有类似的广告点击偏好。
增量训练
实验证明,MLR利用结构先验(用户特征进行聚类,广告特征进行分类)进行pretrain,然后再增量进行全空间参数寻优训练,会使得收敛步数更少,收敛更稳定。
在这里插入图片描述
模型级联
盖坤在PPT讲解到,MLR支持与LR的级联式训练。有点类似于Wide & Deep,一些强Feature配置成级联形式能够提高模型的收敛性。例如典型的应用方法是:以统计反馈类特征构建第一层模型,输出FBctr级联到第二级大规模稀疏ID特征中去,能得到更好的提升效果。

此外盖坤还提出了带
在这里插入图片描述

代码实现

import numpy as np
import pandas as pd
import time
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score#数据处理
def data_concat(train,test):train['type']=1test['type'] =2all_columns=['age','workclass','fnlwgt','education','education-num','marital-status','occupation','relationship','race','sex','capital-gain','capital-loss','hours-per-week','native-country','label','type']all_data=pd.concat([train,test],axis=0)all_data.columns=all_columnsreturn all_datadef data_processing(train,test):df=data_concat(train,test)continus_columns=['age','fnlwgt','education-num','capital-gain','capital-loss','hours-per-week']category_columns=['workclass','education','marital-status','occupation','relationship','race','sex','native-country']#类别变量做one_hot_encodingdf=pd.get_dummies(df,columns=category_columns)#连续数据标准化for col in continus_columns:ss=StandardScaler()df[col]=ss.fit_transform(df[[col]])df['label']=df['label'].apply(lambda x: 1 if  x.strip()=='>50K'  else 0)return dfos.getcwd()
train_data=pd.read_table(r'E:/996/推荐系统/DATA/recsys-data/MLR/adult.data',header=None,delimiter=',')
test_data=pd.read_table(r'E:/996/推荐系统/DATA/recsys-data/MLR/adult.txt',header=None,delimiter=',')
test_data[14]=test_data[14].apply(lambda x: x[:-1])
df= data_processing(train_data,test_data)
train_data=df[df['type']==1].drop(['type'],axis=1)
test_data=df[df['type']==2].drop(['type'],axis=1)   #mlr模型训练及测试print(tf.__path__)
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()def train_and_test(train_data,test_data,m=2,learning_rate=0.1):#m=2#分片参数为2
#learning_rate#学习率为0.3train_y= train_data['label']train_x=train_data.drop('label',axis=1)test_y= test_data['label']test_x=test_data.drop('label',axis=1)x=tf.placeholder(tf.float32,shape=[None,108])#特征向量维度为108y=tf.placeholder(tf.float32,shape=[None])u=tf.Variable(tf.random_normal([108,m],0.0,0.5),name='u')w=tf.Variable(tf.random_normal([108,m],0.0,0.5),name='w')U=tf.matmul(x,u)p1=tf.nn.softmax(U)W=tf.matmul(x,w)p2=tf.nn.softmax(W)pred=tf.reduce_sum(tf.multiply(p1,p2),axis=1)cost1=tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred,labels=y))cost=tf.add_n([cost1])train_opt=tf.train.FtrlOptimizer(learning_rate).minimize(cost)time_start=time.time()#会话with tf.Session() as sess:sess.run(tf.global_variables_initializer())#初始化 train_dict={x:train_x,y:train_y}for epoch in range(500):_,cost_train,pred_train=sess.run([train_opt,cost,pred],feed_dict=train_dict)train_auc=roc_auc_score(train_y,pred_train)time_end=time.time()test_dict={x:test_x,y:test_y}result=[]if epoch %100==0:_,cost_test,pred_test=sess.run([train_opt,cost,pred],feed_dict=test_dict)test_auc=roc_auc_score(test_y,pred_test)print("epoch:%d,time:%d,train_auc:%f,test_auc:%f"%(epoch,(time_end-time_start),train_auc,test_auc))result.append([epoch,pred_test,train_auc,test_auc])return resulttrain_and_test(train_data,test_data,m=2,learning_rate=0.1)

http://chatgpt.dhexx.cn/article/LFWqX9fx.shtml

相关文章

多元线性回归MLR

多元线性回归(MLR) 文章目录 多元线性回归(MLR)由极大似然估计(MLE, Maximum likelihood estimation)推导MSE简单导数知识推导解析解( θ ( X T X ) − 1 X T Y \theta (X^TX)^{-1}X^TY θ(XTX)−1XTY)来个例子试一试吧 不用解析解, 用梯度下降求解 θ \theta θ梯度下降法另…

MySql MVCC 详解

注意以下操作都是以InnoDB引擎为操作基准。 一,前置知识准备 1,MVCC简介 MVCC 是多版本并发控制(Multiversion Concurrency Control)的缩写。它是一种数据库事务管理技术,用于解决并发访问数据库的问题。MVCC 通过创…

MVCC总结

MVCC多版本并发控制 数据库中的并发分为三种情况: 读读:不存在任何问题,也不需要并发控制 读写:有数据安全问题,脏读,幻读,不可重复读 写写:有数据安全问题,可能存在…

Mysql中的MVCC

Mysql到底是怎么实现MVCC的?这个问题无数人都在问,但google中并无答案,本文尝试从Mysql源码中寻找答案。 在Mysql中MVCC是在Innodb存储引擎中得到支持的,Innodb为每行记录都实现了三个隐藏字段: 6字节的事务ID&#xf…

MVCC原理

在并发读写数据库时,读操作可能会不一致的数据(脏读)。为了避免这种情况,需要实现数据库的并发访问控制,最简单的方式就是加锁访问。由于,加锁会将读写操作串行化,所以不会出现不一致的状态。但…

MVCC机制

MVCC 1. MVCC是什么? MVCC,全称Multi-Version Concurrency Control,即多版本并发控制。MVCC是一种并发控制的方法,一般在数据库管理系统中,实现对数据库的并发访问,在编程语言中实现事务内存。 MVCC的具体…

MVCC实现原理

1、什么是MVCC mvcc多版本并发控制。 mvcc在mysql innodb中主要是为了提高数据库并发性能,用更好的方式去处理读写冲突,做到即使有读写冲突时,也能做到不加索,非阻塞并发读。。 2、实现原理: mvcc的实现是通过保存…

MySQL的MVCC及实现原理

一 概要 1.什么是 MVCC ? MVCC,全称 Multi-Version Concurrency Control ,即多版本并发控制。MVCC 是一种并发控制的方法,一般在数据库管理系统中,实现对数据库的并发访问,在编程语言中实现事务内存。 MVCC 在 MySQL…

InnoDB MVCC 机制

本文详细的介绍了什么是MVCC?为什么要有MVCC?以及MVCC的内部实现原理:包括Undo Log的版本链是如何组织的,RR、RC两个级别下一致性读是如何实现的等。通过案例、插图,以最通俗易懂的方式,让你彻底掌握MVCC的…

mysql mvcc 实例说明_Mysql MVCC

一、MVCC概述 MVCC,全称Multi-Version Concurrency Control,即多版本并发控制。整个MVCC多并发控制的目的就是为了实现读-写冲突不加锁,提高并发读写性能,而这个读指的就是快照度, 而非当前读,当前读实际上是一种加锁的…

MVCC

一、什么是MVCC MVCC(Multiversion concurrency control )是一种多版本并发控制机制。 二、MVCC是为了解决什么问题? 并发访问(读或写)数据库时,对正在事务内处理的数据做多版本的管理。以达到用来避免写操作的堵塞,从而引发读操…

MVCC详解

一、前言 全称Multi-Version Concurrency Control,即多版本并发控制,主要是为了提高数据库的并发性能。以下文章都是围绕InnoDB引擎来讲,因为myIsam不支持事务。 同一行数据平时发生读写请求时,会上锁阻塞住。但mvcc用更好的方式…

MVCC 机制的原理及实现

什么是 MVCC MVCC (Multiversion Concurrency Control) 中文全程叫多版本并发控制,是现代数据库(包括 MySQL、Oracle、PostgreSQL 等)引擎实现中常用的处理读写冲突的手段,目的在于提高数据库高并发场景下的吞吐性能。 如此一来…

深入浅出:MVCC详解

什么是MVCC: MVCC(Multi Version Concurrency Control的简称),代表多版本并发控制。与MVCC相对的,是基于锁的并发控制,Lock-Based Concurrency Control)。 MVCC最大的优势:读不加锁,读写不冲突。在读多写少…

什么是MVCC?MVCC解决了什么问题?MVCC的实现原理?

1.什么是MVCC? MVCC全称是【Multi-Version ConCurrency Control】,即多版本控制协议。 多版本控制(Multiversion Concurrency Control): 指的是一种提高并发的技术。最早的数据库系统,只有读读之间可以并发&#xff…

MVCC详解,深入浅出简单易懂

一、什么是MVCC? mvcc,也就是多版本并发控制,是为了在读取数据时不加锁来提高读取效率和并发性的一种手段。 数据库并发有以下几种场景: 读-读:不存在任何问题。读-写:有线程安全问题,可能出…

【MySQL笔记】正确的理解MySQL的MVCC及实现原理

MVCC多版本并发控制 如果觉得对你有帮助,能否点个赞或关个注,以示鼓励笔者呢?!博客目录 | 先点这里 !首先声明,MySQL 的测试环境是 5.7 前提概要 什么是 MVCC什么是当前读和快照读?当前读&…

Oracle自定义函数

使用Navicat的话,可以点击函数,新建函数,根据引导完成一个函数的基本搭建。 语法和Java类似,其中对于变量赋值要使用 : 进行赋值。 具体语法可以参考一下 Oracle 自定义函数语法与实例_桑汤奈伊伏的博客-CSDN博客_oracle 自定义函…