自学内容网 自学内容网

机器学习系列----关联分析

   

目录

1. 关联分析的基本概念

1.1定义

1.2常用算法

2.Apriori 算法的实现

2.1 工作原理

2.2 算法步骤

2.3 优缺点

2.4 时间复杂度

2.5实际运用----市场购物篮分析

3. FP-Growth 算法

3.1 工作原理

3.2 算法步骤

3.3 优缺点

3.4 时间复杂度

3.5实际运用——网页点击行为分析

 4.Apriori 与 FP-Growth 的对比

5.具体项目——基于Apriori 算法的市场篮子分析

5.1项目目标

5.2项目目录结构

5.3核心功能实现


在机器学习和数据挖掘领域,关联分析(Association Analysis) 是一种非常重要的技术,尤其在市场篮子分析和推荐系统中得到了广泛的应用。它的核心任务是发现不同变量或项目之间的有趣关系,通常以关联规则的形式表示。关联分析通过揭示数据之间的隐藏模式,帮助我们理解数据的结构,并为决策提供支持。

    本文将介绍关联分析的基本概念、常用算法以及一个较为复杂的项目实现,以帮助大家更好地理解和应用关联分析技术。

1. 关联分析的基本概念

1.1定义

  关联分析主要用于挖掘数据集中的频繁项集(Frequent Itemsets)和关联规则(Association Rules)。关联规则通常采用“如果-那么”的形式,即:如果条件 A 成立,则条件 B 成立。最常见的应用场景是市场篮子分析,在这个场景中,A 和 B 代表顾客购买的商品。

关联规则通常包含三个重要的度量:

支持度(Support):规则中项集的出现频率,表示在整个数据集中,A 和 B 同时出现的概率。
置信度(Confidence):规则的可靠性,表示在所有包含 A 的记录中,有多少比例同时包含 B。
提升度(Lift):规则的增强程度,衡量规则中项集之间的关联程度,表示 A 出现时,B 出现的概率是 A 不出现时的多少倍。
这些度量帮助我们从数据中筛选出最有意义的规则。

1.2常用算法

在实际应用中,关联分析最常用的算法是 Apriori 算法 和 FP-Growth 算法。

Apriori 算法
Apriori 算法通过逐层生成候选频繁项集,来挖掘数据中最频繁的项集。它的核心思想是“剪枝”:如果某个项集不是频繁的,那么它的所有超集也一定不是频繁的。该算法的过程通常包括以下几个步骤:

从单个项集开始,生成所有候选项集。
计算候选项集的支持度,并筛选出频繁项集。
使用频繁项集生成更大的候选项集,并重复步骤2。
FP-Growth 算法
FP-Growth(Frequent Pattern Growth)算法是一种基于压缩数据结构FP树(Frequent Pattern Tree)的算法,相比于 Apriori,FP-Growth 更加高效。它通过构建 FP 树来压缩数据,并递归地挖掘频繁项集,避免了候选项集生成的过程,因此在大数据集上具有较好的性能。

2.Apriori 算法的实现

2.1 工作原理

Apriori 算法是最早用于挖掘频繁项集和生成关联规则的算法之一,其核心思想是通过“剪枝”来减少候选项集的数量。算法的基本步骤如下:

生成候选项集:

首先,从单个项集开始,生成所有候选项集。然后,通过计算每个候选项集的支持度,筛选出支持度大于最小支持度的频繁项集。
接着,使用频繁项集生成更大的候选项集,直到不能生成新的候选项集为止。
剪枝:

剪枝策略是 Apriori 算法的核心思想。如果一个项集不是频繁的,那么它的所有超集也不可能是频繁的。因此,可以通过剪除不频繁的项集来减少计算量。
迭代计算:

每次迭代都会生成更大的项集,并筛选出频繁项集。这个过程会持续,直到找不到新的频繁项集为止。

2.2 算法步骤

假设数据集为 (D),最小支持度为 (min_support),最小置信度为 (min_confidence),Apriori 算法的基本步骤如下:

