|
4 | 4 | #include "wrapper_cuda.h"
|
5 | 5 |
|
6 | 6 | template<typename T, typename AXPY>
|
7 |
| -void cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy, cudaError_t *error, cublasStatus_t *blasStatus) |
| 7 | +CudaResults cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy) |
8 | 8 | {
|
9 | 9 | T *d_X = NULL;
|
10 | 10 | T *d_Y = NULL;
|
11 |
| - *error = cudaError_t::cudaSuccess; |
12 |
| - *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
| 11 | + CudaResults results; |
13 | 12 |
|
14 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T))) |
15 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T))) |
| 13 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_X, n*sizeof(T))); |
| 14 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_Y, n*sizeof(T))); |
16 | 15 |
|
17 |
| - SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)) |
18 |
| - SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY)) |
| 16 | + SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)); |
| 17 | + SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY)); |
19 | 18 |
|
20 |
| - SAFECUDACALL(blasStatus, axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX)) |
| 19 | + SAFECUDACALL(results.blasStatus, axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX)); |
21 | 20 |
|
22 |
| - SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_Y, incY, y, incY)) |
| 21 | + SAFECUDACALL(results.blasStatus, cublasGetVector(n, sizeof(T), d_Y, incY, y, incY)); |
23 | 22 |
|
24 | 23 | exit:
|
25 | 24 | cudaFree(d_X);
|
26 | 25 | cudaFree(d_Y);
|
| 26 | + |
| 27 | + return results; |
27 | 28 | }
|
28 | 29 |
|
29 | 30 | template<typename T, typename SCAL>
|
30 |
| -void cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal, cudaError_t *error, cublasStatus_t *blasStatus) |
| 31 | +CudaResults cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal) |
31 | 32 | {
|
32 | 33 | T *d_X = NULL;
|
33 |
| - *error = cudaError_t::cudaSuccess; |
34 |
| - *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
| 34 | + CudaResults results; |
35 | 35 |
|
36 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T))) |
37 |
| - SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)) |
38 |
| - SAFECUDACALL(blasStatus, scal(blasHandle, n, &alpha, d_X, incX)) |
39 |
| - SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_X, incX, x, incX)) |
| 36 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_X, n*sizeof(T))); |
| 37 | + SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)); |
| 38 | + SAFECUDACALL(results.blasStatus, scal(blasHandle, n, &alpha, d_X, incX)); |
| 39 | + SAFECUDACALL(results.blasStatus, cublasGetVector(n, sizeof(T), d_X, incX, x, incX)); |
40 | 40 |
|
41 | 41 | exit:
|
42 | 42 | cudaFree(d_X);
|
| 43 | + |
| 44 | + return results; |
43 | 45 | }
|
44 | 46 |
|
45 | 47 | template<typename T, typename DOT>
|
46 |
| -void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot, cudaError_t *error, cublasStatus_t *blasStatus) |
| 48 | +CudaResults cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot) |
47 | 49 | {
|
48 | 50 | T *d_X = NULL;
|
49 | 51 | T *d_Y = NULL;
|
50 |
| - *error = cudaError_t::cudaSuccess; |
51 |
| - *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
| 52 | + CudaResults results; |
52 | 53 |
|
53 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T))) |
54 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T))) |
| 54 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_X, n*sizeof(T))); |
| 55 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_Y, n*sizeof(T))); |
55 | 56 |
|
56 |
| - SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)) |
57 |
| - SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY)) |
| 57 | + SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX)); |
| 58 | + SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY)); |
58 | 59 |
|
59 |
| - SAFECUDACALL(blasStatus, dot(blasHandle, n, d_X, incX, d_Y, incY, result)) |
| 60 | + SAFECUDACALL(results.blasStatus, dot(blasHandle, n, d_X, incX, d_Y, incY, result)); |
60 | 61 |
|
61 | 62 | exit:
|
62 | 63 | cudaFree(d_X);
|
63 | 64 | cudaFree(d_Y);
|
| 65 | + |
| 66 | + return results; |
64 | 67 | }
|
65 | 68 |
|
66 | 69 | template<typename T, typename GEMM>
|
67 |
| -void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm, cudaError_t *error, cublasStatus_t *blasStatus) |
| 70 | +CudaResults cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm) |
68 | 71 | {
|
69 | 72 | T *d_A = NULL;
|
70 | 73 | T *d_B = NULL;
|
71 | 74 | T *d_C = NULL;
|
72 |
| - *error = cudaError_t::cudaSuccess; |
73 |
| - *blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS; |
| 75 | + CudaResults results; |
74 | 76 |
|
75 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_A, m*k*sizeof(T))) |
76 |
| - SAFECUDACALL(blasStatus, cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m)) |
| 77 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_A, m*k*sizeof(T))); |
| 78 | + SAFECUDACALL(results.blasStatus, cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m)); |
77 | 79 |
|
78 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_B, k*n*sizeof(T))) |
79 |
| - SAFECUDACALL(blasStatus, cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k)) |
| 80 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_B, k*n*sizeof(T))); |
| 81 | + SAFECUDACALL(results.blasStatus, cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k)); |
80 | 82 |
|
81 |
| - SAFECUDACALL(error, cudaMalloc((void**)&d_C, m*n*sizeof(T))) |
82 |
| - SAFECUDACALL(blasStatus, cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m)) |
| 83 | + SAFECUDACALL(results.error, cudaMalloc((void**)&d_C, m*n*sizeof(T))); |
| 84 | + SAFECUDACALL(results.blasStatus, cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m)); |
83 | 85 |
|
84 |
| - SAFECUDACALL(blasStatus, gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc)) |
| 86 | + SAFECUDACALL(results.blasStatus, gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc)); |
85 | 87 |
|
86 |
| - SAFECUDACALL(blasStatus, cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m)) |
| 88 | + SAFECUDACALL(results.blasStatus, cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m)); |
87 | 89 |
|
88 | 90 | exit:
|
89 | 91 | cudaFree(d_A);
|
90 | 92 | cudaFree(d_B);
|
91 | 93 | cudaFree(d_C);
|
| 94 | + |
| 95 | + return results; |
92 | 96 | }
|
93 | 97 |
|
94 | 98 | extern "C" {
|
95 | 99 |
|
96 | 100 | DLLEXPORT CudaResults s_axpy(const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[]){
|
97 |
| - CudaResults ret; |
98 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy, &ret.error, &ret.blasStatus); |
99 |
| - return ret; |
| 101 | + return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy); |
100 | 102 | }
|
101 | 103 |
|
102 | 104 | DLLEXPORT CudaResults d_axpy(const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[]){
|
103 |
| - CudaResults ret; |
104 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy, &ret.error, &ret.blasStatus); |
105 |
| - return ret; |
| 105 | + return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy); |
106 | 106 | }
|
107 | 107 |
|
108 | 108 | DLLEXPORT CudaResults c_axpy(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[]){
|
109 |
| - CudaResults ret; |
110 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy, &ret.error, &ret.blasStatus); |
111 |
| - return ret; |
| 109 | + return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy); |
112 | 110 | }
|
113 | 111 |
|
114 | 112 | DLLEXPORT CudaResults z_axpy(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[]){
|
115 |
| - CudaResults ret; |
116 |
| - cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy, &ret.error, &ret.blasStatus); |
117 |
| - return ret; |
| 113 | + return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy); |
118 | 114 | }
|
119 | 115 |
|
120 | 116 | DLLEXPORT CudaResults s_scale(const cublasHandle_t blasHandle, const int n, const float alpha, float x[]){
|
121 |
| - CudaResults ret; |
122 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal, &ret.error, &ret.blasStatus); |
123 |
| - return ret; |
| 117 | + return cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal); |
124 | 118 | }
|
125 | 119 |
|
126 | 120 | DLLEXPORT CudaResults d_scale(const cublasHandle_t blasHandle, const int n, const double alpha, double x[]){
|
127 |
| - CudaResults ret; |
128 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal, &ret.error, &ret.blasStatus); |
129 |
| - return ret; |
| 121 | + return cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal); |
130 | 122 | }
|
131 | 123 |
|
132 | 124 | DLLEXPORT CudaResults c_scale(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[]){
|
133 |
| - CudaResults ret; |
134 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal, &ret.error, &ret.blasStatus); |
135 |
| - return ret; |
| 125 | + return cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal); |
136 | 126 | }
|
137 | 127 |
|
138 | 128 | DLLEXPORT CudaResults z_scale(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[]){
|
139 |
| - CudaResults ret; |
140 |
| - cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal, &ret.error, &ret.blasStatus); |
141 |
| - return ret; |
| 129 | + return cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal); |
142 | 130 | }
|
143 | 131 |
|
144 | 132 | DLLEXPORT CudaResults s_dot_product(const cublasHandle_t blasHandle, const int n, const float x[], const float y[], float *result){
|
145 |
| - CudaResults ret; |
146 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasSdot, &ret.error, &ret.blasStatus); |
147 |
| - return ret; |
| 133 | + return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasSdot); |
148 | 134 | }
|
149 | 135 |
|
150 | 136 | DLLEXPORT CudaResults d_dot_product(const cublasHandle_t blasHandle, const int n, const double x[], const double y[], double *result){
|
151 |
| - CudaResults ret; |
152 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasDdot, &ret.error, &ret.blasStatus); |
153 |
| - return ret; |
| 137 | + return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasDdot); |
154 | 138 | }
|
155 | 139 |
|
156 | 140 | DLLEXPORT CudaResults c_dot_product(const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[], cuComplex *result){
|
157 |
| - CudaResults ret; |
158 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasCdotu, &ret.error, &ret.blasStatus); |
159 |
| - return ret; |
| 141 | + return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasCdotu); |
160 | 142 | }
|
161 | 143 |
|
162 | 144 | DLLEXPORT CudaResults z_dot_product(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[], cuDoubleComplex *result){
|
163 |
| - CudaResults ret; |
164 |
| - cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasZdotu, &ret.error, &ret.blasStatus); |
165 |
| - return ret; |
| 145 | + return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasZdotu); |
166 | 146 | }
|
167 | 147 |
|
168 | 148 | DLLEXPORT CudaResults s_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[]){
|
169 |
| - CudaResults ret; |
170 | 149 | int lda = transA == CUBLAS_OP_N ? m : k;
|
171 | 150 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
172 | 151 |
|
173 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm, &ret.error, &ret.blasStatus); |
174 |
| - return ret; |
| 152 | + return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm); |
175 | 153 | }
|
176 | 154 |
|
177 | 155 | DLLEXPORT CudaResults d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){
|
178 |
| - CudaResults ret; |
179 | 156 | int lda = transA == CUBLAS_OP_N ? m : k;
|
180 | 157 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
181 | 158 |
|
182 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm, &ret.error, &ret.blasStatus); |
183 |
| - return ret; |
| 159 | + return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm); |
184 | 160 | }
|
185 | 161 |
|
186 | 162 | DLLEXPORT CudaResults c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){
|
187 |
| - CudaResults ret; |
188 | 163 | int lda = transA == CUBLAS_OP_N ? m : k;
|
189 | 164 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
190 | 165 |
|
191 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm, &ret.error, &ret.blasStatus); |
192 |
| - return ret; |
| 166 | + return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm); |
193 | 167 | }
|
194 | 168 |
|
195 | 169 | DLLEXPORT CudaResults z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){
|
196 |
| - CudaResults ret; |
197 | 170 | int lda = transA == CUBLAS_OP_N ? m : k;
|
198 | 171 | int ldb = transB == CUBLAS_OP_N ? k : n;
|
199 | 172 |
|
200 |
| - cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm, &ret.error, &ret.blasStatus); |
201 |
| - return ret; |
| 173 | + return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm); |
202 | 174 | }
|
203 | 175 |
|
204 | 176 | }
|
|
0 commit comments