Actual source code: cupmblasinterface.hpp

  1: #pragma once

  3: #include <petsc/private/cupminterface.hpp>
  4: #include <petsc/private/petscadvancedmacros.h>

  6: #include <limits> // std::numeric_limits

  8: namespace Petsc
  9: {

 11: namespace device
 12: {

 14: namespace cupm
 15: {

 17: namespace impl
 18: {

 20: #define PetscCallCUPMBLAS_(__abort_fn__, __comm__, ...) \
 21:   do { \
 22:     PetscStackUpdateLine; \
 23:     const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
 24:     if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 25:       if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 26:         __abort_fn__(__comm__, PETSC_ERR_GPU_RESOURCE, \
 27:                      "%s error %d (%s). Reports not initialized or alloc failed; " \
 28:                      "this indicates the GPU may have run out resources", \
 29:                      cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 30:       } \
 31:       __abort_fn__(__comm__, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 32:     } \
 33:   } while (0)

 35: #define PetscCallCUPMBLAS(...)             PetscCallCUPMBLAS_(SETERRQ, PETSC_COMM_SELF, __VA_ARGS__)
 36: #define PetscCallCUPMBLASAbort(comm_, ...) PetscCallCUPMBLAS_(SETERRABORT, comm_, __VA_ARGS__)

 38: // given cupmBlasaxpy() then
 39: // T = PETSC_CUPBLAS_FP_TYPE
 40: // given cupmBlasnrm2() then
 41: // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
 42: // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
 43: #if PetscDefined(USE_COMPLEX)
 44:   #if PetscDefined(USE_REAL_SINGLE)
 45:     #define PETSC_CUPMBLAS_FP_TYPE_U       C
 46:     #define PETSC_CUPMBLAS_FP_TYPE_L       c
 47:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
 48:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
 49:   #elif PetscDefined(USE_REAL_DOUBLE)
 50:     #define PETSC_CUPMBLAS_FP_TYPE_U       Z
 51:     #define PETSC_CUPMBLAS_FP_TYPE_L       z
 52:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
 53:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
 54:   #endif
 55:   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 56:   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 57: #else
 58:   #if PetscDefined(USE_REAL_SINGLE)
 59:     #define PETSC_CUPMBLAS_FP_TYPE_U S
 60:     #define PETSC_CUPMBLAS_FP_TYPE_L s
 61:   #elif PetscDefined(USE_REAL_DOUBLE)
 62:     #define PETSC_CUPMBLAS_FP_TYPE_U D
 63:     #define PETSC_CUPMBLAS_FP_TYPE_L d
 64:   #endif
 65:   #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 66:   #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 67:   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
 68:   #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
 69: #endif // USE_COMPLEX

 71: #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
 72:   #error "Unsupported floating-point type for CUDA/HIP BLAS"
 73: #endif

 75: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
 76: // blas function whose return type does not match the input type
 77: //
 78: // input param:
 79: // func - base suffix of the blas function, e.g. nrm2
 80: //
 81: // notes:
 82: // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
 83: // letter ("S" for real/complex single, "D" for real/complex double).
 84: //
 85: // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
 86: // letter ("c" for complex single, "z" for complex double and  for real
 87: // single/double).
 88: //
 89: // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
 90: // infuriatingly inconsistent...
 91: //
 92: // example usage:
 93: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  S
 94: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
 95: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
 96: //
 97: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  D
 98: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
 99: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
100: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)

102: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
103: // because they are both extra special
104: //
105: // input param:
106: // func - base suffix of the blas function, either amax or amin
107: //
108: // notes:
109: // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
110: // that's what it does.
111: //
112: // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
113: // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
114: // real double).
115: //
116: // example usage:
117: // #define PETSC_CUPMBLAS_FP_TYPE_L s
118: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
119: //
120: // #define PETSC_CUPMBLAS_FP_TYPE_L z
121: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
122: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))

124: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
125: // blas function name
126: //
127: // input param:
128: // func - base suffix of the blas function, e.g. axpy, scal
129: //
130: // notes:
131: // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
132: // complex single, "Z" for complex double, "S" for real single, "D" for real double).
133: //
134: // example usage:
135: // #define PETSC_CUPMBLAS_FP_TYPE S
136: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
137: //
138: // #define PETSC_CUPMBLAS_FP_TYPE Z
139: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
140: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)

