Skip to content

Commit ca0a1ab

Browse files
matajohcdrnet
authored andcommitted
Cleaned up the safe call mechanism further, and made all the managed changes.
1 parent 43d7a79 commit ca0a1ab

11 files changed

+184
-193
lines changed

src/NativeProviders/CUDA/blas.cpp

Lines changed: 55 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -4,201 +4,173 @@
44
#include "wrapper_cuda.h"
55

66
template<typename T, typename AXPY>
7-
void cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy, cudaError_t *error, cublasStatus_t *blasStatus)
7+
CudaResults cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy)
88
{
99
T *d_X = NULL;
1010
T *d_Y = NULL;
11-
*error = cudaError_t::cudaSuccess;
12-
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
11+
CudaResults results;
1312

14-
SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T)))
15-
SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T)))
13+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_X, n*sizeof(T)));
14+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_Y, n*sizeof(T)));
1615

17-
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX))
18-
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY))
16+
SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX));
17+
SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY));
1918

20-
SAFECUDACALL(blasStatus, axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX))
19+
SAFECUDACALL(results.blasStatus, axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX));
2120

22-
SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_Y, incY, y, incY))
21+
SAFECUDACALL(results.blasStatus, cublasGetVector(n, sizeof(T), d_Y, incY, y, incY));
2322

2423
exit:
2524
cudaFree(d_X);
2625
cudaFree(d_Y);
26+
27+
return results;
2728
}
2829

2930
template<typename T, typename SCAL>
30-
void cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal, cudaError_t *error, cublasStatus_t *blasStatus)
31+
CudaResults cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal)
3132
{
3233
T *d_X = NULL;
33-
*error = cudaError_t::cudaSuccess;
34-
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
34+
CudaResults results;
3535

36-
SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T)))
37-
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX))
38-
SAFECUDACALL(blasStatus, scal(blasHandle, n, &alpha, d_X, incX))
39-
SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_X, incX, x, incX))
36+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_X, n*sizeof(T)));
37+
SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX));
38+
SAFECUDACALL(results.blasStatus, scal(blasHandle, n, &alpha, d_X, incX));
39+
SAFECUDACALL(results.blasStatus, cublasGetVector(n, sizeof(T), d_X, incX, x, incX));
4040

4141
exit:
4242
cudaFree(d_X);
43+
44+
return results;
4345
}
4446

4547
template<typename T, typename DOT>
46-
void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot, cudaError_t *error, cublasStatus_t *blasStatus)
48+
CudaResults cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot)
4749
{
4850
T *d_X = NULL;
4951
T *d_Y = NULL;
50-
*error = cudaError_t::cudaSuccess;
51-
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
52+
CudaResults results;
5253

53-
SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T)))
54-
SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T)))
54+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_X, n*sizeof(T)));
55+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_Y, n*sizeof(T)));
5556

56-
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX))
57-
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY))
57+
SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX));
58+
SAFECUDACALL(results.blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY));
5859

59-
SAFECUDACALL(blasStatus, dot(blasHandle, n, d_X, incX, d_Y, incY, result))
60+
SAFECUDACALL(results.blasStatus, dot(blasHandle, n, d_X, incX, d_Y, incY, result));
6061

6162
exit:
6263
cudaFree(d_X);
6364
cudaFree(d_Y);
65+
66+
return results;
6467
}
6568

6669
template<typename T, typename GEMM>
67-
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm, cudaError_t *error, cublasStatus_t *blasStatus)
70+
CudaResults cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm)
6871
{
6972
T *d_A = NULL;
7073
T *d_B = NULL;
7174
T *d_C = NULL;
72-
*error = cudaError_t::cudaSuccess;
73-
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
75+
CudaResults results;
7476

75-
SAFECUDACALL(error, cudaMalloc((void**)&d_A, m*k*sizeof(T)))
76-
SAFECUDACALL(blasStatus, cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m))
77+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_A, m*k*sizeof(T)));
78+
SAFECUDACALL(results.blasStatus, cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m));
7779

78-
SAFECUDACALL(error, cudaMalloc((void**)&d_B, k*n*sizeof(T)))
79-
SAFECUDACALL(blasStatus, cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k))
80+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_B, k*n*sizeof(T)));
81+
SAFECUDACALL(results.blasStatus, cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k));
8082

