在mlr3vers中,还可以进行生存分析。关于生存分析的理论内容请参考以前的推文。

1、加载R包

library("mlr3verse")
library("mlr3proba")
library("survival")

2、设定任务

task = as_task_surv(survival::rats, 
                    time = "time",
                    event = "status", 
                    id = "rats")

task$head()
##    time status litter rx sex
## 1:  101      0      1  1   f
## 2:   49      1      1  0   f
## 3:  104      0      1  0   f
## 4:   91      0      2  1   m
## 5:  104      0      2  0   m
## 6:  102      0      2  0   m
#绘制KM曲线
autoplot(task,rhs="sex")

R语言: mlr3机器学习--生存分析-LMLPHP

3、构建生存分析模型

在mlr3vese中内置了“rats”任务,可直接使用。

t = tsk("rats")#使用自带的数据及和任务
# 数据划分
split = partition(t)
# 模型训练
p = lrn("surv.coxph")$train(t, split$train)$predict(t, split$test)
p
## <PredictionSurv> for 99 observations:
##     row_ids time status       crank          lp     distr
##           3  104  FALSE -0.07449981 -0.07449981 <list[1]>
##           9  104  FALSE -0.07152396 -0.07152396 <list[1]>
##          11  104  FALSE -2.69451424 -2.69451424 <list[1]>
## ---                                                      
##         235   80   TRUE  0.71123465  0.71123465 <list[1]>
##         247   73   TRUE  0.71718634  0.71718634 <list[1]>
##         249   66   TRUE  0.04750991  0.04750991 <list[1]>

4、生存分析预测

mlr3框架中的生存分析,可以做出以下预测:

  • response- 预测生存时间。
  • distr- 预测的生存分布,
  • lp- 线性预测,
  • crank- 连续的风险

4.1 预测生存时间(predict_type = “response” )

预测生存时间实际上是生存分析中最不常见的预测类型。对于许多参与者,我们很少观察到真实的生存时间,因此任何生存模型都不太可能自信地预测生存时间。

在下面的示例中,我们用生存支持向量机 (mlr_learners_surv.svm) 进行训练和预测。

library(mlr3extralearners)
library(survivalsvm)
pred = lrn("surv.svm", 
           type = "regression", 
           gamma.mu = 1e-3)$train(t, split$train)$predict(t,split$test)

预测生存时间与真实生存时间比较

data.frame(pred = pred$response[1:3], truth = pred$truth[1:3])
##       pred truth
## 1 88.14720  104+
## 2 87.85996  104+
## 3 87.72024  104+

从输出中可以看出,我们的预测都小于真实的观测时间,这意味着我们的模型肯定低估了真相。由于受到种种限制,所以生存时间预测很少用。

4.2 预测概率(predict_type = “distr” )

在生存分析中,分布预测更为常见。mlr3proba 中的大多数生存模型默认会进行概率预测。

t = tsk("rats")
split = partition(t)
p = lrn("surv.coxph")$train(t, split$train)$predict(t, split$test)
p$distr[1:3]$survival(50)
##        [,1]      [,2]      [,3]
## 50 0.946988 0.9991727 0.9462497

输出表明,预测前三只的大鼠在时间 50分别存活的几率为 95.7%、98.4%、97.7%。

4.3 预测风险(predict_type = “crank”)

学术论文中通常会在生存分析中提及“风险”预测(因此生存模型通常被称为“风险预测模型”),而没有定义“风险”的含义。通常风险被定义为crank。

我们继续使用前面的示例,我们输出前三个预测。输出告诉我们,第一只大鼠的死亡风险较高(值越大表示风险越高),第三只大鼠的死亡风险最低。预测之间的距离也告诉我们,第一只和第二只大鼠之间的风险差异大于于第二只和第三只大鼠之间的风险差异。实际值本身是没有意义的。

p$crank[1:3]
##          1          2          3 
##  0.6082167 -3.5785387  0.6224343

5、生存分析模型评估

一般来说,生存分析模型评估方法包括:

  • Discrimination measures(区分度)– 量化模型是否正确识别一个观测值是否比另一个观测值面临更高的风险。
  • Calibration measures (校准度)– 量化平均预测是否接近真实(不幸的是,在生存环境中,校准的所有定义都是模糊的)
  • Scoring rules-(评分规则) 量化概率预测是否接近真实值

mlr3verse框架中生存分析模型的部分评估方法:

head(as.data.table(mlr_measures)[
  task_type == "surv", c("key", "predict_type")])
##                   key predict_type
## 1:         surv.brier        distr
## 2:   surv.calib_alpha        distr
## 3:    surv.calib_beta           lp
## 4: surv.chambless_auc           lp
## 5:        surv.cindex        crank
## 6:        surv.dcalib        distr

我们建议使用RCLL(mlr_measures_surv.rcll)来评估预测的质量, 使用一致性指数(mlr_measures_surv.cindex)来评估模型的区分度 及D-Calibration (mlr_measures_surv.dcalib)来评估模型的校准。

p$score(msrs(c("surv.rcll", "surv.cindex", "surv.dcalib")))
##   surv.rcll surv.cindex surv.dcalib 
##   3.4355262   0.7925743   0.2582002

在上面的代码中,我们使用推荐的评估指标,发现此模型的性能似乎还可以,因为 RCLL 和 DCalib 相对较低 ,C index大于 0.5。

6 、全部代码

library(mlr3verse)
library(mlr3extralearners)

task = tsk("grace")$filter(1:500)

# 评估标准
msr_txt = c("surv.rcll", "surv.cindex", "surv.dcalib")
measures = msrs(msr_txt)

# 定义surv.glmnet
pipe = as_learner(ppl(
  "distrcompositor",
  learner = lrn("surv.glmnet"),
  estimator = "kaplan",
  form = "ph"
))
pipe$id = "Coxnet"

# 设定学习器集
learners = c(lrns(c("surv.coxph", "surv.kaplan")), pipe)

# 基准测试
bmr = benchmark(benchmark_grid(task, 
                               learners, 
                               rsmp("cv", folds = 3)))
# 评估指标
bmr$aggregate(measures)[, c("learner_id", ..msr_txt)]
##     learner_id surv.rcll surv.cindex surv.dcalib
## 1:  surv.coxph  2.815028   0.8188342    1.391847
## 2: surv.kaplan  2.960156   0.5000000    5.749886
## 3:      Coxnet  2.817485   0.8192240    5.143964
04-14 05:26