在Rcpp中使用NumericMatrix和NumericVec

在Rcpp中使用NumericMatrix和NumericVec

本文介绍了在Rcpp中使用NumericMatrix和NumericVector进行矩阵乘法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想知道是否有一种使用NumericMatrix和NumericVector类计算矩阵乘法的方法.我想知道是否有任何简单的方法为了帮助我避免以下循环进行此计算.我只想计算X%*%beta.

I am wondering is there a way of calculating matrix multiplication using NumericMatrix and NumericVector class. I am wondering if there is any simple wayto help me avoid the following loop to conduct this calculation. I just want to calculate X%*%beta.

// assume X and beta are initialized and X is of dimension (nsites, p),
// beta is a NumericVector with p elements.
for(int j = 0; j < nsites; j++)
 {
    temp = 0;

    for(int l = 0; l < p; l++) temp = temp + X(j,l) * beta[l];

}

非常感谢您!

推荐答案

以下是基于Dirk的注释的一些案例,这些案例通过重载的*运算符演示了Armadillo库的矩阵乘法:

Building off of Dirk's comment, here are a few cases that demonstrate the Armadillo library's matrix multiplication via the overloaded * operator:

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export(".mm")]]
arma::mat mm_mult(const arma::mat& lhs,
                  const arma::mat& rhs)
{
  return lhs * rhs;
}

// [[Rcpp::export(".vm")]]
arma::mat vm_mult(const arma::vec& lhs,
                  const arma::mat& rhs)
{
  return lhs.t() * rhs;
}

// [[Rcpp::export(".mv")]]
arma::mat mv_mult(const arma::mat& lhs,
                  const arma::vec& rhs)
{
  return lhs * rhs;
}

// [[Rcpp::export(".vv")]]
arma::mat vv_mult(const arma::vec& lhs,
                  const arma::vec& rhs)
{
  return lhs.t() * rhs;
}


然后,您可以定义R函数来分派适当的C ++函数:


You could then define an R function to dispatch the appropriate C++ function:

`%a*%` <- function(x,y) {

  if (is.matrix(x) && is.matrix(y)) {
    return(.mm(x,y))
  } else if (!is.matrix(x) && is.matrix(y)) {
    return(.vm(x,y))
  } else if (is.matrix(x) && !is.matrix(y)) {
    return(.mv(x,y))
  } else {
    return(.vv(x,y))
  }

}
##
mx <- matrix(1,nrow=3,ncol=3)
vx <- rep(1,3)
my <- matrix(.5,nrow=3,ncol=3)
vy <- rep(.5,3)


与R的%*%函数进行比较:


And comparing to R's %*% function:

R>  mx %a*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
[2,]  1.5  1.5  1.5
[3,]  1.5  1.5  1.5

R>  mx %*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
[2,]  1.5  1.5  1.5
[3,]  1.5  1.5  1.5
##
R>  vx %a*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5

R>  vx %*% my
     [,1] [,2] [,3]
[1,]  1.5  1.5  1.5
##
R>  mx %a*% vy
     [,1]
[1,]  1.5
[2,]  1.5
[3,]  1.5

R>  mx %*% vy
     [,1]
[1,]  1.5
[2,]  1.5
[3,]  1.5
##
R>  vx %a*% vy
     [,1]
[1,]  1.5

R>  vx %*% vy
     [,1]
[1,]  1.5

这篇关于在Rcpp中使用NumericMatrix和NumericVector进行矩阵乘法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-19 23:41