2016-09-30 98 views
0

我有一个不平衡的数据,我想做分层交叉验证,并使用精度调用auc作为我的评估指标。préummaryin r caret package for imbalance data

我在分包索引的r包caret中使用prSummary,并且在计算性能时遇到错误。

以下是可以复制的样本。我发现计算p-r auc只有十个样本,并且由于不平衡,只有一个类,因此它不能计算p-r auc。 (我发现,只有10个检计算PR AUC的原因是因为我修改prSummary迫使此功能打印出的数据)

library(randomForest) 
library(mlbench) 
library(caret) 

# Load Dataset 
data(Sonar) 
dataset <- Sonar 
x <- dataset[,1:60] 
y <- dataset[,61] 
# make this data very imbalance 
y[4:length(y)] <- "M" 
y <- as.factor(y) 
dataset$Class <- y 

# create index and indexOut 
seed <- 1 
set.seed(seed) 
folds <- 2 
idxAll <- 1:nrow(x) 
cvIndex <- createFolds(factor(y), folds, returnTrain = T) 
cvIndexOut <- lapply(1:length(cvIndex), function(i){ 
    idxAll[-cvIndex[[i]]] 
}) 
names(cvIndexOut) <- names(cvIndex) 

# set the index, indexOut and prSummaryCorrect 
control <- trainControl(index = cvIndex, indexOut = cvIndexOut, 
          method="cv", summaryFunction = prSummary, classProbs = T) 
metric <- "AUC" 
set.seed(seed) 
mtry <- sqrt(ncol(x)) 
tunegrid <- expand.grid(.mtry=mtry) 
rf_default <- train(Class~., data=dataset, method="rf", metric=metric, tuneGrid=tunegrid, trControl=control) 

以下是错误消息:

Error in ROCR::prediction(y_pred, y_true) : 
Number of classes is not equal to 2. 
ROCR currently supports only evaluation of binary classification tasks. 

回答

0

我认为我发现了奇怪的事情......

即使我指定了交叉验证索引,摘要函数(无论prSummary或​​其他汇总函数)仍然会随机(我不确定)选择十个样本来计算性能。

我做的方式是用tryCatch定义一个汇总函数,以避免发生此错误。

prSummaryCorrect <- function (data, lev = NULL, model = NULL) { 
    print(data) 
    print(dim(data)) 
    library(MLmetrics) 
    library(PRROC) 
    if (length(levels(data$obs)) != 2) 
    stop(levels(data$obs)) 
    if (length(levels(data$obs)) > 2) 
    stop(paste("Your outcome has", length(levels(data$obs)), 
       "levels. The prSummary() function isn't appropriate.")) 
    if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match") 

    res <- tryCatch({ 
    auc <- MLmetrics::PRAUC(y_pred = data[, lev[2]], y_true = ifelse(data$obs == lev[2], 1, 0)) 
    }, warning = function(war) { 
    print(war) 
    auc <- NA 
    }, error = function(e){ 
    print(dim(data)) 
    auc <- NA 
    }, finally = { 
    print("finally") 
    auc <- NA 
    }) 

    c(AUC = res, 
    Precision = precision.default(data = data$pred, reference = data$obs, relevant = lev[2]), 
    Recall = recall.default(data = data$pred, reference = data$obs, relevant = lev[2]), 
    F = F_meas.default(data = data$pred, reference = data$obs, relevant = lev[2])) 
}