81-
SAFECUDACALL(error, cudaMalloc((void**)&d_C, m*n*sizeof(T)))
82-
SAFECUDACALL(blasStatus, cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m))
83+
SAFECUDACALL(results.error, cudaMalloc((void**)&d_C, m*n*sizeof(T)));
84+
SAFECUDACALL(results.blasStatus, cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m));
8385

84-
SAFECUDACALL(blasStatus, gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc))
86+
SAFECUDACALL(results.blasStatus, gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc));
8587

86-
SAFECUDACALL(blasStatus, cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m))
88+
SAFECUDACALL(results.blasStatus, cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m));
8789

8890
exit:
8991
cudaFree(d_A);
9092
cudaFree(d_B);
9193
cudaFree(d_C);
94+
95+
return results;
9296
}
9397

9498
extern "C" {
9599

96100
DLLEXPORT CudaResults s_axpy(const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[]){
97-
CudaResults ret;
98-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy, &ret.error, &ret.blasStatus);
99-
return ret;
101+
return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy);
100102
}
101103

102104
DLLEXPORT CudaResults d_axpy(const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[]){
103-
CudaResults ret;
104-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy, &ret.error, &ret.blasStatus);
105-
return ret;
105+
return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy);
106106
}
107107

108108
DLLEXPORT CudaResults c_axpy(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[]){
109-
CudaResults ret;
110-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy, &ret.error, &ret.blasStatus);
111-
return ret;
109+
return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy);
112110
}
113111

114112
DLLEXPORT CudaResults z_axpy(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[]){
115-
CudaResults ret;
116-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy, &ret.error, &ret.blasStatus);
117-
return ret;
113+
return cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy);
118114
}
119115

120116
DLLEXPORT CudaResults s_scale(const cublasHandle_t blasHandle, const int n, const float alpha, float x[]){
121-
CudaResults ret;
122-
cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal, &ret.error, &ret.blasStatus);
123-
return ret;
117+
return cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal);
124118
}
125119

126120
DLLEXPORT CudaResults d_scale(const cublasHandle_t blasHandle, const int n, const double alpha, double x[]){
127-
CudaResults ret;
128-
cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal, &ret.error, &ret.blasStatus);
129-
return ret;
121+
return cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal);
130122
}
131123

132124
DLLEXPORT CudaResults c_scale(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[]){
133-
CudaResults ret;
134-
cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal, &ret.error, &ret.blasStatus);
135-
return ret;
125+
return cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal);
136126
}
137127

138128
DLLEXPORT CudaResults z_scale(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[]){
139-
CudaResults ret;
140-
cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal, &ret.error, &ret.blasStatus);
141-
return ret;
129+
return cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal);
142130
}
143131

144132
DLLEXPORT CudaResults s_dot_product(const cublasHandle_t blasHandle, const int n, const float x[], const float y[], float *result){
145-
CudaResults ret;
146-
cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasSdot, &ret.error, &ret.blasStatus);
147-
return ret;
133+
return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasSdot);
148134
}
149135

150136
DLLEXPORT CudaResults d_dot_product(const cublasHandle_t blasHandle, const int n, const double x[], const double y[], double *result){
151-
CudaResults ret;
152-
cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasDdot, &ret.error, &ret.blasStatus);
153-
return ret;
137+
return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasDdot);
154138
}
155139

156140
DLLEXPORT CudaResults c_dot_product(const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[], cuComplex *result){
157-
CudaResults ret;
158-
cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasCdotu, &ret.error, &ret.blasStatus);
159-
return ret;
141+
return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasCdotu);
160142
}
161143

162144
DLLEXPORT CudaResults z_dot_product(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[], cuDoubleComplex *result){
163-
CudaResults ret;
164-
cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasZdotu, &ret.error, &ret.blasStatus);
165-
return ret;
145+
return cuda_dot(blasHandle, n, x, 1, y, 1, result, cublasZdotu);
166146
}
167147