初始化:

将单项集(长度为1的项集)作为候选项集生成。
频繁项集生成:

在每一轮迭代中,从上一次得到的频繁项集生成候选项集。
计算这些候选项集的支持度,筛选出频繁项集。
生成规则:

基于频繁项集生成关联规则,并计算每条规则的置信度和提升度。
过滤出置信度和提升度大于最小阈值的规则。

2.3 优缺点

优点:

算法简单,容易理解。
可以用于发现各种类型的关联规则,不限于二元关联。
缺点:

计算代价较高,因为候选项集生成的过程会导致大量的计算。
在数据集较大时,计算频繁项集的过程中需要扫描数据库多次,效率较低。

2.4 时间复杂度

假设数据集中的项数为 (n),数据集的大小为 (m),最小支持度为 (min_support)。
最坏情况下,Apriori 需要扫描数据集 (k) 次((k) 为频繁项集的最大长度),每次扫描时,需要生成和验证大量的候选项集,因此时间复杂度通常较高。

2.5实际运用----市场购物篮分析

from mlxtend.frequent_patterns import apriori, association_rules
import pandas as pd

# 示例交易数据
transactions = [
    ['牛奶', '面包', '黄油'],
    ['面包', '黄油'],
    ['牛奶', '黄油'],
    ['面包', '牛奶', '黄油'],
]

# 转换为 DataFrame 格式
from mlxtend.preprocessing import TransactionEncoder
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_ary, columns=te.columns_)

# 使用 Apriori 算法挖掘频繁项集
frequent_itemsets = apriori(df, min_support=0.5, use_colnames=True)

# 生成关联规则
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.7)

# 输出关联规则
print(rules)

3. FP-Growth 算法

3.1 工作原理

FP-Growth 算法(Frequent Pattern Growth)是一个改进的算法,旨在提高频繁项集挖掘的效率。与 Apriori 不同,FP-Growth 避免了候选项集生成过程,通过构建一个压缩的树结构(FP-tree)来存储数据集,从而显著提高了性能。

3.2 算法步骤

FP-Growth 算法的步骤如下:

构建 FP-树:

扫描数据集,统计各项的频率,然后按频率排序。
根据频率排序后的项集,构建一个树结构(FP-tree)。树的每个节点表示一个项,每个路径表示一个事务的项集。
递归挖掘频繁项集:

从 FP-树中逐步挖掘频繁项集。对于每个节点,构造条件模式基(Conditional Pattern Base)并递归地构建条件 FP-树,直到无法生成新的频繁项集为止。
生成关联规则:

基于挖掘出的频繁项集,生成关联规则,并计算规则的置信度。

3.3 优缺点

优点:

FP-Growth 通过构建 FP 树来压缩数据,避免了频繁项集生成的过程,因此通常比 Apriori 更高效。
不需要多次扫描整个数据集,计算速度较快,适合大规模数据集。
FP-Growth 支持快速的递归计算,避免了生成大量候选项集的开销。
缺点:

需要存储树结构,可能会消耗较多内存,尤其是在数据集非常大的情况下。
需要构建树结构,理解和实现相对复杂。

3.4 时间复杂度

FP-Growth 的时间复杂度通常较低,因为它仅需要扫描数据集两次:第一次构建 FP 树,第二次递归地挖掘频繁项集。
如果数据集非常大,构建和递归过程中的内存消耗可能会成为瓶颈,但总体上比 Apriori 更高效。

3.5实际运用——网页点击行为分析

from mlxtend.frequent_patterns import fpgrowth, association_rules
import pandas as pd

# 示例用户点击数据
transactions = [
    ['首页', '产品页面A', '购物车'],
    ['首页', '产品页面B', '购物车'],
    ['产品页面A', '购物车'],
    ['首页', '产品页面A', '产品页面B', '购物车'],
]

