FM、FFM以及DeepFM

news2025/5/26 6:26:13

FM部分

  • 什么是FM
    FM是factor machine的简写,中文翻译为因子分解机。
  • 为什么需要FM
    在进行特征建模的过程中,经常会遇到两种情况:
    1. 对特征直接进行建模,未考虑特征之间的关联信息;
    2. 特征高维稀疏,导致计算量大,特征权值更新缓慢;
      FM正好能解决特征交互问题;另外FM 通过引入隐向量,能够降低稀疏特征的维度;提高交互特征参数评估;
  • FM长啥样
    1. 在这里插入图片描述
    2. 特征组合
      在这里插入图片描述
      矩阵分解提供了一种解决思路。在model-based的协同过滤中,一个rating矩阵可以分解为user矩阵和item矩阵,每个user和item都可以采用一个隐向量表示。比如在下图中的例子中,我们把每个user表示成一个二维向量,同时把每个item表示成一个二维向量,两个向量的点积就是矩阵中user对item的打分。
      上图中n个user,经过one-hot编码,变成n1n维,属于高度稀疏矩阵;为了减少维度,可以通过引入隐向量,将每个user映射为k维,则特征矩阵维度变为n1k维,即上图中的表示形式。

类似地,所有二次项参数 <Vi,Vj>可以组成一个对称阵 W(为了方便说明FM的由来,对角元素可以设置为正实数),那么这个矩阵就可以分解为 W=Vt*V,V 的第 j列( vj)便是第 j维特征( xj)的隐向量。换句话说,特征分量xj 和xi 的交叉项系数就等于xi对应的隐向量与xj对应的隐向量的内积,即每个参数 wij=<vi,vj>,这就是FM模型的核心思想。
关于隐向量,这里的 vi是xi 特征的低纬稠密表达,实际中隐向量的长度通常远小于特征维度N,在我做的实验中长度都是4。在实际的CTR场景中,数据都是很稀疏的category特征,通常表示成离散的one-hot形式,这种编码方式,使得one-hot vector非常长,而且很稀疏,同时特征总量也骤然增加,达到千万级甚至亿级别都是有可能的,而实际上的category特征数目可能只有几百维。FM学到的隐向量可以看做是特征的一种embedding表示,把离散特征转化为Dense Feature,这种Dense Feature还可以后续和DNN来结合,作为DNN的输入,事实上用于DNN的CTR也是这个思路来做的。

  • FM应用场景
    在这里插入图片描述

  • FM code实现
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

# -*- coding: utf-8 -*-

from __future__ import division
from math import exp
from numpy import *
from random import normalvariate  # 正态分布
from sklearn import preprocessing
import numpy as np

'''
    data : 数据的路径
    feature_potenital : 潜在分解维度数
    alpha : 学习速率
    iter : 迭代次数
    _w,_w_0,_v : 拆分子矩阵的weight
    with_col : 是否带有columns_name
    first_col : 首列有价值的feature的index
'''