168148
DLLEXPORT CudaResults s_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[]){
169-
CudaResults ret;
170149
int lda = transA == CUBLAS_OP_N ? m : k;
171150
int ldb = transB == CUBLAS_OP_N ? k : n;
172151

173-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm, &ret.error, &ret.blasStatus);
174-
return ret;
152+
return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm);
175153
}
176154

177155
DLLEXPORT CudaResults d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){
178-
CudaResults ret;
179156
int lda = transA == CUBLAS_OP_N ? m : k;
180157
int ldb = transB == CUBLAS_OP_N ? k : n;
181158

182-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm, &ret.error, &ret.blasStatus);
183-
return ret;
159+
return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm);
184160
}
185161

186162
DLLEXPORT CudaResults c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){
187-
CudaResults ret;
188163
int lda = transA == CUBLAS_OP_N ? m : k;
189164
int ldb = transB == CUBLAS_OP_N ? k : n;
190165

191-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm, &ret.error, &ret.blasStatus);
192-
return ret;
166+
return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm);
193167
}
194168

195169
DLLEXPORT CudaResults z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){
196-
CudaResults ret;
197170
int lda = transA == CUBLAS_OP_N ? m : k;
198171
int ldb = transB == CUBLAS_OP_N ? k : n;
199172

200-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm, &ret.error, &ret.blasStatus);
201-
return ret;
173+
return cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm);
202174
}
203175

204176
}

src/NativeProviders/CUDA/capabilities.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "wrapper_common.h"
1+
#include "wrapper_cuda.h"
22
#include "cuda_runtime.h"
33
#include "cublas_v2.h"
44
#include "cusolverDn.h"
@@ -79,20 +79,28 @@ extern "C" {
7979
}
8080
}
8181

82-
DLLEXPORT cublasStatus_t createBLASHandle(cublasHandle_t *blasHandle){
83-
return cublasCreate(blasHandle);
82+
DLLEXPORT CudaResults createBLASHandle(cublasHandle_t *blasHandle){
83+
CudaResults ret;
84+
ret.blasStatus = cublasCreate(blasHandle);
85+
return ret;
8486
}
8587

86-
DLLEXPORT cublasStatus_t destroyBLASHandle(cublasHandle_t blasHandle){
87-
return cublasDestroy(blasHandle);
88+
DLLEXPORT CudaResults destroyBLASHandle(cublasHandle_t blasHandle){
89+
CudaResults ret;
90+
ret.blasStatus = cublasDestroy(blasHandle);
91+
return ret;
8892
}
8993

90-
DLLEXPORT cusolverStatus_t createSolverHandle(cusolverDnHandle_t *solverHandle){
91-
return cusolverDnCreate(solverHandle);
94+
DLLEXPORT CudaResults createSolverHandle(cusolverDnHandle_t *solverHandle){
95+
CudaResults ret;
96+
ret.solverStatus = cusolverDnCreate(solverHandle);
97+
return ret;
9298
}
9399

94-
DLLEXPORT cusolverStatus_t destroySolverHandle(cusolverDnHandle_t solverHandle){
95-
return cusolverDnDestroy(solverHandle);
100+
DLLEXPORT CudaResults destroySolverHandle(cusolverDnHandle_t solverHandle){
101+
CudaResults ret;
102+
ret.solverStatus = cusolverDnDestroy(solverHandle);
103+
return ret;
96104
}
97105

98106
#ifdef __cplusplus

src/NativeProviders/CUDA/wrapper_cuda.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
#include "wrapper_common.h"
55
#include "cuda_runtime.h"
6+
#include "cublas_v2.h"
67
#include "cusolver_common.h"
78

8-
#define SAFECUDACALL(error,call) {*error = call; if(*error){goto exit;}}
9+
#define SAFECUDACALL(error,call) {error = call; if(error){goto exit;}}
910

1011
typedef struct
1112
{

src/Numerics/Control.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,10 @@ public static ILinearAlgebraProvider LinearAlgebraProvider
261261
{
262262
value.InitializeVerify();
263263

264+
// dispose the previous value if necessary
265+
if (_linearAlgebraProvider != null && _linearAlgebraProvider is IDisposable)
266+
(_linearAlgebraProvider as IDisposable).Dispose();
267+
264268
// only actually set if verification did not throw
265269
_linearAlgebraProvider = value;
266270
}

0 commit comments

Comments
 (0)