自学内容网 自学内容网

线上lgb使用

1. 单机版模型 转 spark集群 打分

速度超快,十亿数据,十多分钟!!!

1.1 主函数-主要获取模型路径

# coding=utf-8
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType
import argparse, bisect, os
import lightgbm as lgb
import pandas as pd

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dt", type=str)
    args = parser.parse_args()
    dt = args.dt
    spark = SparkSession.builder.appName("model_predict").enableHiveSupport().getOrCreate()
    sc = spark.sparkContext
    sc.setLogLevel("ERROR")

    current_path = os.getcwd() + '/files/'  # 获取线上项目路径,files为项目别名
    model_file_path = current_path + f'superior_user/lgb_model'
    pred(dt, spark, model_file_path)

1.2 定义集群打分函数

模型训练时的特征处理,预测时在这里进行。

这里每次处理的是一条数据,会对集群上各个pod上的数据进行处理。

def predict_udf(model, feature_list):
    def inner(*selectCols):
        X = pd.DataFrame(dict(zip(feature_list, selectCols)))
        # X = X.astype("float")
        for col in ['fea1', 'fea2', 'fea3']:  # 类别特征指定
            X[col] = X[col].astype('category')
        y_pre = model.predict(X)
        return pd.Series(y_pre)

    return F.pandas_udf(inner, returnType=DoubleType())

1.3 模型结果写入

def pred(dt, spark, model_path):
    gbm = lgb.Booster(model_file=model_path)
    dt = '20241104'
    feature_list = ['fea1', 'fea2', 'fea3', 'fea4']
    udf = predict_udf(gbm, feature_list)

    sql = f"select * from db.table_name where dt = '{dt}'"
    predict_data = spark.sql(sql)
    pred_df = predict_data.select('uid', 'market', # 保留预测集中的uid、market
        udf(*feature_list).alias('score') # 加入模型score
    )
    print('pred_df: ', pred_df.show(10))
    pred_df.createOrReplaceTempView("pred_df")

2. 结果后处理

2.1

black_df = spark.sql(f"""
    select member_id, 0 as label 
    from db.table_name
    where black_flag = '1'
    group by member_id 
""")
black_df.createOrReplaceTempView("black_df")
spark.sql("""
    create table db.table_name_1 stored as orc as 
    select t1.uid,
    if(t2.uid is null,t1.score,0) as score,
    if(t2.uid is not null,1,0) as bad_flag 
    from pred_df t1 
    left join black_df t2 
    on t1.uid = t2.uid 
""")

pred_df.createOrReplaceTempView("stat_df")
if pred_df.count() == 0:
    raise Exception("the table is null")

2.2 打分根据市场分桶

ppercentile_data = spark.sql("""
    select market,
    percentile_approx(score,array(0.20,0.40,0.60,0.80)) as score_bucket
    from db.table_name_1
    group by market
""").toPandas()
sc = spark.sparkContext
d = {}
for i in percentile_data.to_dict("records"):
    market = i["market"]
    score_bucket = i["score_bucket"]
    d[market] = score_bucket

d_bc = sc.broadcast(d)
def my_udf1(market, score):
    if market is None:
        return 0
    score_bucket = d_bc.value[market]
    x = bisect.bisect_left(score_bucket, score)
    return x

my_udf1 = F.udf(my_udf1)
stat_df = spark.sql(f"""select * from db.table_name_1 where score is not null""")
stat_df = stat_df.withColumn("score_bucket", my_udf1("market", "score"))
stat_df.createOrReplaceTempView("stat_df")
spark.sql(f"""
    insert overwrite table db.result_table partition(dt='{dt}')
    select member_id, site_tp, country_nm,
    market, score,
    cast(score_bucket as int) as score_bucket
    from stat_df 
""")
check_result(dt, spark)

1.3 结果校验

def check_result(dt, spark):
    df = spark.sql(f"""
        select market,cnt,all_cnt,cnt/all_cnt as rate
        from (
            select market,count(distinct if(score_bucket=5,uid,null)) as cnt,
            count(distinct member_id) as all_cnt 
            from db.result_table
            where dt='{dt}'
            group by market 
        )
    """)

    pdf = df.toPandas()
    for d in pdf.to_dict("records"):
        market = d.get("market")
        rate = d.get("rate")
        if market:
            if rate <= 0.045 or rate >= 0.055:
                raise Exception(f"the superior rate is not good, check please!")
    return 0

4. 配置文件

11

#!/bin/bash
set -euxo pipefail
echo "$(pwd)"
cd ${tmp_dir}
if [ ! -d ${folder_name} ]; then
    mkdir ${folder_name}
fi
cd ${folder_name}
sh ../code_pull.sh https://gitlab.baidu.cn/aiapp/in_score.git main
echo "$(ls)"
name="superior_user"
project_name="in_score"
cd ${project_name}

spark-submit --master yarn \
--deploy-mode cluster \
--driver-cores 4 \
--driver-memory 20G \
--num-executors 200 \
--executor-cores 4 \
--executor-memory 4G \
--name ${name} \
--conf spark.yarn.priority=100 \
--conf spark.storage.memoryFraction=0.5 \
--conf spark.shuffle.memoryFraction=0.5 \
--conf spark.driver.maxResultSize=10G \
--conf spark.dynamicAllocation.enabled=false \
--conf spark.executor.extraJavaOptions='-Xss128M' \
--conf spark.sql.autoBroadcastJoinThreshold=-1 \
--conf spark.sql.adaptive.enabled=true \
--conf spark.yarn.dist.archives=../${project_name}.zip#files \
--conf spark.yarn.appMasterEnv.PYTHONPATH=files \
--conf spark.executorEnv.PYTHONPATH=files \
--conf spark.pyspark.python=./env/bin/python \
--archives s3://xxaiapp/individual/bbb/condaenv/mzpy38_v2.tar.gz#env \  # 环境文件别名env
superior_user/train_test.py --dt ${dt}  # 执行主文件


原文地址:https://blog.csdn.net/MusicDancing/article/details/143582057

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!