# 转换为 DataFrame 格式
from mlxtend.preprocessing import TransactionEncoder
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_ary, columns=te.columns_)

# 使用 FP-Growth 算法挖掘频繁项集
frequent_itemsets = fpgrowth(df, min_support=0.6, use_colnames=True)

# 生成关联规则
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.75)

# 输出关联规则
print(rules)

 4.Apriori 与 FP-Growth 的对比

5.具体项目——基于Apriori 算法的市场篮子分析

5.1项目目标

数据准备:从数据源(如HBase、HDFS等)读取用户的购买记录。
数据处理:清洗数据,进行合适的格式化。
Apriori算法实现:用于发现频繁项集,并生成关联规则。
结果存储:将关联规则存储到HBase或HDFS,并进行进一步分析。
性能优化:使用Spark进行分布式计算,以提高算法效率。

5.2项目目录结构

market-basket-analysis/
├── data/
│   └── transactions.csv       # 输入的交易数据
├── src/
│   ├── main/
│   │   ├── scala/
│   │   │   ├── Apriori.scala   # Apriori算法核心实现
│   │   │   ├── DataPreprocessor.scala  # 数据清洗与预处理
│   │   │   ├── HBaseConnector.scala     # HBase连接与数据存储
│   │   │   ├── SparkApriori.scala  # 使用Spark并行化Apriori算法
│   │   │   └── ResultWriter.scala  # 结果存储模块
│   ├── test/
│   │   └── AprioriTest.scala    # 单元测试
├── pom.xml                    # Maven构建文件
└── README.md                  # 项目说明文档

5.3核心功能实现

1. 数据预处理(DataPreprocessor.scala)
首先,我们需要从存储系统(如HDFS、HBase)读取原始数据,进行清洗和格式化。

import org.apache.spark.sql.SparkSession

object DataPreprocessor {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder.appName("Market Basket Analysis").getOrCreate()
    
    // 假设数据源是一个CSV文件,格式为: 用户ID, 商品ID
    val data = spark.read.option("header", "true").csv("data/transactions.csv")
    
    // 将交易数据格式化成购物篮的形式
    val transactions = data.groupBy("user_id").agg(collect_list("product_id").alias("basket"))
    
    transactions.show()
    
    // 你可以将数据保存到HDFS、HBase或者其他地方
    transactions.write.parquet("data/processed_transactions")
    
    spark.stop()
  }
}

2. Apriori算法实现(Apriori.scala)
Apriori.scala 实现了经典的 Apriori 算法,用于发现频繁项集和生成关联规则。为了简化,我们使用 RDD 操作,且算法会输出每一轮的候选项集。

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

object Apriori {

  def generateFrequentItemsets(transactions: RDD[Set[String]], minSupport: Double): RDD[(Set[String], Int)] = {
    var k = 1
    var frequentItemsets: RDD[(Set[String], Int)] = transactions.flatMap(t => t.subsets(k))
      .map(itemset => (itemset, 1))
      .reduceByKey(_ + _)
      .filter { case (_, count) => count >= minSupport * transactions.count() }

    // 生成k项集
    while (!frequentItemsets.isEmpty()) {
      k += 1
      val candidateItemsets = frequentItemsets.flatMap { case (itemset, _) =>
        itemset.subsets(k).toSeq
      }.map(itemset => (itemset, 1))
        .reduceByKey(_ + _)
        .filter { case (_, count) => count >= minSupport * transactions.count() }

      frequentItemsets = candidateItemsets
    }

    frequentItemsets
  }

