@@ -93,88 +93,112 @@ void cuda_gemm(const cublasHandle_t handle, const cublasOperation_t transa, cons
93
93
94
94
extern " C" {
95
95
96
- DLLEXPORT void s_axpy (const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[], cudaError_t *error, cublasStatus_t *blasStatus){
97
- cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasSaxpy, error, blasStatus);
96
+ DLLEXPORT CudaResults s_axpy (const cublasHandle_t blasHandle, const int n, const float alpha, const float x[], float y[]){
97
+ CudaResults ret;
98
+ cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasSaxpy, &ret.error , &ret.blasStatus );
99
+ return ret;
98
100
}
99
101
100
- DLLEXPORT void d_axpy (const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[], cudaError_t *error, cublasStatus_t *blasStatus){
101
- cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasDaxpy, error, blasStatus);
102
+ DLLEXPORT CudaResults d_axpy (const cublasHandle_t blasHandle, const int n, const double alpha, const double x[], double y[]){
103
+ CudaResults ret;
104
+ cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasDaxpy, &ret.error , &ret.blasStatus );
105
+ return ret;
102
106
}
103
107
104
- DLLEXPORT void c_axpy (const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){
105
- cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasCaxpy, error, blasStatus);
108
+ DLLEXPORT CudaResults c_axpy (const cublasHandle_t blasHandle, const int n, const cuComplex alpha, const cuComplex x[], cuComplex y[]){
109
+ CudaResults ret;
110
+ cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasCaxpy, &ret.error , &ret.blasStatus );
111
+ return ret;
106
112
}
107
113
108
- DLLEXPORT void z_axpy (const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[], cudaError_t *error, cublasStatus_t *blasStatus){
109
- cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasZaxpy, error, blasStatus);
114
+ DLLEXPORT CudaResults z_axpy (const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, const cuDoubleComplex x[], cuDoubleComplex y[]){
115
+ CudaResults ret;
116
+ cuda_axpy (blasHandle, n, alpha, x, 1 , y, 1 , cublasZaxpy, &ret.error , &ret.blasStatus );
117
+ return ret;
110
118
}
111
119
112
- DLLEXPORT void s_scale (const cublasHandle_t blasHandle, const int n, const float alpha, float x[], cudaError_t *error, cublasStatus_t *blasStatus){
113
- cuda_scal (blasHandle, n, alpha, x, 1 , cublasSscal, error, blasStatus);
120
+ DLLEXPORT CudaResults s_scale (const cublasHandle_t blasHandle, const int n, const float alpha, float x[]){
121
+ CudaResults ret;
122
+ cuda_scal (blasHandle, n, alpha, x, 1 , cublasSscal, &ret.error , &ret.blasStatus );
123
+ return ret;
114
124
}
115
125
116
- DLLEXPORT void d_scale (const cublasHandle_t blasHandle, const int n, const double alpha, double x[], cudaError_t *error, cublasStatus_t *blasStatus){
117
- cuda_scal (blasHandle, n, alpha, x, 1 , cublasDscal, error, blasStatus);
126
+ DLLEXPORT CudaResults d_scale (const cublasHandle_t blasHandle, const int n, const double alpha, double x[]){
127
+ CudaResults ret;
128
+ cuda_scal (blasHandle, n, alpha, x, 1 , cublasDscal, &ret.error , &ret.blasStatus );
129
+ return ret;
118
130
}
119
131
120
- DLLEXPORT void c_scale (const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[], cudaError_t *error, cublasStatus_t *blasStatus){
121
- cuda_scal (blasHandle, n, alpha, x, 1 , cublasCscal, error, blasStatus);
132
+ DLLEXPORT CudaResults c_scale (const cublasHandle_t blasHandle, const int n, const cuComplex alpha, cuComplex x[]){
133
+ CudaResults ret;
134
+ cuda_scal (blasHandle, n, alpha, x, 1 , cublasCscal, &ret.error , &ret.blasStatus );
135
+ return ret;
122
136
}
123
137
124
- DLLEXPORT void z_scale (const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[], cudaError_t *error, cublasStatus_t *blasStatus){
125
- cuda_scal (blasHandle, n, alpha, x, 1 , cublasZscal, error, blasStatus);
138
+ DLLEXPORT CudaResults z_scale (const cublasHandle_t blasHandle, const int n, const cuDoubleComplex alpha, cuDoubleComplex x[]){
139
+ CudaResults ret;
140
+ cuda_scal (blasHandle, n, alpha, x, 1 , cublasZscal, &ret.error , &ret.blasStatus );
141
+ return ret;
126
142
}
127
143
128
- DLLEXPORT float s_dot_product (const cublasHandle_t blasHandle, const int n, const float x[], const float y[], cudaError_t *error, cublasStatus_t *blasStatus ){
129
- float ret;
130
- cuda_dot (blasHandle, n, x, 1 , y, 1 , &ret , cublasSdot, error, blasStatus);
144
+ DLLEXPORT CudaResults s_dot_product (const cublasHandle_t blasHandle, const int n, const float x[], const float y[], float *result ){
145
+ CudaResults ret;
146
+ cuda_dot (blasHandle, n, x, 1 , y, 1 , result , cublasSdot, &ret. error , &ret. blasStatus );
131
147
return ret;
132
148
}
133
149
134
- DLLEXPORT double d_dot_product (const cublasHandle_t blasHandle, const int n, const double x[], const double y[], cudaError_t *error, cublasStatus_t *blasStatus ){
135
- double ret;
136
- cuda_dot (blasHandle, n, x, 1 , y, 1 , &ret , cublasDdot, error, blasStatus);
150
+ DLLEXPORT CudaResults d_dot_product (const cublasHandle_t blasHandle, const int n, const double x[], const double y[], double *result ){
151
+ CudaResults ret;
152
+ cuda_dot (blasHandle, n, x, 1 , y, 1 , result , cublasDdot, &ret. error , &ret. blasStatus );
137
153
return ret;
138
154
}
139
155
140
- DLLEXPORT cuComplex c_dot_product (const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[], cudaError_t *error, cublasStatus_t *blasStatus ){
141
- cuComplex ret;
142
- cuda_dot (blasHandle, n, x, 1 , y, 1 , &ret , cublasCdotu, error, blasStatus);
156
+ DLLEXPORT CudaResults c_dot_product (const cublasHandle_t blasHandle, const int n, const cuComplex x[], const cuComplex y[], cuComplex *result ){
157
+ CudaResults ret;
158
+ cuda_dot (blasHandle, n, x, 1 , y, 1 , result , cublasCdotu, &ret. error , &ret. blasStatus );
143
159
return ret;
144
160
}
145
161
146
- DLLEXPORT cuDoubleComplex z_dot_product (const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[], cudaError_t *error, cublasStatus_t *blasStatus ){
147
- cuDoubleComplex ret;
148
- cuda_dot (blasHandle, n, x, 1 , y, 1 , &ret , cublasZdotu, error, blasStatus);
162
+ DLLEXPORT CudaResults z_dot_product (const cublasHandle_t blasHandle, const int n, const cuDoubleComplex x[], const cuDoubleComplex y[], cuDoubleComplex *result ){
163
+ CudaResults ret;
164
+ cuda_dot (blasHandle, n, x, 1 , y, 1 , result , cublasZdotu, &ret. error , &ret. blasStatus );
149
165
return ret;
150
166
}
151
167
152
- DLLEXPORT void s_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[], cudaError_t *error, cublasStatus_t *blasStatus){
168
+ DLLEXPORT CudaResults s_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const float alpha, const float x[], const float y[], const float beta, float c[]){
169
+ CudaResults ret;
153
170
int lda = transA == CUBLAS_OP_N ? m : k;
154
171
int ldb = transB == CUBLAS_OP_N ? k : n;
155
172
156
- cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm, error, blasStatus);
173
+ cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasSgemm, &ret.error , &ret.blasStatus );
174
+ return ret;
157
175
}
158
176
159
- DLLEXPORT void d_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[], cudaError_t *error, cublasStatus_t *blasStatus){
177
+ DLLEXPORT CudaResults d_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const double alpha, const double x[], const double y[], const double beta, double c[]){
178
+ CudaResults ret;
160
179
int lda = transA == CUBLAS_OP_N ? m : k;
161
180
int ldb = transB == CUBLAS_OP_N ? k : n;
162
181
163
- cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm, error, blasStatus);
182
+ cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasDgemm, &ret.error , &ret.blasStatus );
183
+ return ret;
164
184
}
165
185
166
- DLLEXPORT void c_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[], cudaError_t *error, cublasStatus_t *blasStatus){
186
+ DLLEXPORT CudaResults c_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuComplex alpha, const cuComplex x[], const cuComplex y[], const cuComplex beta, cuComplex c[]){
187
+ CudaResults ret;
167
188
int lda = transA == CUBLAS_OP_N ? m : k;
168
189
int ldb = transB == CUBLAS_OP_N ? k : n;
169
190
170
- cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm, error, blasStatus);
191
+ cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasCgemm, &ret.error , &ret.blasStatus );
192
+ return ret;
171
193
}
172
194
173
- DLLEXPORT void z_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[], cudaError_t *error, cublasStatus_t *blasStatus){
195
+ DLLEXPORT CudaResults z_matrix_multiply (const cublasHandle_t blasHandle, cublasOperation_t transA, cublasOperation_t transB, const int m, const int n, const int k, const cuDoubleComplex alpha, const cuDoubleComplex x[], const cuDoubleComplex y[], const cuDoubleComplex beta, cuDoubleComplex c[]){
196
+ CudaResults ret;
174
197
int lda = transA == CUBLAS_OP_N ? m : k;
175
198
int ldb = transB == CUBLAS_OP_N ? k : n;
176
199
177
- cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm, error, blasStatus);
200
+ cuda_gemm (blasHandle, transA, transB, m, n, k, alpha, x, lda, y, ldb, beta, c, m, cublasZgemm, &ret.error , &ret.blasStatus );
201
+ return ret;
178
202
}
179
203
180
204
}
0 commit comments