Erlo

使用PySpark计算AUC,KS与PSI

2023-08-20 16:30:30 发布   238 浏览  
页面报错/反馈
收藏 点赞

当特征数量或者模型数量很多的时候,使用PySpark去计算相关风控指标会节省很多的时间。网上关于使用PySpark计算相关风控指标的资料较少,尤其是PSI计算不管是国内还是国外相关的代码都没有正确的,这里抛砖引玉,写了三个风控常用的指标AUC,KS和PSI相关的计算方法,供参考。

AUC

AUC的相关概念网上已经有很多的很好的文章,这里不在赘述,AUC使用的到的计算公式如下:

[AUC=frac{sum_{iin positiveClass}rank_i-{displaystylefrac{M(1+M)}2}}{Mtimes N} ]

其中M为负类样本的数目,N为正类样本的数目

使用PySpark计算代码如下:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

true_y_col = 'y'
pred_y_col = 'pred_y'
date_col = 'day'


auc_df = df.filter(F.col(true_y_col)>=0).filter(F.col(pred_y_col)>=0)
           .select(true_y_col, pred_y_col, date_col, 'model_name')
           .withColumn('totalbad', F.sum(F.col(true_y_col)).over(Window.patitonBy(date_col, 'model_name').orderBy(F.lit(1))))
           .withColumn('totalgood', F.sum(1-F.col(true_y_col)).over(Window.patitonBy(date_col, 'model_name').orderBy(F.lit(1))))
           .withColumn('rnk2', F.row_number().over(Window.partitionBy(date_col, 'model_name').orderBy(F.col(pred_y_col).asc())))
           .filter(F.col(true_y_col)==1)
           .groupBy(date_col, 'model_name')
           .agg(((F.sum(F.col('rnk2'))-0.5*(F.max(F.col('totalbad')))*(1+F.max(F.col('totalbad'))))/(F.max(F.col('totalbad'))*F.max(F.col('totalgood')))).alias('AUC'))
           .orderBy('model_name', date_col)

KS

KS统计量是基于经验累积分布函数(Empirical Cumulative Distribution Function,ECDF)
建立的,一般定义为:

[KS=maxleft{left|cumleft(bad_rateright)-cumleft(good_rateright)right|right} ]

即为TPRFPR差值绝对值的最大值。

[KS=maxleft(left|TPR-FPRright|right) ]

KS计算方法有很多种,这里使用的是分箱法分别计算TPRFPR,然后得到KS。
使用PySpark计算代码如下:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

true_y_col = 'y'
pred_y_col = 'pred_y'
date_col = 'day'
nBins = 10

ks_df = df.filter(F.col(true_y_col)>=0).filter(F.col(pred_y_col)>=0)
          .select(true_y_col, pred_y_col, date_col, 'model_name')
          .withColumn('Bin', F.ntile(nBins).over(Window.partitionBy(date_col, 'model_name').orderBy(pred_y_col)))
          .groupBy(date_col, 'model_name', 'Bin').agg(F.sum(true_y_col).alias('N_1'), F.sum(1-F.col(true_y_col)).alias('N-0'))
          .withColumn('ALL_1', F.sum('N_1').over(Window.partitionBy(date_col, 'model_name')))
          .withColumn('ALL_0', F.sum('N_0').over(Window.partitionBy(date_col, 'model_name')))
          .withColumn('SUM_1', F.sum('N_1').over(Window.partitionBy(date_col, 'model_name').orderBy('Bin')))
          .withColumn('ALL_0', F.sum('N_0').over(Window.partitionBy(date_col, 'model_name').orderBy('Bin')))
          .withColumn('KSn', F.expr('round(abs(SUM_1/ALL_1-SUM_0/ALL_0),6)'))
          .withColumn('KS', F.round(F.max('KSn').over(Window.partitionBy(date_col, 'model_name')),6))

ks_df = ks_df.select(date_col, 'model_name', 'KS').filter(col('KS').isNotNull()).dropDuplicates()

PSI

群体稳定性指标(Population Stability Index,PSI)是风控场景常用的验证样本在各分数段的分布与建模样本分布的稳定性。在建模中,常用来筛选特征变量、评估模型稳定性

计算公式如下:

[psi=sum_{i=1}^nleft(A_i-E_iright)astlnleft(A_i/E_iright) ]

其中(A_i)代表的是第i个分箱中实际分布(actual)样本占比,同理(E_i)代表的是第i个分箱中预期分布(excepted)样本占比

使用PySpark计算代码如下:

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import when


date_col = 'day'
nBins = 10
feature_list = ['fea_1', 'fea_2', 'fea_3']

df = df.withColumn('flag', when(F.col(date_col) == 'actual_date', 0).when(F.col(date_col) == 'excepted_date', 1).otherwise(None)

quantitles = df.filter(F.col('flag') == 0)
               .approxQuantile(feature_list, [i/nBins for i in range(1, nBins)], 0.001) # 基准样本分箱

quantitles_dict = {col: quantitles[idx] for idx, col in enumerate(feature_list)}
f_quantitles_dict = F.create_map([F.lit(x) if isinstance(x, str) else F.array(*[F.lit(xx) for xx in x]) for i in quantitles_dict.items() for x in i])

unpivotExpr = "stack(3, 'fea_1', fea_1, 'fea_2', fea_2, 'fea_3', fea_3)"

psi_df = df.filter(F.col('flag').isNotNull()).select('flag', F.expr(unpivotExpr))
           .withColumn('Bin', when(F.col('value').isNull(), 'Missing').otherwise(
            when(F.col('value') 

Reference

登录查看全部

参与评论

评论留言

还没有评论留言,赶紧来抢楼吧~~

手机查看

返回顶部

给这篇文章打个标签吧~

棒极了 糟糕透顶 好文章 PHP JAVA JS 小程序 Python SEO MySql 确认