  def generateAssociationRules(frequentItemsets: RDD[(Set[String], Int)], minConfidence: Double): RDD[(Set[String], Set[String], Double)] = {
    frequentItemsets.flatMap { case (itemset, support) =>
      itemset.subsets(itemset.size - 1).map { antecedent =>
        val consequent = itemset -- antecedent
        val confidence = support.toDouble / frequentItemsets.lookup(antecedent).head
        if (confidence >= minConfidence) Some(antecedent, consequent, confidence)
        else None
      }.flatten
    }
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder.appName("Apriori Algorithm").getOrCreate()
    val sc = spark.sparkContext

    // 加载数据
    val transactions = sc.textFile("data/processed_transactions")
      .map(line => line.split(",").toSet) // 数据转换为 Set 形式的交易数据

    val minSupport = 0.03
    val minConfidence = 0.6

    // 生成频繁项集
    val frequentItemsets = generateFrequentItemsets(transactions, minSupport)
    
    // 生成关联规则
    val associationRules = generateAssociationRules(frequentItemsets, minConfidence)

    // 输出结果
    associationRules.saveAsTextFile("data/association_rules")

    spark.stop()
  }
}

 3. 分布式 Apriori 算法(SparkApriori.scala)
为了提高计算效率,可以使用 Spark 对 Apriori 算法进行并行化。以下代码对 Apriori 算法进行了 Spark 优化,支持大规模数据集的处理。

import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD

object SparkApriori {

  def parallelGenerateFrequentItemsets(transactions: RDD[Set[String]], minSupport: Double): RDD[(Set[String], Int)] = {
    // 同前面简单Apriori算法,使用RDD进行并行化
    Apriori.generateFrequentItemsets(transactions, minSupport)
  }

  def parallelGenerateAssociationRules(frequentItemsets: RDD[(Set[String], Int)], minConfidence: Double): RDD[(Set[String], Set[String], Double)] = {
    // 同前面Apriori算法,生成关联规则
    Apriori.generateAssociationRules(frequentItemsets, minConfidence)
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder.appName("Distributed Apriori").getOrCreate()
    val sc = spark.sparkContext

    val transactions = sc.textFile("data/processed_transactions")
      .map(line => line.split(",").toSet)

    val minSupport = 0.03
    val minConfidence = 0.6

    // 并行化生成频繁项集
    val frequentItemsets = parallelGenerateFrequentItemsets(transactions, minSupport)
    
    // 并行化生成关联规则
    val associationRules = parallelGenerateAssociationRules(frequentItemsets, minConfidence)

    // 输出结果
    associationRules.saveAsTextFile("data/association_rules_output")

    spark.stop()
  }
}

4. 结果存储模块(ResultWriter.scala)
结果存储部分使用 HBase 将生成的关联规则存储到表中,便于查询。

import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client.ConnectionFactory
import org.apache.hadoop.hbase.HTableDescriptor
import org.apache.hadoop.hbase.TableNotFoundException

object ResultWriter {

  def saveToHBase(rules: RDD[(Set[String], Set[String], Double)]): Unit = {
    val conf = HBaseConfiguration.create()
    val connection = ConnectionFactory.createConnection(conf)

    try {
      val table = connection.getTable(TableName.valueOf("association_rules"))
      
      // 向 HBase 写入每条规则
      rules.foreach { case (antecedent, consequent, confidence) =>
        val rowKey = Bytes.toBytes(antecedent.mkString(",") + ":" + consequent.mkString(","))
        val put = new Put(rowKey)
        
        // 将前提和结论、置信度存储在对应的列族中
        put.addColumn(Bytes.toBytes("rule"), Bytes.toBytes("antecedent"), Bytes.toBytes(antecedent.mkString(",")))
        put.addColumn(Bytes.toBytes("rule"), Bytes.toBytes("consequent"), Bytes.toBytes(consequent.mkString(",")))
        put.addColumn(Bytes.toBytes("rule"), Bytes.toBytes("confidence"), Bytes.toBytes(confidence.toString))
        
        // 执行写操作
        table.put(put)
      }

      println("Successfully written association rules to HBase.")

      table.close()
    } catch {
      case e: TableNotFoundException =>
        println("HBase table 'association_rules' not found. Please ensure the table exists.")
      case e: Exception =>
        println(s"An error occurred while saving results to HBase: ${e.getMessage}")
    } finally {
      // 关闭连接
      connection.close()
    }
  }

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder.appName("Result Writer").getOrCreate()