class fm(object):
    def __init__(self):
        self.data = None
        self.feature_potential = None
        self.alpha = None
        self.iter = None
        self._w = None
        self._w_0 = None
        self.v = None
        self.with_col = None
        self.first_col = None

    def min_max(self, data):
        self.data = data
        min_max_scaler = preprocessing.MinMaxScaler()
        return min_max_scaler.fit_transform(self.data)

    def loadDataSet(self, data, with_col=True, first_col=2):
        # 我就是闲的蛋疼,明明pd.read_table()可以直接度,非要搞这样的,显得代码很长,小数据下完全可以直接读嘛,唉~
        self.first_col = first_col
        dataMat = []
        labelMat = []
        fr = open(data)
        self.with_col = with_col
        if self.with_col:
            N = 0
            for line in fr.readlines():
                # N=1时干掉列表名
                if N > 0:
                    currLine = line.strip().split()
                    lineArr = []
                    featureNum = len(currLine)
                    for i in range(self.first_col, featureNum):
                        lineArr.append(float(currLine[i]))
                    dataMat.append(lineArr)
                    labelMat.append(float(currLine[1]) * 2 - 1)
                N = N + 1
        else:
            for line in fr.readlines():
                currLine = line.strip().split()
                lineArr = []
                featureNum = len(currLine)
                for i in range(2, featureNum):
                    lineArr.append(float(currLine[i]))
                dataMat.append(lineArr)
                labelMat.append(float(currLine[1]) * 2 - 1)
        return mat(self.min_max(dataMat)), labelMat

    def sigmoid(self, inx):
        # return 1.0/(1+exp(min(max(-inx,-10),10)))
        return 1.0 / (1 + exp(-inx))

    # 得到对应的特征weight的矩阵
    def fit(self, data, feature_potential=8, alpha=0.01, iter=100):
        # alpha是学习速率
        self.alpha = alpha
        self.feature_potential = feature_potential
        self.iter = iter
        # dataMatrix用的是mat, classLabels是列表
        dataMatrix, classLabels = self.loadDataSet(data)
        print('dataMatrix:',dataMatrix.shape)
        print('classLabels:',classLabels)
        k = self.feature_potential
        m, n = shape(dataMatrix)
        # 初始化参数
        w = zeros((n, 1))  # 其中n是特征的个数
        w_0 = 0.
        v = normalvariate(0, 0.2) * ones((n, k))
        for it in range(self.iter): # 迭代次数
            # 对每一个样本,优化
            for x in range(m):
                # 这边注意一个数学知识:对应点积的地方通常会有sum,对应位置积的地方通常都没有,详细参见矩阵运算规则,本处计算逻辑在:http://blog.csdn.net/google19890102/article/details/45532745
                # xi·vi,xi与vi的矩阵点积
                inter_1 = dataMatrix[x] * v
                # xi与xi的对应位置乘积   与   xi^2与vi^2对应位置的乘积    的点积
                inter_2 = multiply(dataMatrix[x], dataMatrix[x]) * multiply(v, v)  # multiply对应元素相乘
                # 完成交叉项,xi*vi*xi*vi - xi^2*vi^2
                interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
                # 计算预测的输出
                p = w_0 + dataMatrix[x] * w + interaction
                print('classLabels[x]:',classLabels[x])
                print('预测的输出p:', p)
                # 计算sigmoid(y*pred_y)-1
                loss = self.sigmoid(classLabels[x] * p[0, 0]) - 1
                if loss >= -1:
                    loss_res = '正方向 '
                else:
                    loss_res = '反方向'
                # 更新参数
                w_0 = w_0 - self.alpha * loss * classLabels[x]
                for i in range(n):
                    if dataMatrix[x, i] != 0:
                        w[i, 0] = w[i, 0] - self.alpha * loss * classLabels[x] * dataMatrix[x, i]
                        for j in range(k):
                            v[i, j] = v[i, j] - self.alpha * loss * classLabels[x] * (
                                    dataMatrix[x, i] * inter_1[0, j] - v[i, j] * dataMatrix[x, i] * dataMatrix[x, i])
            print('the no %s times, the loss arrach %s' % (it, loss_res))
        self._w_0, self._w, self._v = w_0, w, v

    def predict(self, X):
        if (self._w_0 == None) or (self._w == None).any() or (self._v == None).any():
            raise NotFittedError("Estimator not fitted, call `fit` first")
        # 类型检查
        if isinstance(X, np.ndarray):
            pass
        else:
            try:
                X = np.array(X)
            except:
                raise TypeError("numpy.ndarray required for X")
        w_0 = self._w_0
        w = self._w
        v = self._v
        m, n = shape(X)
        result = []
        for x in range(m):
            inter_1 = mat(X[x]) * v
            inter_2 = mat(multiply(X[x], X[x])) * multiply(v, v)  # multiply对应元素相乘
            # 完成交叉项
            interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
            p = w_0 + X[x] * w + interaction  # 计算预测的输出
            pre = self.sigmoid(p[0, 0])
            result.append(pre)
        return result

    def getAccuracy(self, data):
        dataMatrix, classLabels = self.loadDataSet(data)
        w_0 = self._w_0
        w = self._w
        v = self._v
        m, n = shape(dataMatrix)
        allItem = 0
        error = 0
        result = []
        for x in range(m):
            allItem += 1
            inter_1 = dataMatrix[x] * v
            inter_2 = multiply(dataMatrix[x], dataMatrix[x]) * multiply(v, v)  # multiply对应元素相乘
            # 完成交叉项
            interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
            p = w_0 + dataMatrix[x] * w + interaction  # 计算预测的输出
            pre = self.sigmoid(p[0, 0])
            result.append(pre)
            if pre < 0.5 and classLabels[x] == 1.0:
                error += 1
            elif pre >= 0.5 and classLabels[x] == -1.0:
                error += 1
            else:
                continue
        # print(result)
        value = 1 - float(error) / allItem
        return value