142: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
143: // one can provide both here
144: //
145: // input params:
146: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
147: // IFPTYPE
148: // our_suffix   - the suffix of the alias function
149: // their_suffix - the suffix of the function being aliased
150: //
151: // notes:
152: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
153: // prefix. requires any other specific definitions required by the specific builder macro to
154: // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
155: // function alias.
156: //
157: // example usage:
158: // #define PETSC_CUPMBLAS_PREFIX  cublas
159: // #define PETSC_CUPMBLAS_FP_TYPE C
160: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
161: // template 
162: // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
163: // {
164: //   return cublasCdotc(std::forward(args)...);
165: // }
166: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
167:   PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))

169: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
170: //
171: // input params:
172: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
173: // IFPTYPE
174: // suffix       - the common suffix between CUDA and HIP of the alias function
175: //
176: // notes:
177: // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
178: // "our_prefix" and "their_prefix"
179: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)

181: // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
182: //
183: // input params:
184: // suffix - the common suffix between CUDA and HIP of the alias function
185: //
186: // notes:
187: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
188: // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
189: //
190: // example usage:
191: // #define PETSC_CUPMBLAS_PREFIX hipblas
192: // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
193: // template 
194: // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
195: // {
196: //   return hipblasCreate(std::forward(args)...);
197: // }
198: #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))

200: template <DeviceType>
201: struct BlasInterfaceImpl;

203: // Exists because HIP (for whatever godforsaken reason) has elected to define both their
204: // hipBlasHandle_t and hipSolverHandle_t as void *. So we cannot disambiguate them for overload
205: // resolution and hence need to wrap their types int this mess.
206: template <typename T, std::size_t I>
207: class cupmBlasHandleWrapper {
208: public:
209:   constexpr cupmBlasHandleWrapper() noexcept = default;
210:   constexpr cupmBlasHandleWrapper(T h) noexcept : handle_{std::move(h)} { static_assert(std::is_standard_layout<cupmBlasHandleWrapper<T, I>>::value, ""); }

212:   cupmBlasHandleWrapper &operator=(std::nullptr_t) noexcept
213:   {
214:     handle_ = nullptr;
215:     return *this;
216:   }

218:   operator T() const { return handle_; }

220:   const T *ptr_to() const { return &handle_; }
221:   T       *ptr_to() { return &handle_; }

223: private:
224:   T handle_{};
225: };

227: #if PetscDefined(HAVE_CUDA)
228:   #define PETSC_CUPMBLAS_PREFIX         cublas
229:   #define PETSC_CUPMBLAS_PREFIX_U       CUBLAS
230:   #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
231:   #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
232:   #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
233: template <>
234: struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterfaceImpl<DeviceType::CUDA> : Interface<DeviceType::CUDA> {
235:   // typedefs
236:   using cupmBlasHandle_t      = cupmBlasHandleWrapper<cublasHandle_t, 0>;
237:   using cupmBlasError_t       = cublasStatus_t;
238:   using cupmBlasInt_t         = int;
239:   using cupmBlasPointerMode_t = cublasPointerMode_t;

241:   // values
242:   static const auto CUPMBLAS_STATUS_SUCCESS         = CUBLAS_STATUS_SUCCESS;
243:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
244:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = CUBLAS_STATUS_ALLOC_FAILED;
245:   static const auto CUPMBLAS_POINTER_MODE_HOST      = CUBLAS_POINTER_MODE_HOST;
246:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = CUBLAS_POINTER_MODE_DEVICE;
247:   static const auto CUPMBLAS_OP_T                   = CUBLAS_OP_T;
248:   static const auto CUPMBLAS_OP_N                   = CUBLAS_OP_N;
249:   static const auto CUPMBLAS_OP_C                   = CUBLAS_OP_C;
250:   static const auto CUPMBLAS_FILL_MODE_LOWER        = CUBLAS_FILL_MODE_LOWER;
251:   static const auto CUPMBLAS_FILL_MODE_UPPER        = CUBLAS_FILL_MODE_UPPER;
252:   static const auto CUPMBLAS_SIDE_LEFT              = CUBLAS_SIDE_LEFT;
253:   static const auto CUPMBLAS_SIDE_RIGHT             = CUBLAS_SIDE_RIGHT;
254:   static const auto CUPMBLAS_DIAG_NON_UNIT          = CUBLAS_DIAG_NON_UNIT;

256:   // utility functions
257:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
258:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
259:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
260:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
261:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
262:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

264:   // level 1 BLAS
265:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
266:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, copy)
267:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
268:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
269:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
270:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
271:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
272:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
273:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

