我试图使用cublasDgemm()通过矩阵的转置来计算矩阵的乘积。输入矩阵和我期望的代码输出如下(分别为A和C):

    | 1 4 7 |        | 66 78 |
A = | 2 5 8 |    C = | 78 93 |

然而,我得到了奇怪的结果,我有点难以理解cublas/cuda使用的维度(column major)。任何小费都将不胜感激!
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <cuda_runtime.h>
#include "cublas_v2.h"
#define M 3
#define N 2
#define IDX2C(i,j,ld) (((j)*(ld))+(i))

int main (void){
    cudaError_t cudaStat;
    cublasStatus_t stat;
    cublasHandle_t handle;
    int i, j;
    double *devPtrA, *devPtrC;
    double *a = 0, *c = 0;

    const double alpha = 1;
    const double beta = 0;

    // initialize host arrays
    a = (double *)malloc (M * N * sizeof (*a));
    c = (double *)malloc (N * N * sizeof (*c));
    if (!a || !c) {
        printf ("host memory allocation failed");
        return EXIT_FAILURE;
    }

    // fill input array
    for (j = 0; j < N; j++) {
        for (i = 0; i < M; i++) {
            a[IDX2C(i,j,M)] = (double)(i * M + j + 1);
            printf ("%7.0f", a[IDX2C(i,j,M)]);
        }
        printf ("\n");
    }

    // set device to 0 (for double processing)
    cudaStat = cudaSetDevice(0);
    if (cudaStat != cudaSuccess) {
        printf("could not set device 0");
        return EXIT_FAILURE;
    }

    // allocate device arrays
    cudaStat = cudaMalloc ((void**)&devPtrA, M*N*sizeof(*a));
    if (cudaStat != cudaSuccess) {
        printf ("device memory allocation of A failed");
        return EXIT_FAILURE;
    }
    cudaStat = cudaMalloc ((void**)&devPtrC, N*N*sizeof(*c));
    if (cudaStat != cudaSuccess) {
        printf ("device memory allocation of C failed");
        return EXIT_FAILURE;
    }

    // create the cublas handle
    stat = cublasCreate(&handle);
    if (stat != CUBLAS_STATUS_SUCCESS) {
        printf ("CUBLAS initialization failed\n");
        return EXIT_FAILURE;
    }

    // set the matrix a
    stat = cublasSetMatrix (M, N, sizeof(*a), a, M, devPtrA, M);
    if (stat != CUBLAS_STATUS_SUCCESS) {
        printf ("data download failed");
        cudaFree (devPtrA);
        cudaFree (devPtrC);
        cublasDestroy(handle);
        return EXIT_FAILURE;
    }

    stat = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, M, M, N, &alpha, devPtrA, M, devPtrA, M, &beta, devPtrC, M);
    if (stat!= CUBLAS_STATUS_SUCCESS) {
        switch (stat) {
            case CUBLAS_STATUS_NOT_INITIALIZED:
                printf("CUBLAS_STATUS_NOT_INITIALIZED\n");
                break;
            case CUBLAS_STATUS_INVALID_VALUE:
                printf("CUBLAS_STATUS_INVALID_VALUE\n");
                break;
            case CUBLAS_STATUS_ARCH_MISMATCH:
                printf("CUBLAS_STATUS_ARCH_MISMATCH\n");
                break;
            case CUBLAS_STATUS_EXECUTION_FAILED:
                printf("CUBLAS_STATUS_EXECUTION_FAILED\n");
                break;
            default:
                printf("??\n");
        }

        printf("Error: %d\n", (int)stat);
        cudaFree (devPtrA);
        cudaFree (devPtrC);
        cublasDestroy(handle);
        return EXIT_FAILURE;
    }

    // get matrix c
    stat = cublasGetMatrix (N, N, sizeof(*c), devPtrC, N, c, N);
    if (stat != CUBLAS_STATUS_SUCCESS) {
        printf ("data upload failed");
        cudaFree (devPtrC);
        cublasDestroy(handle);
        return EXIT_FAILURE;
    }

    // cleanup cuda/cublas
    cudaFree (devPtrA);
    cudaFree (devPtrC);
    cublasDestroy(handle);

    // print result
    for (j = 0; j < N; j++) {
        for (i = 0; i < N; i++) {
            printf ("%7.0f", c[IDX2C(i,j,M)]);
        }
        printf ("\n");
    }

    // clear host data
    free(a);
    free(c);
    return EXIT_SUCCESS;
}

最佳答案

第一个问题是你要用行主格式填写矩阵A。要解决这个问题,只需交换i和j指数。在列主格式中,前导维度应该是行数,在您的情况下,是N。

for (j = 0; j < N; j++) {
    for (i = 0; i < M; i++) {
        a[IDX2C(j,i,N)] = (double)(i * M + j + 1);
        printf ("%7.0f", a[IDX2C(j,i,N)]);
    }
    printf ("\n");
}

您还在cublasDgemm调用中交换维度,应该如下所示:
stat = cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, N, N, M, &alpha, devPtrA, N, devPtrA, N, &beta, devPtrC, N);

最后,你用M作为C矩阵的前导维数,这里应该是N:
printf ("%7.0f", c[IDX2C(i,j,N)]);

09-08 11:49