|
1 | 1 | #include <stdio.h>
|
2 | 2 | #include "cublas_v2.h"
|
3 | 3 | #include "cuda_runtime.h"
|
4 |
| -#include "wrapper_common.h" |
| 4 | +#include "wrapper_cuda.h" |
5 | 5 |
|
6 | 6 | template<typename T, typename AXPY>
|
7 |
| -void cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy) |
| 7 | +void cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy, cudaError_t *error, cublasStatus_t *blasStatus) |
8 | 8 | {
|
9 | 9 | T *d_X = NULL;
|
10 | 10 | T *d_Y = NULL;
|
11 |
| - cudaMalloc((void**)&d_X, n*sizeof(T)); |
12 |
| - cudaMalloc((void**)&d_Y, n*sizeof(T)); |
| 11 | + *error = cudaError_t::cudaSuccess; |
| 12 | + *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
13 | 13 |
|
14 |
| - cublasSetVector(n, sizeof(T), x, incX, d_X, incX); |
15 |
| - cublasSetVector(n, sizeof(T), y, incY, d_Y, incY); |
| 14 | + SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T))) |
| 15 | + SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T))) |
16 | 16 |
|
17 |
| - axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX); |
| 17 | + SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)) |
| 18 | + SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY)) |
18 | 19 |
|
19 |
| - cublasGetVector(n, sizeof(T), d_Y, incY, y, incY); |
| 20 | + SAFECUDACALL(blasStatus, axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX)) |
20 | 21 |
|
| 22 | + SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_Y, incY, y, incY)) |
| 23 | + |
| 24 | +exit: |
21 | 25 | cudaFree(d_X);
|
22 | 26 | cudaFree(d_Y);
|
23 | 27 | }
|
24 | 28 |
|
25 | 29 | template<typename T, typename SCAL>
|
26 |
| -void cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal) |
| 30 | +void cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal, cudaError_t *error, cublasStatus_t *blasStatus) |
27 | 31 | {
|
28 | 32 | T *d_X = NULL;
|
29 |
| - cudaMalloc((void**)&d_X, n*sizeof(T)); |
30 |
| - |
31 |
| - cublasSetVector(n, sizeof(T), x, incX, d_X, incX); |
32 |
| - |
33 |
| - scal(blasHandle, n, &alpha, d_X, incX); |
| 33 | + *error = cudaError_t::cudaSuccess; |
| 34 | + *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
34 | 35 |
|
35 |
| - cublasGetVector(n, sizeof(T), d_X, incX, x, incX); |
| 36 | + SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T))) |
| 37 | + SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)) |
| 38 | + SAFECUDACALL(blasStatus, scal(blasHandle, n, &alpha, d_X, incX)) |
| 39 | + SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_X, incX, x, incX)) |
36 | 40 |
|
| 41 | +exit: |
37 | 42 | cudaFree(d_X);
|
38 | 43 | }
|
39 | 44 |
|
40 | 45 | template<typename T, typename DOT>
|
41 |
| -void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot) |
| 46 | +void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot, cudaError_t *error, cublasStatus_t *blasStatus) |
42 | 47 | {
|
43 | 48 | T *d_X = NULL;
|
44 | 49 | T *d_Y = NULL;
|
45 |
| - cudaMalloc((void**)&d_X, n*sizeof(T)); |
46 |
| - cudaMalloc((void**)&d_Y, n*sizeof(T)); |
| 50 | + *error = cudaError_t::cudaSuccess; |
| 51 | + *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
47 | 52 |
|
48 |
| - cublasSetVector(n, sizeof(T), x, incX, d_X, incX); |
49 |
| - cublasSetVector(n, sizeof(T), y, incY, d_Y, incY); |
| 53 | + SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T))) |
| 54 | + SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T))) |
50 | 55 |
|
51 |
| - dot(blasHandle, n, d_X, incX, d_Y, incY, result); |
| 56 | + SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)) |
| 57 | + SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY)) |
52 | 58 |
|
| 59 | + SAFECUDACALL(blasStatus, dot(blasHandle, n, d_X, incX, d_Y, incY, result)) |
| 60 | + |
| 61 | +exit: |
53 | 62 | cudaFree(d_X);
|
54 | 63 | cudaFree(d_Y);
|
55 | 64 | }
|
56 | 65 |
|
57 | 66 | template<typename T, typename GEMM>
|
58 |
| -void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm) |
| 67 | +void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm, cudaError_t *error, cublasStatus_t *blasStatus) |
59 | 68 | {
|
60 | 69 | T *d_A = NULL;
|
61 |
| - cudaMalloc((void**)&d_A, m*k*sizeof(T)); |
62 |
| - cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m); |
63 |
| - |
64 | 70 | T *d_B = NULL;
|
65 |
| - cudaMalloc((void**)&d_B, k*n*sizeof(T)); |
66 |
| - cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k); |
67 |
| - |
68 | 71 | T *d_C = NULL;
|
69 |
| - cudaMalloc((void**)&d_C, m*n*sizeof(T)); |
70 |
| - cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m); |
| 72 | + *error = cudaError_t::cudaSuccess; |
| 73 | + *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
| 74 | + |
| 75 | + SAFECUDACALL(error, cudaMalloc((void**)&d_A, m*k*sizeof(T))) |
| 76 | + SAFECUDACALL(blasStatus, cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m)) |
| 77 | + |
| 78 | + SAFECUDACALL(error, cudaMalloc((void**)&d_B, k*n*sizeof(T))) |
| 79 | + SAFECUDACALL(blasStatus, cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k)) |
| 80 | + |
| 81 | + SAFECUDACALL(error, cudaMalloc((void**)&d_C, m*n*sizeof(T))) |
| 82 | + SAFECUDACALL(blasStatus, cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m)) |
71 | 83 |
|
72 |
| - gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc); |
| 84 | + SAFECUDACALL(blasStatus, gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc)) |
73 | 85 |
|
74 |
| - cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m); |
| 86 | + SAFECUDACALL(blasStatus, cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m)) |
75 | 87 |
|
| 88 | +exit: |
76 | 89 | cudaFree(d_A);
|
77 | 90 | cudaFree(d_B);
|
78 | 91 | cudaFree(d_C);
|
79 | 92 | }
|
80 | 93 |
|
81 | 94 | extern "C" {
|
82 | 95 |
|
83 |
| - DLLEXPORT void s_axpy(const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[]){ |
84 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy); |
| 96 | + DLLEXPORT void s_axpy(const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 97 | + cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy, error, blasStatus); |
85 | 98 | }
|
86 | 99 |
|
87 |
| - DLLEXPORT void d_axpy(const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[]){ |
88 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy); |
| 100 | + DLLEXPORT void d_axpy(const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 101 | + cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy, error, blasStatus); |
89 | 102 | }
|
90 | 103 |
|
91 |
| - DLLEXPORT void c_axpy(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[]){ |
92 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy); |
| 104 | + DLLEXPORT void c_axpy(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 105 | + cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy, error, blasStatus); |
93 | 106 | }
|
94 | 107 |
|
95 |
| - DLLEXPORT void z_axpy(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[]){ |
96 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy); |
| 108 | + DLLEXPORT void z_axpy(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 109 | + cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy, error, blasStatus); |
97 | 110 | }
|
98 | 111 |
|
99 |
| - DLLEXPORT void s_scale(const cublasHandle_t blasHandle, const int n, const float alpha, float x[]){ |
100 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal); |
| 112 | + DLLEXPORT void s_scale(const cublasHandle_t blasHandle, const int n, const float alpha, float x[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 113 | + cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal, error, blasStatus); |
101 | 114 | }
|
102 | 115 |
|
103 |
| - DLLEXPORT void d_scale(const cublasHandle_t blasHandle, const int n, const double alpha, double x[]){ |
104 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal); |
| 116 | + DLLEXPORT void d_scale(const cublasHandle_t blasHandle, const int n, const double alpha, double x[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 117 | + cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal, error, blasStatus); |
105 | 118 | }
|
106 | 119 |
|
107 |
| - DLLEXPORT void c_scale(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[]){ |
108 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal); |
| 120 | + DLLEXPORT void c_scale(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 121 | + cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal, error, blasStatus); |
109 | 122 | }
|
110 | 123 |
|
111 |
| - DLLEXPORT void z_scale(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[]){ |
112 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal); |
| 124 | + DLLEXPORT void z_scale(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[], cudaError_t *error, cublasStatus_t *blasStatus){ |
| 125 | + cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal, error, blasStatus); |
113 | 126 | }
|
114 | 127 |
|
115 |
| - DLLEXPORT float s_dot_product(const cublasHandle_t blasHandle, const int n, const float x[], const float y[]){ |
| 128 | + DLLEXPORT float s_dot_product(const cublasHandle_t blasHandle, const int n, const float x[], const float y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
116 | 129 | float ret;
|
117 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasSdot); |
| 130 | + cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasSdot, error, blasStatus); |
118 | 131 | return ret;
|
119 | 132 | }
|
120 | 133 |
|
121 |
| - DLLEXPORT double d_dot_product(const cublasHandle_t blasHandle, const int n, const double x[], const double y[]){ |
| 134 | + DLLEXPORT double d_dot_product(const cublasHandle_t blasHandle, const int n, const double x[], const double y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
122 | 135 | double ret;
|
123 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasDdot); |
| 136 | + cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasDdot, error, blasStatus); |
124 | 137 | return ret;
|
125 | 138 | }
|
126 | 139 |
|
127 |
| - DLLEXPORT cuComplex c_dot_product(const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[]){ |
| 140 | + DLLEXPORT cuComplex c_dot_product(const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
128 | 141 | cuComplex ret;
|
129 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasCdotu); |
| 142 | + cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasCdotu, error, blasStatus); |
130 | 143 | return ret;
|
131 | 144 | }
|
132 | 145 |
|
133 |
| - DLLEXPORT cuDoubleComplex z_dot_product(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[]){ |
| 146 | + DLLEXPORT cuDoubleComplex z_dot_product(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){ |
134 | 147 | cuDoubleComplex ret;
|
135 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasZdotu); |
| 148 | + cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasZdotu, error, blasStatus); |
136 | 149 | return ret;
|
137 | 150 | }
|
138 | 151 |
|
139 |
| - DLLEXPORT void s_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[]){ |
| 152 | + DLLEXPORT void s_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[], cudaError_t *error, cublasStatus_t *blasStatus){ |
140 | 153 | int lda = transA == CUBLAS_OP_N ? m : k;
|
141 | 154 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
142 | 155 |
|
143 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm); |
| 156 | + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm, error, blasStatus); |
144 | 157 | }
|
145 | 158 |
|
146 |
| - DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){ |
| 159 | + DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[], cudaError_t *error, cublasStatus_t *blasStatus){ |
147 | 160 | int lda = transA == CUBLAS_OP_N ? m : k;
|
148 | 161 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
149 | 162 |
|
150 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm); |
| 163 | + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm, error, blasStatus); |
151 | 164 | }
|
152 | 165 |
|
153 |
| - DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){ |
| 166 | + DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[], cudaError_t *error, cublasStatus_t *blasStatus){ |
154 | 167 | int lda = transA == CUBLAS_OP_N ? m : k;
|
155 | 168 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
156 | 169 |
|
157 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm); |
| 170 | + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm, error, blasStatus); |
158 | 171 | }
|
159 | 172 |
|
160 |
| - DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){ |
| 173 | + DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[], cudaError_t *error, cublasStatus_t *blasStatus){ |
161 | 174 | int lda = transA == CUBLAS_OP_N ? m : k;
|
162 | 175 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
163 | 176 |
|
164 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm); |
| 177 | + cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm, error, blasStatus); |
165 | 178 | }
|
166 | 179 |
|
167 | 180 | }
|
|
0 commit comments