机器学习系列----关联分析
目录
在机器学习和数据挖掘领域,关联分析(Association Analysis) 是一种非常重要的技术,尤其在市场篮子分析和推荐系统中得到了广泛的应用。它的核心任务是发现不同变量或项目之间的有趣关系,通常以关联规则的形式表示。关联分析通过揭示数据之间的隐藏模式,帮助我们理解数据的结构,并为决策提供支持。
本文将介绍关联分析的基本概念、常用算法以及一个较为复杂的项目实现,以帮助大家更好地理解和应用关联分析技术。
1. 关联分析的基本概念
1.1定义
关联分析主要用于挖掘数据集中的频繁项集(Frequent Itemsets)和关联规则(Association Rules)。关联规则通常采用“如果-那么”的形式,即:如果条件 A 成立,则条件 B 成立。最常见的应用场景是市场篮子分析,在这个场景中,A 和 B 代表顾客购买的商品。
关联规则通常包含三个重要的度量:
支持度(Support):规则中项集的出现频率,表示在整个数据集中,A 和 B 同时出现的概率。
置信度(Confidence):规则的可靠性,表示在所有包含 A 的记录中,有多少比例同时包含 B。
提升度(Lift):规则的增强程度,衡量规则中项集之间的关联程度,表示 A 出现时,B 出现的概率是 A 不出现时的多少倍。
这些度量帮助我们从数据中筛选出最有意义的规则。
1.2常用算法
在实际应用中,关联分析最常用的算法是 Apriori 算法 和 FP-Growth 算法。
Apriori 算法
Apriori 算法通过逐层生成候选频繁项集,来挖掘数据中最频繁的项集。它的核心思想是“剪枝”:如果某个项集不是频繁的,那么它的所有超集也一定不是频繁的。该算法的过程通常包括以下几个步骤:
从单个项集开始,生成所有候选项集。
计算候选项集的支持度,并筛选出频繁项集。
使用频繁项集生成更大的候选项集,并重复步骤2。
FP-Growth 算法
FP-Growth(Frequent Pattern Growth)算法是一种基于压缩数据结构FP树(Frequent Pattern Tree)的算法,相比于 Apriori,FP-Growth 更加高效。它通过构建 FP 树来压缩数据,并递归地挖掘频繁项集,避免了候选项集生成的过程,因此在大数据集上具有较好的性能。
2.Apriori 算法的实现
2.1 工作原理
Apriori 算法是最早用于挖掘频繁项集和生成关联规则的算法之一,其核心思想是通过“剪枝”来减少候选项集的数量。算法的基本步骤如下:
生成候选项集:
首先,从单个项集开始,生成所有候选项集。然后,通过计算每个候选项集的支持度,筛选出支持度大于最小支持度的频繁项集。
接着,使用频繁项集生成更大的候选项集,直到不能生成新的候选项集为止。
剪枝:
剪枝策略是 Apriori 算法的核心思想。如果一个项集不是频繁的,那么它的所有超集也不可能是频繁的。因此,可以通过剪除不频繁的项集来减少计算量。
迭代计算:
每次迭代都会生成更大的项集,并筛选出频繁项集。这个过程会持续,直到找不到新的频繁项集为止。
2.2 算法步骤
假设数据集为 (D),最小支持度为 (min_support),最小置信度为 (min_confidence),Apriori 算法的基本步骤如下:
初始化:
将单项集(长度为1的项集)作为候选项集生成。
频繁项集生成:
在每一轮迭代中,从上一次得到的频繁项集生成候选项集。
计算这些候选项集的支持度,筛选出频繁项集。
生成规则:
基于频繁项集生成关联规则,并计算每条规则的置信度和提升度。
过滤出置信度和提升度大于最小阈值的规则。
2.3 优缺点
优点:
算法简单,容易理解。
可以用于发现各种类型的关联规则,不限于二元关联。
缺点:
计算代价较高,因为候选项集生成的过程会导致大量的计算。
在数据集较大时,计算频繁项集的过程中需要扫描数据库多次,效率较低。
2.4 时间复杂度
假设数据集中的项数为 (n),数据集的大小为 (m),最小支持度为 (min_support)。
最坏情况下,Apriori 需要扫描数据集 (k) 次((k) 为频繁项集的最大长度),每次扫描时,需要生成和验证大量的候选项集,因此时间复杂度通常较高。
2.5实际运用----市场购物篮分析
from mlxtend.frequent_patterns import apriori, association_rules
import pandas as pd
# 示例交易数据
transactions = [
['牛奶', '面包', '黄油'],
['面包', '黄油'],
['牛奶', '黄油'],
['面包', '牛奶', '黄油'],
]
# 转换为 DataFrame 格式
from mlxtend.preprocessing import TransactionEncoder
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_ary, columns=te.columns_)
# 使用 Apriori 算法挖掘频繁项集
frequent_itemsets = apriori(df, min_support=0.5, use_colnames=True)
# 生成关联规则
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.7)
# 输出关联规则
print(rules)
3. FP-Growth 算法
3.1 工作原理
FP-Growth 算法(Frequent Pattern Growth)是一个改进的算法,旨在提高频繁项集挖掘的效率。与 Apriori 不同,FP-Growth 避免了候选项集生成过程,通过构建一个压缩的树结构(FP-tree)来存储数据集,从而显著提高了性能。
3.2 算法步骤
FP-Growth 算法的步骤如下:
构建 FP-树:
扫描数据集,统计各项的频率,然后按频率排序。
根据频率排序后的项集,构建一个树结构(FP-tree)。树的每个节点表示一个项,每个路径表示一个事务的项集。
递归挖掘频繁项集:
从 FP-树中逐步挖掘频繁项集。对于每个节点,构造条件模式基(Conditional Pattern Base)并递归地构建条件 FP-树,直到无法生成新的频繁项集为止。
生成关联规则:
基于挖掘出的频繁项集,生成关联规则,并计算规则的置信度。
3.3 优缺点
优点:
FP-Growth 通过构建 FP 树来压缩数据,避免了频繁项集生成的过程,因此通常比 Apriori 更高效。
不需要多次扫描整个数据集,计算速度较快,适合大规模数据集。
FP-Growth 支持快速的递归计算,避免了生成大量候选项集的开销。
缺点:
需要存储树结构,可能会消耗较多内存,尤其是在数据集非常大的情况下。
需要构建树结构,理解和实现相对复杂。
3.4 时间复杂度
FP-Growth 的时间复杂度通常较低,因为它仅需要扫描数据集两次:第一次构建 FP 树,第二次递归地挖掘频繁项集。
如果数据集非常大,构建和递归过程中的内存消耗可能会成为瓶颈,但总体上比 Apriori 更高效。
3.5实际运用——网页点击行为分析
from mlxtend.frequent_patterns import fpgrowth, association_rules
import pandas as pd
# 示例用户点击数据
transactions = [
['首页', '产品页面A', '购物车'],
['首页', '产品页面B', '购物车'],
['产品页面A', '购物车'],
['首页', '产品页面A', '产品页面B', '购物车'],
]
# 转换为 DataFrame 格式
from mlxtend.preprocessing import TransactionEncoder
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_ary, columns=te.columns_)
# 使用 FP-Growth 算法挖掘频繁项集
frequent_itemsets = fpgrowth(df, min_support=0.6, use_colnames=True)
# 生成关联规则
rules = association_rules(frequent_itemsets, metric="confidence", min_threshold=0.75)
# 输出关联规则
print(rules)
4.Apriori 与 FP-Growth 的对比
5.具体项目——基于Apriori 算法的市场篮子分析
5.1项目目标
数据准备:从数据源(如HBase、HDFS等)读取用户的购买记录。
数据处理:清洗数据,进行合适的格式化。
Apriori算法实现:用于发现频繁项集,并生成关联规则。
结果存储:将关联规则存储到HBase或HDFS,并进行进一步分析。
性能优化:使用Spark进行分布式计算,以提高算法效率。
5.2项目目录结构
market-basket-analysis/
├── data/
│ └── transactions.csv # 输入的交易数据
├── src/
│ ├── main/
│ │ ├── scala/
│ │ │ ├── Apriori.scala # Apriori算法核心实现
│ │ │ ├── DataPreprocessor.scala # 数据清洗与预处理
│ │ │ ├── HBaseConnector.scala # HBase连接与数据存储
│ │ │ ├── SparkApriori.scala # 使用Spark并行化Apriori算法
│ │ │ └── ResultWriter.scala # 结果存储模块
│ ├── test/
│ │ └── AprioriTest.scala # 单元测试
├── pom.xml # Maven构建文件
└── README.md # 项目说明文档
5.3核心功能实现
1. 数据预处理(DataPreprocessor.scala)
首先,我们需要从存储系统(如HDFS、HBase)读取原始数据,进行清洗和格式化。
import org.apache.spark.sql.SparkSession
object DataPreprocessor {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("Market Basket Analysis").getOrCreate()
// 假设数据源是一个CSV文件,格式为: 用户ID, 商品ID
val data = spark.read.option("header", "true").csv("data/transactions.csv")
// 将交易数据格式化成购物篮的形式
val transactions = data.groupBy("user_id").agg(collect_list("product_id").alias("basket"))
transactions.show()
// 你可以将数据保存到HDFS、HBase或者其他地方
transactions.write.parquet("data/processed_transactions")
spark.stop()
}
}
2. Apriori算法实现(Apriori.scala)
Apriori.scala 实现了经典的 Apriori 算法,用于发现频繁项集和生成关联规则。为了简化,我们使用 RDD 操作,且算法会输出每一轮的候选项集。
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
object Apriori {
def generateFrequentItemsets(transactions: RDD[Set[String]], minSupport: Double): RDD[(Set[String], Int)] = {
var k = 1
var frequentItemsets: RDD[(Set[String], Int)] = transactions.flatMap(t => t.subsets(k))
.map(itemset => (itemset, 1))
.reduceByKey(_ + _)
.filter { case (_, count) => count >= minSupport * transactions.count() }
// 生成k项集
while (!frequentItemsets.isEmpty()) {
k += 1
val candidateItemsets = frequentItemsets.flatMap { case (itemset, _) =>
itemset.subsets(k).toSeq
}.map(itemset => (itemset, 1))
.reduceByKey(_ + _)
.filter { case (_, count) => count >= minSupport * transactions.count() }
frequentItemsets = candidateItemsets
}
frequentItemsets
}
def generateAssociationRules(frequentItemsets: RDD[(Set[String], Int)], minConfidence: Double): RDD[(Set[String], Set[String], Double)] = {
frequentItemsets.flatMap { case (itemset, support) =>
itemset.subsets(itemset.size - 1).map { antecedent =>
val consequent = itemset -- antecedent
val confidence = support.toDouble / frequentItemsets.lookup(antecedent).head
if (confidence >= minConfidence) Some(antecedent, consequent, confidence)
else None
}.flatten
}
}
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("Apriori Algorithm").getOrCreate()
val sc = spark.sparkContext
// 加载数据
val transactions = sc.textFile("data/processed_transactions")
.map(line => line.split(",").toSet) // 数据转换为 Set 形式的交易数据
val minSupport = 0.03
val minConfidence = 0.6
// 生成频繁项集
val frequentItemsets = generateFrequentItemsets(transactions, minSupport)
// 生成关联规则
val associationRules = generateAssociationRules(frequentItemsets, minConfidence)
// 输出结果
associationRules.saveAsTextFile("data/association_rules")
spark.stop()
}
}
3. 分布式 Apriori 算法(SparkApriori.scala)
为了提高计算效率,可以使用 Spark 对 Apriori 算法进行并行化。以下代码对 Apriori 算法进行了 Spark 优化,支持大规模数据集的处理。
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD
object SparkApriori {
def parallelGenerateFrequentItemsets(transactions: RDD[Set[String]], minSupport: Double): RDD[(Set[String], Int)] = {
// 同前面简单Apriori算法,使用RDD进行并行化
Apriori.generateFrequentItemsets(transactions, minSupport)
}
def parallelGenerateAssociationRules(frequentItemsets: RDD[(Set[String], Int)], minConfidence: Double): RDD[(Set[String], Set[String], Double)] = {
// 同前面Apriori算法,生成关联规则
Apriori.generateAssociationRules(frequentItemsets, minConfidence)
}
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("Distributed Apriori").getOrCreate()
val sc = spark.sparkContext
val transactions = sc.textFile("data/processed_transactions")
.map(line => line.split(",").toSet)
val minSupport = 0.03
val minConfidence = 0.6
// 并行化生成频繁项集
val frequentItemsets = parallelGenerateFrequentItemsets(transactions, minSupport)
// 并行化生成关联规则
val associationRules = parallelGenerateAssociationRules(frequentItemsets, minConfidence)
// 输出结果
associationRules.saveAsTextFile("data/association_rules_output")
spark.stop()
}
}
4. 结果存储模块(ResultWriter.scala)
结果存储部分使用 HBase 将生成的关联规则存储到表中,便于查询。
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client.ConnectionFactory
import org.apache.hadoop.hbase.HTableDescriptor
import org.apache.hadoop.hbase.TableNotFoundException
object ResultWriter {
def saveToHBase(rules: RDD[(Set[String], Set[String], Double)]): Unit = {
val conf = HBaseConfiguration.create()
val connection = ConnectionFactory.createConnection(conf)
try {
val table = connection.getTable(TableName.valueOf("association_rules"))
// 向 HBase 写入每条规则
rules.foreach { case (antecedent, consequent, confidence) =>
val rowKey = Bytes.toBytes(antecedent.mkString(",") + ":" + consequent.mkString(","))
val put = new Put(rowKey)
// 将前提和结论、置信度存储在对应的列族中
put.addColumn(Bytes.toBytes("rule"), Bytes.toBytes("antecedent"), Bytes.toBytes(antecedent.mkString(",")))
put.addColumn(Bytes.toBytes("rule"), Bytes.toBytes("consequent"), Bytes.toBytes(consequent.mkString(",")))
put.addColumn(Bytes.toBytes("rule"), Bytes.toBytes("confidence"), Bytes.toBytes(confidence.toString))
// 执行写操作
table.put(put)
}
println("Successfully written association rules to HBase.")
table.close()
} catch {
case e: TableNotFoundException =>
println("HBase table 'association_rules' not found. Please ensure the table exists.")
case e: Exception =>
println(s"An error occurred while saving results to HBase: ${e.getMessage}")
} finally {
// 关闭连接
connection.close()
}
}
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder.appName("Result Writer").getOrCreate()
// 假设之前我们已经生成了关联规则,存在文件中
val rules = spark.sparkContext.textFile("data/association_rules_output")
.map { line =>
val parts = line.split(",")
val antecedent = parts(0).split(":").toSet
val consequent = parts(1).split(":").toSet
val confidence = parts(2).toDouble
(antecedent, consequent, confidence)
}
// 将关联规则保存到 HBase
saveToHBase(rules)
spark.stop()
}
}
5. 测试模块(AprioriTest.scala)
为了确保代码的正确性,我们需要为 Apriori 算法编写一些单元测试。你可以使用 ScalaTest 或 JUnit 等框架进行测试。以下是使用 ScalaTest 编写的一个简单的测试示例
import org.scalatest._
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD
class AprioriTest extends FlatSpec with Matchers {
"Apriori algorithm" should "generate correct frequent itemsets" in {
val spark = SparkSession.builder.appName("AprioriTest").master("local").getOrCreate()
val sc = spark.sparkContext
// 示例交易数据
val transactions = sc.parallelize(Seq(
Set("apple", "banana", "cherry"),
Set("banana", "cherry"),
Set("apple", "banana"),
Set("apple", "cherry")
))
// 生成频繁项集
val frequentItemsets = Apriori.generateFrequentItemsets(transactions, minSupport = 0.5)
// 检查频繁项集
frequentItemsets.collect() should contain allOf(
(Set("apple"), 3),
(Set("banana"), 3),
(Set("cherry"), 3),
(Set("apple", "banana"), 2),
(Set("banana", "cherry"), 2)
)
spark.stop()
}
it should "generate correct association rules" in {
val spark = SparkSession.builder.appName("AprioriTest").master("local").getOrCreate()
val sc = spark.sparkContext
// 示例交易数据
val transactions = sc.parallelize(Seq(
Set("apple", "banana", "cherry"),
Set("banana", "cherry"),
Set("apple", "banana"),
Set("apple", "cherry")
))
// 生成频繁项集
val frequentItemsets = Apriori.generateFrequentItemsets(transactions, minSupport = 0.5)
// 生成关联规则
val associationRules = Apriori.generateAssociationRules(frequentItemsets, minConfidence = 0.5)
// 检查生成的关联规则
associationRules.collect() should contain allOf(
(Set("apple"), Set("banana"), 0.6666666666666666),
(Set("banana"), Set("apple"), 0.6666666666666666)
)
spark.stop()
}
}
6. 项目构建与依赖(pom.xml)
为了方便项目的构建和依赖管理,我们使用 Maven。以下是 pom.xml 文件的基础配置。
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>market-basket-analysis</artifactId>
<version>1.0-SNAPSHOT</version>
<packaging>jar</packaging>
<dependencies>
<!-- Spark Core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>3.3.0</version>
</dependency>
<!-- Spark SQL (for SparkSession) -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>3.3.0</version>
</dependency>
<!-- HBase client -->
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-client</artifactId>
<version>2.4.8</version>
</dependency>
<!-- ScalaTest for unit testing -->
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_2.12</artifactId>
<version>3.2.9</version>
<scope>test</scope>
</dependency>
<!-- Hadoop Common (required for HBase and Spark) -->
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>3.3.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
</project>
原文地址:https://blog.csdn.net/DK22151/article/details/143864437
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!