class NotFittedError(Exception):
    """
    Exception class to raise if estimator is used before fitting
    """
    pass


if __name__ == '__main__':
    fm()

参考连接

  1. https://zhuanlan.zhihu.com/p/3796326
  2. https://www.jianshu.com/p/9a3416ed683b
  3. https://www.cnblogs.com/wkang/p/9588360.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1593125.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【详细教程】MySQL 高可用架构代码实现

前言 对于 MySQL 数据库作为各个业务系统的存储介质&#xff0c;在系统中承担着非常重要的职责&#xff0c;如果数据库崩了&#xff0c;那么对于读和写数据库的操作都会受到影响。如果不能迅速恢复&#xff0c;对业务的影响是非常大的。之前 B 站不是出过一次事故么&#xff0…

解决jenkins运行sh报process apparently never started in XXX

个人记录 问题 process apparently never started in /var/jenkins_home/workspace/ks-springboot_mastertmp/durable-bbfe5f99(running Jenkins temporarily with -Dorg.jenkinsci.plugins.durabletask.BourneShellScript.LAUNCH_DIAGNOSTICStrue might make the problem cl…

Electron+React 搭建桌面应用

创建应用程序 创建 Electron 应用 使用 Webpack 创建新的 Electron 应用程序&#xff1a; npm init electron-applatest my-new-app -- --templatewebpack 启动应用 npm start 设置 Webpack 配置 添加依赖包&#xff0c;确保可以正确使用 JSX 和其他 React 功能&#xff…

Java基础(一)--语法入门

文章目录 第一章、语法入门一、Java简介1、JVM2、Java程序执行过程3、JDK4、JRE5、JDK、JRE和JVM三者关系 二、Java常量与变量1、标识符2、关键字3、保留字4、变量5、数据类型6、常量 三、运算符1、算术运算符2、赋值运算符3、关系运算符4、逻辑运算符5、条件运算符6、运算符的…

全国各省环境规制强度数据(2004-2022年)

01、数据简介 以保护环境为目的&#xff0c;对各种环境污染行为进行规制&#xff0c;政府相关政策规制&#xff0c;是社会性规制的重要内容&#xff0c;包含大气、水、废弃物、噪声污染等外部行为&#xff0c;对这些行为进行规制就是要将整个社会为其承担的成本转化为其自身承…

funasr 麦克风实时流语音识别

参考: https://github.com/alibaba-damo-academy/FunASR chunk_size 是用于流式传输延迟的配置。[0,10,5] 表示实时显示的粒度为 1060=600 毫秒,并且预测的向前信息为 560=300 毫秒。每个推理输入为 600 毫秒(采样点为 16000*0.6=960),输出为相应的文本。对于最后一个语音…

根据状态转移图实现时序电路

描述 某同步时序电路的状态转换图如下&#xff0c;→上表示“C/Y”&#xff0c;圆圈内为现态&#xff0c;→指向次态。 请使用D触发器和必要的逻辑门实现此同步时序电路&#xff0c;用Verilog语言描述。 如图所示&#xff1a; 电路的接口如下图所示&#xff0c;C是单bit数据…

MySQL基础入门上篇

MySQL基础 介绍 mysql -uroot -p -h127.0.0.1 -P3306项目设计 具备数据库一定的设计能力和操作数据的能力。 数据库设计DDL 定义 操作 显示所有数据库 show databases;创建数据库 create database db02;数据库名唯一&#xff0c;不能重复。 查询是否创建成功 加入一些…

文本检索粗读

一.前情提要 1.本文理论为主&#xff0c;并且仅为个人理解&#xff0c;能力一般&#xff0c;不喜勿喷 2.本文理论知识较为散碎 3.如有需要&#xff0c;以下是原文&#xff0c;更为完备 Neural Corpus Indexer 文档检索【论文精读47】_哔哩哔哩_bilibili 二.正文 &#xf…

重大璧山院_APP_apk_安卓端下载

主要是方便去重庆大学璧山研究院搞科研的学生&#xff0c; 这个安卓安装包&#xff0c;在网上很难搜到。 找半天才搞到手、蓝奏云下载 https://wwb.lanzn.com/iqnro1v1bwkh 密码:i3n2

防止邮箱发信泄露服务器IP教程

使用QQ邮箱,网易邮箱,189邮箱,新浪邮箱,139邮箱可能会泄露自己的服务器IP。 泄露原理&#xff1a;服务器通过请求登录SMTP邮箱服务器接口&#xff0c;对指定的收件人发送信息。 建议大家使用商业版的邮箱&#xff0c;比如阿里云邮箱发信等 防止邮件发信漏源主要关注的是确保邮件…

C语言 数据输入输出

本文 我们来说 数据的输入与输出 及数据的运算 在程序的运算工程中 往往需要输入一些数据 而程序的运算 所得到的运算结果又需要输出给用户 因此 数据的输入与输出 就显得非常重要 在C语言中 不提供专门的输入输出语句 所有的输入输出 都是通过对标准库的调用 来实现的 一般 …

权威Scrum敏捷开发企业级实训/敏捷开发培训课程

课程简介 Scrum是目前运用最为广泛的敏捷开发方法&#xff0c;是一个轻量级的项目管理和产品研发管理框架。 这是一个两天的实训课程&#xff0c;面向研发管理者、项目经理、产品经理、研发团队等&#xff0c;旨在帮助学员全面系统地学习Scrum和敏捷开发, 帮助企业快速启动敏…

抖音滑块验证码加密的盐的位置

最近更新后之前很容易找到盐的位置的方法变了&#xff0c;抖音特意把盐隐藏起来了 {"reply": "RJC","models": "yAd8rl","in_modal": "DTn0nD2","in_slide": "ou7H0Ngda","move": …

基于java+springboot+vue实现的网上购物系统(文末源码+Lw+ppt)23-42

摘 要 随着我国经济的高速发展与人们生活水平的日益提高&#xff0c;人们对生活质量的追求也多种多样。尤其在人们生活节奏不断加快的当下&#xff0c;人们更趋向于足不出户解决生活上的问题&#xff0c;网上购物系统展现了其蓬勃生命力和广阔的前景。与此同时&#xff0c;为…

走进MySQL:从认识到入门(针对初学者)

一&#xff0c;引言 MySQL是一款久负盛名且广泛应用的关系型数据库管理系统&#xff0c;自1995年Michael Widenius和David Axmark在瑞典和芬兰发起研发以来&#xff0c;其发展历程可谓辉煌且深远。作为开源软件的代表&#xff0c;MySQL以其卓越的成本效益、高性能及高可靠性赢得…

【数据结构与算法】:二叉树经典OJ

目录 1. 二叉树的前序遍历 (中&#xff0c;后序类似)2. 二叉树的最大深度3. 平衡二叉树4. 二叉树遍历 1. 二叉树的前序遍历 (中&#xff0c;后序类似) 这道题的意思是对二叉树进行前序遍历&#xff0c;把每个结点的值都存入一个数组中&#xff0c;并且返回这个数组。 思路&…

c++11 标准模板(STL)本地化库 - 平面类别(std::codecvt) - 在字符编码间转换,包括 UTF-8、UTF-16、UTF-32 (四)

本地化库 本地环境设施包含字符分类和字符串校对、数值、货币及日期/时间格式化和分析&#xff0c;以及消息取得的国际化支持。本地环境设置控制流 I/O 、正则表达式库和 C 标准库的其他组件的行为。 平面类别 在字符编码间转换&#xff0c;包括 UTF-8、UTF-16、UTF-32 std::…

ReactRouter

React-Router 概念&#xff1a;一个路劲path对应一个组件component 当我们在浏览器中访问一个path的时候&#xff0c;path对应的组件会在页面中进行渲染路由语法&#xff1a; import {createBrowserRouter, RouterProvider} from react-router-dom// 1. 创建router实例对象并…

【数据结构】习题之链表的回文结构和相交链表

&#x1f451;个人主页&#xff1a;啊Q闻 &#x1f387;收录专栏&#xff1a;《数据结构》 &#x1f389;前路漫漫亦灿灿 前言 今日的习题是关于链表的&#xff0c;分别是链表的回文结构和相交链表的判断。 链表的回文结构 题目为&#xff1a;链表的回文结…