Skip to content

Commit 3f11f41

Browse files
matajohcdrnet
authored andcommitted
Adding blas exception handling to the library.
1 parent fdc773a commit 3f11f41

File tree

5 files changed

+89
-64
lines changed

5 files changed

+89
-64
lines changed

src/NativeProviders/CUDA/blas.cpp

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,180 @@
11
#include <stdio.h>
22
#include "cublas_v2.h"
33
#include "cuda_runtime.h"
4-
#include "wrapper_common.h"
4+
#include "wrapper_cuda.h"
55

66
template<typename T, typename AXPY>
7-
void cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy)
7+
void cuda_axpy(const cublasHandle_t blasHandle, const int n, const T alpha, const T x[], int incX, T y[], int incY, AXPY axpy, cudaError_t *error, cublasStatus_t *blasStatus)
88
{
99
T *d_X = NULL;
1010
T *d_Y = NULL;
11-
cudaMalloc((void**)&d_X, n*sizeof(T));
12-
cudaMalloc((void**)&d_Y, n*sizeof(T));
11+
*error = cudaError_t::cudaSuccess;
12+
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
1313

14-
cublasSetVector(n, sizeof(T), x, incX, d_X, incX);
15-
cublasSetVector(n, sizeof(T), y, incY, d_Y, incY);
14+
SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T)))
15+
SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T)))
1616

17-
axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX);
17+
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX))
18+
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY))
1819

19-
cublasGetVector(n, sizeof(T), d_Y, incY, y, incY);
20+
SAFECUDACALL(blasStatus, axpy(blasHandle, n, &alpha, d_X, incX, d_Y, incX))
2021

22+
SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_Y, incY, y, incY))
23+
24+
exit:
2125
cudaFree(d_X);
2226
cudaFree(d_Y);
2327
}
2428

2529
template<typename T, typename SCAL>
26-
void cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal)
30+
void cuda_scal(const cublasHandle_t blasHandle, const int n, const T alpha, T x[], int incX, SCAL scal, cudaError_t *error, cublasStatus_t *blasStatus)
2731
{
2832
T *d_X = NULL;
29-
cudaMalloc((void**)&d_X, n*sizeof(T));
30-
31-
cublasSetVector(n, sizeof(T), x, incX, d_X, incX);
32-
33-
scal(blasHandle, n, &alpha, d_X, incX);
33+
*error = cudaError_t::cudaSuccess;
34+
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
3435

35-
cublasGetVector(n, sizeof(T), d_X, incX, x, incX);
36+
SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T)))
37+
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX))
38+
SAFECUDACALL(blasStatus, scal(blasHandle, n, &alpha, d_X, incX))
39+
SAFECUDACALL(blasStatus, cublasGetVector(n, sizeof(T), d_X, incX, x, incX))
3640

41+
exit:
3742
cudaFree(d_X);
3843
}
3944

4045
template<typename T, typename DOT>
41-
void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot)
46+
void cuda_dot(const cublasHandle_t blasHandle, const int n, const T x[], int incX, const T y[], int incY, T* result, DOT dot, cudaError_t *error, cublasStatus_t *blasStatus)
4247
{
4348
T *d_X = NULL;
4449
T *d_Y = NULL;
45-
cudaMalloc((void**)&d_X, n*sizeof(T));
46-
cudaMalloc((void**)&d_Y, n*sizeof(T));
50+
*error = cudaError_t::cudaSuccess;
51+
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
4752

48-
cublasSetVector(n, sizeof(T), x, incX, d_X, incX);
49-
cublasSetVector(n, sizeof(T), y, incY, d_Y, incY);
53+
SAFECUDACALL(error, cudaMalloc((void**)&d_X, n*sizeof(T)))
54+
SAFECUDACALL(error, cudaMalloc((void**)&d_Y, n*sizeof(T)))
5055

51-
dot(blasHandle, n, d_X, incX, d_Y, incY, result);
56+
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), x, incX, d_X, incX))
57+
SAFECUDACALL(blasStatus, cublasSetVector(n, sizeof(T), y, incY, d_Y, incY))
5258

59+
SAFECUDACALL(blasStatus, dot(blasHandle, n, d_X, incX, d_Y, incY, result))
60+
61+
exit:
5362
cudaFree(d_X);
5463
cudaFree(d_Y);
5564
}
5665

