我正在R中为Hadoop进行分布式线性回归计算,但是在实现它之前,我想验证我的计算是否与lm函数的结果一致。

我具有以下功能,这些功能试图实现Andrew Ng等人讨论的通用“求和”框架。在文件Map-Reduce for Machine Learning on Multicore中。

对于线性回归,这涉及将每行y_i和x_i映射到P_i和Q_i,从而:

P_i = x_i * transpose(x_i)
Q_i = x_i * y_i

然后减少求解系数theta:theta = (sum(P_i))^-1 * sum(Q_i)
R函数可以做到这一点:
calculate_p <- function(dat_row) {
  dat_row %*% t(dat_row)
}

calculate_q <- function(dat_row) {
  dat_row[1,1] * dat_row[, -1]
}

calculate_pq <- function(dat_row) {
  c(calculate_p(matrix(dat_row[-1], nrow=1)), calculate_q(matrix(dat_row, nrow=1)))
}

map_pq <- function(dat) {
  t(apply(dat, 1, calculate_pq))
}

reduce_pq <- function(pq) {
  (1 / sum(pq[, 1])) * apply(pq[, -1], 2, sum)
}

您可以通过运行以下命令在一些综合数据上实现它:
X <- matrix(rnorm(20*5), ncol = 5)
y <- as.matrix(rnorm(20))
reduce_pq(map_pq(cbind(y, X)))
[1]  0.010755882 -0.006339951 -0.034797768  0.067438662 -0.033557351
coef(lm.fit(X, y))
          x1           x2           x3           x4           x5
-0.038556283 -0.002963991 -0.195897701  0.422552974 -0.029823962

不幸的是,输出不匹配,所以很明显我做错了。有什么想法可以解决吗?

最佳答案

您在reduce_pq中采用的逆必须是矩阵逆。另外,我对某些功能做了一些改动。

calculate_p <- function(dat_row) {
    dat_row %*% t(dat_row)
}

calculate_q <- function(dat_row) {
    dat_row[1] * dat_row[-1]
}

calculate_pq <- function(dat_row) {
    c(calculate_p(dat_row[-1]), calculate_q(dat_row))
}

map_pq <- function(dat) {
    t(apply(dat, 1, calculate_pq))
}

reduce_pq <- function(pq) {
    solve(matrix(apply(pq[, 1:(ncol(X) * ncol(X))], 2, sum), nrow=ncol(X))) %*% apply(pq[, 1:ncol(X) + ncol(X)*ncol(X)], 2, sum)
}


set.seed(1)
X <- matrix(rnorm(20*5), ncol = 5)
y <- as.matrix(rnorm(20))

t(reduce_pq(map_pq(cbind(y, X))))
          [,1]      [,2]      [,3]       [,4]        [,5]
[1,] 0.1236914 0.2482445 0.5120975 -0.1104451 -0.04080922

coef(lm.fit(X,y))
         x1          x2          x3          x4          x5
 0.12369137  0.24824449  0.51209753 -0.11044507 -0.04080922

> all.equal(as.numeric(t(reduce_pq(map_pq(cbind(y, X))))), as.numeric(coef(lm.fit(X,y))))
[1] TRUE

关于r - Map降低基数R中的线性回归,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/12829577/

10-11 17:48