我正在尝试将软件包quantedacaret一起使用,以根据经过训练的样本对文本进行分类。作为测试运行,我想将quanteda的内置朴素贝叶斯分类器与caret中的那些进行比较。但是,我似乎无法使caret正常工作。

这是一些可复制的代码。首先在quanteda端:

library(quanteda)
library(quanteda.corpora)
library(caret)
corp <- data_corpus_movies
set.seed(300)
id_train <- sample(docnames(corp), size = 1500, replace = FALSE)

# get training set
training_dfm <- corpus_subset(corp, docnames(corp) %in% id_train) %>%
  dfm(stem = TRUE)

# get test set (documents not in id_train, make features equal)
test_dfm <- corpus_subset(corp, !docnames(corp) %in% id_train) %>%
  dfm(stem = TRUE) %>%
  dfm_select(pattern = training_dfm,
             selection = "keep")

# train model on sentiment
nb_quanteda <- textmodel_nb(training_dfm, docvars(training_dfm, "Sentiment"))

# predict and evaluate
actual_class <- docvars(test_dfm, "Sentiment")
predicted_class <- predict(nb_quanteda, newdata = test_dfm)
class_table_quanteda <- table(actual_class, predicted_class)
class_table_quanteda
#>             predicted_class
#> actual_class neg pos
#>          neg 202  47
#>          pos  49 202


不错。精度为80.8%,无需调整。现在,据我所知caret

training_m <- convert(training_dfm, to = "matrix")
test_m <- convert(test_dfm, to = "matrix")
nb_caret <- train(x = training_m,
                  y = as.factor(docvars(training_dfm, "Sentiment")),
                  method = "naive_bayes",
                  trControl = trainControl(method = "none"),
                  tuneGrid = data.frame(laplace = 1,
                                        usekernel = FALSE,
                                        adjust = FALSE),
                  verbose = TRUE)

predicted_class_caret <- predict(nb_caret, newdata = test_m)
class_table_caret <- table(actual_class, predicted_class_caret)
class_table_caret
#>             predicted_class_caret
#> actual_class neg pos
#>          neg 246   3
#>          pos 249   2


在这里,不仅准确性极差(49.6%-大概是机会),而且几乎从未预测过pos类!因此,我敢肯定我在这里遗漏了一些关键的内容,因为我认为实现应该相当相似,但不确定是什么。

我已经查看了quanteda函数的源代码(希望它可能始终建立在caret或基础包上),并且看到正在进行一些加权和平滑处理。如果我在训练之前将其应用于dfm(稍后再设置laplace = 0),则精度会更好一些。但也只有53%。

最佳答案

答案是,插入符(使用naivebayes软件包中的naive_bayes)采用高斯分布,而quanteda::textmodel_nb()基于更适合文本的多项式分布(也可以选择使用Bernoulli分布)。

textmodel_nb()的文档复制了IIR书中的示例(Manning,Raghavan和Schütze2008),还引用了Jurafsky和Martin(2018)的另一个示例。看到:


曼宁(Manning),克里斯托弗(Christopher D.),普拉巴卡(Prabhakar Raghavan)和辛里奇(HinrichSchütze)。 2008年。信息检索简介。剑桥大学出版社(第13章)。 https://nlp.stanford.edu/IR-book/pdf/irbookonlinereading.pdf
Jurafsky,Daniel和James H. Martin。 2018年。语音和语言处理。自然语言处理,计算语言学和语音识别简介。第三版草案,2018年9月23日(第4章)。 https://web.stanford.edu/~jurafsky/slp3/4.pdf


另一个软件包e1071产生的结果与基于高斯分布的结果相同。

library("e1071")
nb_e1071 <- naiveBayes(x = training_m,
                       y = as.factor(docvars(training_dfm, "Sentiment")))
nb_e1071_pred <- predict(nb_e1071, newdata = test_m)
table(actual_class, nb_e1071_pred)
##             nb_e1071_pred
## actual_class neg pos
##          neg 246   3
##          pos 249   2


但是,插入符号和e1071都在密集矩阵上工作,这是它们与在稀疏dfm上运行的Quanteda方法相比如此令人费解的缓慢的原因之一。因此,从适当性,效率和(根据您的结果)分类器的性能的角度来看,应该很清楚地选择哪一个分类器!

library("rbenchmark")
benchmark(
    quanteda = {
        nb_quanteda <- textmodel_nb(training_dfm, docvars(training_dfm, "Sentiment"))
        predicted_class <- predict(nb_quanteda, newdata = test_dfm)
    },
    caret = {
        nb_caret <- train(x = training_m,
                          y = as.factor(docvars(training_dfm, "Sentiment")),
                          method = "naive_bayes",
                          trControl = trainControl(method = "none"),
                          tuneGrid = data.frame(laplace = 1,
                                                usekernel = FALSE,
                                                adjust = FALSE),
                          verbose = FALSE)
        predicted_class_caret <- predict(nb_caret, newdata = test_m)
    },
    e1071 = {
        nb_e1071 <- naiveBayes(x = training_m,
                       y = as.factor(docvars(training_dfm, "Sentiment")))
        nb_e1071_pred <- predict(nb_e1071, newdata = test_m)
    },
    replications = 1
)
##       test replications elapsed relative user.self sys.self user.child sys.child
## 2    caret            1  29.042  123.583    25.896    3.095          0         0
## 3    e1071            1 217.177  924.157   215.587    1.169          0         0
## 1 quanteda            1   0.235    1.000     0.213    0.023          0         0

08-20 00:21