    // 假设之前我们已经生成了关联规则,存在文件中
    val rules = spark.sparkContext.textFile("data/association_rules_output")
      .map { line =>
        val parts = line.split(",")
        val antecedent = parts(0).split(":").toSet
        val consequent = parts(1).split(":").toSet
        val confidence = parts(2).toDouble
        (antecedent, consequent, confidence)
      }

    // 将关联规则保存到 HBase
    saveToHBase(rules)

    spark.stop()
  }
}

5. 测试模块(AprioriTest.scala)
为了确保代码的正确性,我们需要为 Apriori 算法编写一些单元测试。你可以使用 ScalaTest 或 JUnit 等框架进行测试。以下是使用 ScalaTest 编写的一个简单的测试示例

import org.scalatest._
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD

class AprioriTest extends FlatSpec with Matchers {

  "Apriori algorithm" should "generate correct frequent itemsets" in {
    val spark = SparkSession.builder.appName("AprioriTest").master("local").getOrCreate()
    val sc = spark.sparkContext

    // 示例交易数据
    val transactions = sc.parallelize(Seq(
      Set("apple", "banana", "cherry"),
      Set("banana", "cherry"),
      Set("apple", "banana"),
      Set("apple", "cherry")
    ))

    // 生成频繁项集
    val frequentItemsets = Apriori.generateFrequentItemsets(transactions, minSupport = 0.5)

    // 检查频繁项集
    frequentItemsets.collect() should contain allOf(
      (Set("apple"), 3),
      (Set("banana"), 3),
      (Set("cherry"), 3),
      (Set("apple", "banana"), 2),
      (Set("banana", "cherry"), 2)
    )
    
    spark.stop()
  }

  it should "generate correct association rules" in {
    val spark = SparkSession.builder.appName("AprioriTest").master("local").getOrCreate()
    val sc = spark.sparkContext

    // 示例交易数据
    val transactions = sc.parallelize(Seq(
      Set("apple", "banana", "cherry"),
      Set("banana", "cherry"),
      Set("apple", "banana"),
      Set("apple", "cherry")
    ))

    // 生成频繁项集
    val frequentItemsets = Apriori.generateFrequentItemsets(transactions, minSupport = 0.5)

    // 生成关联规则
    val associationRules = Apriori.generateAssociationRules(frequentItemsets, minConfidence = 0.5)

    // 检查生成的关联规则
    associationRules.collect() should contain allOf(
      (Set("apple"), Set("banana"), 0.6666666666666666),
      (Set("banana"), Set("apple"), 0.6666666666666666)
    )

    spark.stop()
  }
}

6. 项目构建与依赖(pom.xml)
为了方便项目的构建和依赖管理,我们使用 Maven。以下是 pom.xml 文件的基础配置。 

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.example</groupId>
    <artifactId>market-basket-analysis</artifactId>
    <version>1.0-SNAPSHOT</version>
    <packaging>jar</packaging>

    <dependencies>
        <!-- Spark Core -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.12</artifactId>
            <version>3.3.0</version>
        </dependency>

        <!-- Spark SQL (for SparkSession) -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.12</artifactId>
            <version>3.3.0</version>
        </dependency>

        <!-- HBase client -->
        <dependency>
            <groupId>org.apache.hbase</groupId>
            <artifactId>hbase-client</artifactId>
            <version>2.4.8</version>
        </dependency>

        <!-- ScalaTest for unit testing -->
        <dependency>
            <groupId>org.scalatest</groupId>
            <artifactId>scalatest_2.12</artifactId>
            <version>3.2.9</version>
            <scope>test</scope>
        </dependency>

        <!-- Hadoop Common (required for HBase and Spark) -->
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-common</artifactId>
            <version>3.3.0</version>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.8.1</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

 

 

 

 

 


原文地址:https://blog.csdn.net/DK22151/article/details/143864437

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!