275:   // level 2 BLAS
276:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
277:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trmv)
278:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsv)
279:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gbmv)
280:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbmv)
281:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbsv)
282:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, hemv, PetscIfPetscDefined(USE_COMPLEX, hemv, symv))

284:   // level 3 BLAS
285:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
286:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)

288:   // BLAS extensions
289:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
290:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, dgmm)

292:   PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscCUBLASGetErrorName(status); }
293: };
294:   #undef PETSC_CUPMBLAS_PREFIX
295:   #undef PETSC_CUPMBLAS_PREFIX_U
296:   #undef PETSC_CUPMBLAS_FP_TYPE
297:   #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
298:   #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
299: #endif // PetscDefined(HAVE_CUDA)

301: #if PetscDefined(HAVE_HIP)
302:   #define PETSC_CUPMBLAS_PREFIX         hipblas
303:   #define PETSC_CUPMBLAS_PREFIX_U       HIPBLAS
304:   #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
305:   #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
306:   #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
307: template <>
308: struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterfaceImpl<DeviceType::HIP> : Interface<DeviceType::HIP> {
309:   // typedefs
310:   using cupmBlasHandle_t      = cupmBlasHandleWrapper<hipblasHandle_t, 0>;
311:   using cupmBlasError_t       = hipblasStatus_t;
312:   using cupmBlasInt_t         = int; // rocblas will have its own
313:   using cupmBlasPointerMode_t = hipblasPointerMode_t;

315:   // values
316:   static const auto CUPMBLAS_STATUS_SUCCESS         = HIPBLAS_STATUS_SUCCESS;
317:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
318:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = HIPBLAS_STATUS_ALLOC_FAILED;
319:   static const auto CUPMBLAS_POINTER_MODE_HOST      = HIPBLAS_POINTER_MODE_HOST;
320:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = HIPBLAS_POINTER_MODE_DEVICE;
321:   static const auto CUPMBLAS_OP_T                   = HIPBLAS_OP_T;
322:   static const auto CUPMBLAS_OP_N                   = HIPBLAS_OP_N;
323:   static const auto CUPMBLAS_OP_C                   = HIPBLAS_OP_C;
324:   static const auto CUPMBLAS_FILL_MODE_LOWER        = HIPBLAS_FILL_MODE_LOWER;
325:   static const auto CUPMBLAS_FILL_MODE_UPPER        = HIPBLAS_FILL_MODE_UPPER;
326:   static const auto CUPMBLAS_SIDE_LEFT              = HIPBLAS_SIDE_LEFT;
327:   static const auto CUPMBLAS_SIDE_RIGHT             = HIPBLAS_SIDE_RIGHT;
328:   static const auto CUPMBLAS_DIAG_NON_UNIT          = HIPBLAS_DIAG_NON_UNIT;

330:   // utility functions
331:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
332:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
333:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
334:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
335:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
336:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

338:   // level 1 BLAS
339:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
340:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, copy)
341:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
342:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
343:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
344:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
345:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
346:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
347:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

349:   // level 2 BLAS
350:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
351:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trmv)
352:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsv)
353:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gbmv)
354:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbmv)
355:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, tbsv)
356:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, hemv, PetscIfPetscDefined(USE_COMPLEX, hemv, symv))

358:   // level 3 BLAS
359:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
360:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, trsm)

362:   // BLAS extensions
363:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
364:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, dgmm)

366:   PETSC_NODISCARD static const char *cupmBlasGetErrorName(cupmBlasError_t status) noexcept { return PetscHIPBLASGetErrorName(status); }
367: };
368:   #undef PETSC_CUPMBLAS_PREFIX
369:   #undef PETSC_CUPMBLAS_PREFIX_U
370:   #undef PETSC_CUPMBLAS_FP_TYPE
371:   #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
372:   #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
373: #endif // PetscDefined(HAVE_HIP)