5766
template<typename T, typename GEMM>
58-
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm)
67+
void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, int n, int k, const T alpha, const T A[], int lda, const T B[], int ldb, const T beta, T C[], int ldc, GEMM gemm, cudaError_t *error, cublasStatus_t *blasStatus)
5968
{
6069
T *d_A = NULL;
61-
cudaMalloc((void**)&d_A, m*k*sizeof(T));
62-
cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m);
63-
6470
T *d_B = NULL;
65-
cudaMalloc((void**)&d_B, k*n*sizeof(T));
66-
cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k);
67-
6871
T *d_C = NULL;
69-
cudaMalloc((void**)&d_C, m*n*sizeof(T));
70-
cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m);
72+
*error = cudaError_t::cudaSuccess;
73+
*blasStatus = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
74+
75+
SAFECUDACALL(error, cudaMalloc((void**)&d_A, m*k*sizeof(T)))
76+
SAFECUDACALL(blasStatus, cublasSetMatrix(m, k, sizeof(T), A, m, d_A, m))
77+
78+
SAFECUDACALL(error, cudaMalloc((void**)&d_B, k*n*sizeof(T)))
79+
SAFECUDACALL(blasStatus, cublasSetMatrix(k, n, sizeof(T), B, k, d_B, k))
80+
81+
SAFECUDACALL(error, cudaMalloc((void**)&d_C, m*n*sizeof(T)))
82+
SAFECUDACALL(blasStatus, cublasSetMatrix(m, n, sizeof(T), C, m, d_C, m))
7183

72-
gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc);
84+
SAFECUDACALL(blasStatus, gemm(handle, transa, transb, m, n, k, &alpha, d_A, lda, d_B, ldb, &beta, d_C, ldc))
7385

74-
cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m);
86+
SAFECUDACALL(blasStatus, cublasGetMatrix(m, n, sizeof(T), d_C, m, C, m))
7587

88+
exit:
7689
cudaFree(d_A);
7790
cudaFree(d_B);
7891
cudaFree(d_C);
7992
}
8093

8194
extern "C" {
8295

83-
DLLEXPORT void s_axpy(const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[]){
84-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy);
96+
DLLEXPORT void s_axpy(const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[], cudaError_t *error, cublasStatus_t *blasStatus){
97+
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasSaxpy, error, blasStatus);
8598
}
8699

87-
DLLEXPORT void d_axpy(const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[]){
88-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy);
100+
DLLEXPORT void d_axpy(const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[], cudaError_t *error, cublasStatus_t *blasStatus){
101+
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasDaxpy, error, blasStatus);
89102
}
90103

91-
DLLEXPORT void c_axpy(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[]){
92-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy);
104+
DLLEXPORT void c_axpy(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){
105+
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasCaxpy, error, blasStatus);
93106
}
94107

95-
DLLEXPORT void z_axpy(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[]){
96-
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy);
108+
DLLEXPORT void z_axpy(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){
109+
cuda_axpy(blasHandle, n, alpha, x, 1, y, 1, cublasZaxpy, error, blasStatus);
97110
}
98111

99-
DLLEXPORT void s_scale(const cublasHandle_t blasHandle, const int n, const float alpha, float x[]){
100-
cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal);
112+
DLLEXPORT void s_scale(const cublasHandle_t blasHandle, const int n, const float alpha, float x[], cudaError_t *error, cublasStatus_t *blasStatus){
113+
cuda_scal(blasHandle, n, alpha, x, 1, cublasSscal, error, blasStatus);
101114
}
102115

103-
DLLEXPORT void d_scale(const cublasHandle_t blasHandle, const int n, const double alpha, double x[]){
104-
cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal);
116+
DLLEXPORT void d_scale(const cublasHandle_t blasHandle, const int n, const double alpha, double x[], cudaError_t *error, cublasStatus_t *blasStatus){
117+
cuda_scal(blasHandle, n, alpha, x, 1, cublasDscal, error, blasStatus);
105118
}
106119

107-
DLLEXPORT void c_scale(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[]){
108-
cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal);
120+
DLLEXPORT void c_scale(const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[], cudaError_t *error, cublasStatus_t *blasStatus){
121+
cuda_scal(blasHandle, n, alpha, x, 1, cublasCscal, error, blasStatus);
109122
}
110123

111-
DLLEXPORT void z_scale(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[]){
112-
cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal);
124+
DLLEXPORT void z_scale(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[], cudaError_t *error, cublasStatus_t *blasStatus){
125+
cuda_scal(blasHandle, n, alpha, x, 1, cublasZscal, error, blasStatus);
113126
}
114127

