线上lgb使用
1. 单机版模型 转 spark集群 打分
速度超快,十亿数据,十多分钟!!!
1.1 主函数-主要获取模型路径
# coding=utf-8
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType
import argparse, bisect, os
import lightgbm as lgb
import pandas as pd
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dt", type=str)
args = parser.parse_args()
dt = args.dt
spark = SparkSession.builder.appName("model_predict").enableHiveSupport().getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("ERROR")
current_path = os.getcwd() + '/files/' # 获取线上项目路径,files为项目别名
model_file_path = current_path + f'superior_user/lgb_model'
pred(dt, spark, model_file_path)
1.2 定义集群打分函数
模型训练时的特征处理,预测时在这里进行。
这里每次处理的是一条数据,会对集群上各个pod上的数据进行处理。
def predict_udf(model, feature_list):
def inner(*selectCols):
X = pd.DataFrame(dict(zip(feature_list, selectCols)))
# X = X.astype("float")
for col in ['fea1', 'fea2', 'fea3']: # 类别特征指定
X[col] = X[col].astype('category')
y_pre = model.predict(X)
return pd.Series(y_pre)
return F.pandas_udf(inner, returnType=DoubleType())
1.3 模型结果写入
def pred(dt, spark, model_path):
gbm = lgb.Booster(model_file=model_path)
dt = '20241104'
feature_list = ['fea1', 'fea2', 'fea3', 'fea4']
udf = predict_udf(gbm, feature_list)
sql = f"select * from db.table_name where dt = '{dt}'"
predict_data = spark.sql(sql)
pred_df = predict_data.select('uid', 'market', # 保留预测集中的uid、market
udf(*feature_list).alias('score') # 加入模型score
)
print('pred_df: ', pred_df.show(10))
pred_df.createOrReplaceTempView("pred_df")
2. 结果后处理
2.1
black_df = spark.sql(f"""
select member_id, 0 as label
from db.table_name
where black_flag = '1'
group by member_id
""")
black_df.createOrReplaceTempView("black_df")
spark.sql("""
create table db.table_name_1 stored as orc as
select t1.uid,
if(t2.uid is null,t1.score,0) as score,
if(t2.uid is not null,1,0) as bad_flag
from pred_df t1
left join black_df t2
on t1.uid = t2.uid
""")
pred_df.createOrReplaceTempView("stat_df")
if pred_df.count() == 0:
raise Exception("the table is null")
2.2 打分根据市场分桶
ppercentile_data = spark.sql("""
select market,
percentile_approx(score,array(0.20,0.40,0.60,0.80)) as score_bucket
from db.table_name_1
group by market
""").toPandas()
sc = spark.sparkContext
d = {}
for i in percentile_data.to_dict("records"):
market = i["market"]
score_bucket = i["score_bucket"]
d[market] = score_bucket
d_bc = sc.broadcast(d)
def my_udf1(market, score):
if market is None:
return 0
score_bucket = d_bc.value[market]
x = bisect.bisect_left(score_bucket, score)
return x
my_udf1 = F.udf(my_udf1)
stat_df = spark.sql(f"""select * from db.table_name_1 where score is not null""")
stat_df = stat_df.withColumn("score_bucket", my_udf1("market", "score"))
stat_df.createOrReplaceTempView("stat_df")
spark.sql(f"""
insert overwrite table db.result_table partition(dt='{dt}')
select member_id, site_tp, country_nm,
market, score,
cast(score_bucket as int) as score_bucket
from stat_df
""")
check_result(dt, spark)
1.3 结果校验
def check_result(dt, spark):
df = spark.sql(f"""
select market,cnt,all_cnt,cnt/all_cnt as rate
from (
select market,count(distinct if(score_bucket=5,uid,null)) as cnt,
count(distinct member_id) as all_cnt
from db.result_table
where dt='{dt}'
group by market
)
""")
pdf = df.toPandas()
for d in pdf.to_dict("records"):
market = d.get("market")
rate = d.get("rate")
if market:
if rate <= 0.045 or rate >= 0.055:
raise Exception(f"the superior rate is not good, check please!")
return 0
4. 配置文件
11
#!/bin/bash
set -euxo pipefail
echo "$(pwd)"
cd ${tmp_dir}
if [ ! -d ${folder_name} ]; then
mkdir ${folder_name}
fi
cd ${folder_name}
sh ../code_pull.sh https://gitlab.baidu.cn/aiapp/in_score.git main
echo "$(ls)"
name="superior_user"
project_name="in_score"
cd ${project_name}
spark-submit --master yarn \
--deploy-mode cluster \
--driver-cores 4 \
--driver-memory 20G \
--num-executors 200 \
--executor-cores 4 \
--executor-memory 4G \
--name ${name} \
--conf spark.yarn.priority=100 \
--conf spark.storage.memoryFraction=0.5 \
--conf spark.shuffle.memoryFraction=0.5 \
--conf spark.driver.maxResultSize=10G \
--conf spark.dynamicAllocation.enabled=false \
--conf spark.executor.extraJavaOptions='-Xss128M' \
--conf spark.sql.autoBroadcastJoinThreshold=-1 \
--conf spark.sql.adaptive.enabled=true \
--conf spark.yarn.dist.archives=../${project_name}.zip#files \
--conf spark.yarn.appMasterEnv.PYTHONPATH=files \
--conf spark.executorEnv.PYTHONPATH=files \
--conf spark.pyspark.python=./env/bin/python \
--archives s3://xxaiapp/individual/bbb/condaenv/mzpy38_v2.tar.gz#env \ # 环境文件别名env
superior_user/train_test.py --dt ${dt} # 执行主文件
原文地址:https://blog.csdn.net/MusicDancing/article/details/143582057
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!