375: #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T) \
376:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
377:   /* introspection */ \
378:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetErrorName; \
379:   /* types */ \
380:   using cupmBlasHandle_t      = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasHandle_t; \
381:   using cupmBlasError_t       = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasError_t; \
382:   using cupmBlasInt_t         = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasInt_t; \
383:   using cupmBlasPointerMode_t = typename ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasPointerMode_t; \
384:   /* values */ \
385:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_SUCCESS; \
386:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_NOT_INITIALIZED; \
387:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_STATUS_ALLOC_FAILED; \
388:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_HOST; \
389:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_POINTER_MODE_DEVICE; \
390:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_T; \
391:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_N; \
392:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_OP_C; \
393:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_LOWER; \
394:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_FILL_MODE_UPPER; \
395:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_LEFT; \
396:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_SIDE_RIGHT; \
397:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::CUPMBLAS_DIAG_NON_UNIT; \
398:   /* utility functions */ \
399:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasCreate; \
400:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasDestroy; \
401:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetStream; \
402:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetStream; \
403:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasGetPointerMode; \
404:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasSetPointerMode; \
405:   /* level 1 BLAS */ \
406:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXaxpy; \
407:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXcopy; \
408:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXscal; \
409:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdot; \
410:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdotu; \
411:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXswap; \
412:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXnrm2; \
413:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXamax; \
414:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXasum; \
415:   /* level 2 BLAS */ \
416:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemv; \
417:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrmv; \
418:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsv; \
419:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgbmv; \
420:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtbmv; \
421:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtbsv; \
422:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXhemv; \
423:   /* level 3 BLAS */ \
424:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgemm; \
425:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXtrsm; \
426:   /* BLAS extensions */ \
427:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXgeam; \
428:   using ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>::cupmBlasXdgmm

430: // The actual interface class
431: template <DeviceType T>
432: struct BlasInterface : BlasInterfaceImpl<T> {
433:   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T);

435:   PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }

437:   static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
438:   {
439:     auto mtype = PETSC_MEMTYPE_HOST;

441:     PetscFunctionBegin;
442:     PetscCall(PetscCUPMGetMemType(ptr, &mtype));
443:     PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
444:     PetscFunctionReturn(PETSC_SUCCESS);
445:   }

447:   static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
448:   {
449:     PetscFunctionBegin;
450:     PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32-bit integers", x, cupmBlasName());
451:     PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer (%" PetscInt_FMT ") to %s routine", x, cupmBlasName());
452:     PetscFunctionReturn(PETSC_SUCCESS);
453:   }

455:   static PetscErrorCode PetscCUPMBlasIntCast(PetscInt x, cupmBlasInt_t *y) noexcept
456:   {
457:     PetscFunctionBegin;
458:     *y = static_cast<cupmBlasInt_t>(x);
459:     PetscCall(checkCupmBlasIntCast(x));
460:     PetscFunctionReturn(PETSC_SUCCESS);
461:   }

463:   class CUPMBlasPointerModeGuard {
464:   public:
465:     CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, cupmBlasPointerMode_t mode) noexcept : handle_{handle}
466:     {
467:       PetscFunctionBegin;
468:       PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasGetPointerMode(handle, &this->old_));
469:       if (this->old_ == mode) {
470:         this->set_ = false;
471:       } else {
472:         this->set_ = true;
473:         PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(handle, mode));
474:       }
475:       PetscFunctionReturnVoid();
476:     }

478:     CUPMBlasPointerModeGuard(const cupmBlasHandle_t &handle, PetscMemType mtype) noexcept : CUPMBlasPointerModeGuard{handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST} { }

480:     ~CUPMBlasPointerModeGuard() noexcept
481:     {
482:       PetscFunctionBegin;
483:       if (this->set_) PetscCallCUPMBLASAbort(PETSC_COMM_SELF, cupmBlasSetPointerMode(this->handle_, this->old_));
484:       PetscFunctionReturnVoid();
485:     }

487:   private:
488:     cupmBlasHandle_t      handle_;
489:     cupmBlasPointerMode_t old_;
490:     bool                  set_;
491:   };
492: };

494: #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(T) \
495:   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(T); \
496:   using ::Petsc::device::cupm::impl::BlasInterface<T>::cupmBlasName; \
497:   using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasSetPointerModeFromPointer; \
498:   using ::Petsc::device::cupm::impl::BlasInterface<T>::checkCupmBlasIntCast; \
499:   using ::Petsc::device::cupm::impl::BlasInterface<T>::PetscCUPMBlasIntCast; \
500:   using CUPMBlasPointerModeGuard = typename ::Petsc::device::cupm::impl::BlasInterface<T>::CUPMBlasPointerModeGuard

502: #if PetscDefined(HAVE_CUDA)
503: extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::CUDA>;
504: #endif

506: #if PetscDefined(HAVE_HIP)
507: extern template struct PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL BlasInterface<DeviceType::HIP>;
508: #endif

510: } // namespace impl

512: } // namespace cupm

514: } // namespace device

516: } // namespace Petsc