115-
DLLEXPORT float s_dot_product(const cublasHandle_t blasHandle, const int n, const float x[], const float y[]){
128+
DLLEXPORT float s_dot_product(const cublasHandle_t blasHandle, const int n, const float x[], const float y[], cudaError_t *error, cublasStatus_t *blasStatus){
116129
float ret;
117-
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasSdot);
130+
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasSdot, error, blasStatus);
118131
return ret;
119132
}
120133

121-
DLLEXPORT double d_dot_product(const cublasHandle_t blasHandle, const int n, const double x[], const double y[]){
134+
DLLEXPORT double d_dot_product(const cublasHandle_t blasHandle, const int n, const double x[], const double y[], cudaError_t *error, cublasStatus_t *blasStatus){
122135
double ret;
123-
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasDdot);
136+
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasDdot, error, blasStatus);
124137
return ret;
125138
}
126139

127-
DLLEXPORT cuComplex c_dot_product(const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[]){
140+
DLLEXPORT cuComplex c_dot_product(const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){
128141
cuComplex ret;
129-
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasCdotu);
142+
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasCdotu, error, blasStatus);
130143
return ret;
131144
}
132145

133-
DLLEXPORT cuDoubleComplex z_dot_product(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[]){
146+
DLLEXPORT cuDoubleComplex z_dot_product(const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){
134147
cuDoubleComplex ret;
135-
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasZdotu);
148+
cuda_dot(blasHandle, n, x, 1, y, 1, &ret, cublasZdotu, error, blasStatus);
136149
return ret;
137150
}
138151

139-
DLLEXPORT void s_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[]){
152+
DLLEXPORT void s_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[], cudaError_t *error, cublasStatus_t *blasStatus){
140153
int lda = transA == CUBLAS_OP_N ? m : k;
141154
int ldb = transB == CUBLAS_OP_N ? k : n;
142155

143-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm);
156+
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm, error, blasStatus);
144157
}
145158

146-
DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){
159+
DLLEXPORT void d_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[], cudaError_t *error, cublasStatus_t *blasStatus){
147160
int lda = transA == CUBLAS_OP_N ? m : k;
148161
int ldb = transB == CUBLAS_OP_N ? k : n;
149162

150-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm);
163+
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm, error, blasStatus);
151164
}
152165

153-
DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){
166+
DLLEXPORT void c_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[], cudaError_t *error, cublasStatus_t *blasStatus){
154167
int lda = transA == CUBLAS_OP_N ? m : k;
155168
int ldb = transB == CUBLAS_OP_N ? k : n;
156169

157-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm);
170+
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm, error, blasStatus);
158171
}
159172

160-
DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){
173+
DLLEXPORT void z_matrix_multiply(const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[], cudaError_t *error, cublasStatus_t *blasStatus){
161174
int lda = transA == CUBLAS_OP_N ? m : k;
162175
int ldb = transB == CUBLAS_OP_N ? k : n;
163176

164-
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm);
177+
cuda_gemm(blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm, error, blasStatus);
165178
}
166179

167180
}

src/NativeProviders/CUDA/memory.c

Whitespace-only changes.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef WRAPPER_CUDA_H
2+
#define WRAPPER_CUDA_H
3+
4+
#include "wrapper_common.h"
5+
6+
#define SAFECUDACALL(error,call) {*error = call; if(*error){goto exit;}}
7+
8+
#endif

src/NativeProviders/Windows/CUDA/CUDAWrapper.vcxproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
</ItemGroup>
3030
<ItemGroup>
3131
<ClInclude Include="..\..\CUDA\resource.h" />
32+
<ClInclude Include="..\..\CUDA\wrapper_cuda.h" />
3233
</ItemGroup>
3334
<PropertyGroup Label="Globals">
3435
<ProjectGuid>{5A52B796-7F41-4C90-8DE2-F3F391C4482C}</ProjectGuid>

src/NativeProviders/Windows/CUDA/CUDAWrapper.vcxproj.filters

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,8 @@
3737
<ClInclude Include="..\..\CUDA\resource.h">
3838
<Filter>Header Files</Filter>
3939
</ClInclude>
40+
<ClInclude Include="..\..\CUDA\wrapper_cuda.h">
41+
<Filter>Header Files</Filter>
42+
</ClInclude>
4043
</ItemGroup>
4144
</Project>

0 commit comments

Comments
 (0)