由于Armadillo(afaik)没有三角形求解器,因此我想使用 dtrtrs 中可用的LAPACK三角形求解器。我查看了以下两个(firstsecond)SO线程并将某些内容拼凑在一起,但是它不起作用。

我使用RStudio创建了一个新程序包,同时启用了RcppArmadillo。我有一个头文件header.h:

#include <RcppArmadillo.h>

#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dtrtrs dtrtrs
#else
#define arma_dtrtrs DTRTRS
#endif
#endif

extern "C" {
  void arma_fortran(arma_dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
                    double* A, int* LDA, double* B, int* LDB, int* INFO);
}

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb);

static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x);

本质上是第一个链接问题的答案,还有包装函数和main函数。函数的内容在trisolve.cpp中,如下所示:
#include "header.h"

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
  int info = 0;
  wrapper_dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
  return info;
}


static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x) {
  size_t  rows = in_A.n_rows;
  size_t  cols = in_A.n_cols;

  double *A = new double[rows*cols];
  double *b = new double[in_b.size()];

  //Lapack has column-major order
  for(size_t col=0, D1_idx=0; col<cols; ++col)
  {
    for(size_t row = 0; row<rows; ++row)
    {
      // Lapack uses column major format
      A[D1_idx++] = in_A(row, col);
    }
    b[col] = in_b(col);
  }

  for(size_t row = 0; row<rows; ++row)
  {
    b[row] = in_b(row);
  }

  int info = trtrs('U', 'N', 'N', cols, 1, A, rows, b, rows);

  for(size_t col=0; col<cols; col++) {
    out_x(col)=b[col];
  }

  delete[] A;
  delete[] b;

  return 0;
}


// [[Rcpp::export]]

arma::mat RtoRcpp(arma::mat A, arma::mat b) {
  arma::uword n = A.n_rows;
  arma::mat x = arma::mat(n, 1, arma::fill::zeros);

  int info = trisolve(A, b, x);
  return x;
}

对我来说(至少)有两个问题:
  • 尝试编译时,我从头文件中获取:conflicting types for 'dtrtrs_'。但是,我看不到输入有什么问题(这实际上是从第二个链接线程中复制的)。
  • 毫不奇怪,wrapper_dtrtrts_不正确。但是从我从Armadillo的 compiler_setup.hpp 可以看出,arma_fortran应该为我创建一个称为wrapper_dtrtrs_的函数。我在cpp主文件中应使用的名称是什么?
  • 最佳答案

    Armadillo 已经使用dtrtrs来解决对角线问题。一些代码引用:

  • dtrtrs中被调用的lapack::trtrs:https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/wrapper_lapack.hpp#L908
  • 用一个不错的调试语句在lapack::trtrs中调用
  • auxlib::solve_tri:https://gitlab.com/conradsnicta/armadillo-code/blob/9.200.x/include/armadillo_bits/auxlib_meat.hpp#L3983

  • 因此,如果我们可以触发此调试语句,则可以确保确实使用了dtrtrs:
    #define ARMA_EXTRA_DEBUG
    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    
    // [[Rcpp::export]]
    void testTrisolve() {
      arma::mat A = arma::randu<arma::mat>(5,5);
      arma::mat B = arma::randu<arma::mat>(5,5);
    
      arma::mat X1 = arma::solve(A, B);
      arma::mat X3 = arma::solve(arma::trimatu(A), B);
    }
    
    /*** R
    testTrisolve()
    */
    

    这会产生很多调试消息,其中包括:
    lapack::gesvx()
    [...]
    lapack::trtrs()
    

    因此,我们清楚地看到在对角线情况下使用了dtrtrs

    至于您的原始问题:
  • 类型冲突是Aramdillo已经使用dtrtrs造成的,但签名略有不同(Aconst)。
  • Fortran函数的C级名称取决于ARMA_BLAS_UNDERSCOREARMA_USE_WRAPPER的值。我不确定是否总是这样,但是对我来说,前者是定义的,而后者没有定义(参见config.hpp),导致dtrtrs_作为名称。

  • 确实,如果我在Armadillo使用它的地方添加了const并将该函数称为dtrtrs_,则您的代码将编译而没有错误或警告(未使用的变量除外):
    // [[Rcpp::depends(RcppArmadillo)]]
    #include <RcppArmadillo.h>
    
    extern "C" {
      void arma_fortran(dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
                        const double* A, int* LDA, double* B, int* LDB, int* INFO);
    }
    
    int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
      int info = 0;
      dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
      return info;
    }
    
    [...]
    

    关于c++ - 直接在RcppArmadillo中调用LAPACK例程,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52908185/

    10-13 00:25