Actual source code: cupmobject.hpp

  1: #pragma once

  3: #include <petsc/private/deviceimpl.h>
  4: #include <petsc/private/cupmsolverinterface.hpp>

  6: #include <cstring> // std::memset

  8: inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
  9: {
 10:   PetscFunctionBegin;
 11:   PetscAssertPointer(dest, 2);
 12:   if (*dest) {
 13:     PetscAssertPointer(*dest, 2);
 14:     PetscCall(PetscFree(*dest));
 15:   }
 16:   PetscCall(PetscStrallocpy(target, dest));
 17:   PetscFunctionReturn(PETSC_SUCCESS);
 18: }

 20: namespace Petsc
 21: {

 23: namespace device
 24: {

 26: namespace cupm
 27: {

 29: namespace impl
 30: {

 32: namespace
 33: {

 35: // ==========================================================================================
 36: // UseCUPMHostAllocGuard
 37: //
 38: // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
 39: // regular versions would be an enormous pain to square with the templated types...
 40: // ==========================================================================================
 41: template <DeviceType T>
 42: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL UseCUPMHostAllocGuard : Interface<T> {
 43: public:
 44:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 46:   UseCUPMHostAllocGuard(bool) noexcept;
 47:   ~UseCUPMHostAllocGuard() noexcept;

 49:   PETSC_NODISCARD bool value() const noexcept;

 51: private:
 52:   // would have loved to just do
 53:   //
 54:   // const auto oldmalloc = PetscTrMalloc;
 55:   //
 56:   // but in order to use auto the member needs to be static; in order to be static it must
 57:   // also be constexpr -- which in turn requires an initializer (also implicitly required by
 58:   // auto). But constexpr needs a constant expression initializer, so we can't initialize it
 59:   // with global (mutable) variables...
 60: #define DECLTYPE_AUTO(left, right) decltype(right) left = right
 61:   const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
 62:   const DECLTYPE_AUTO(oldfree_, PetscTrFree);
 63:   const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
 64: #undef DECLTYPE_AUTO
 65:   bool v_;
 66: };

 68: // ==========================================================================================
 69: // UseCUPMHostAllocGuard -- Public API
 70: // ==========================================================================================

 72: template <DeviceType T>
 73: inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
 74: {
 75:   PetscFunctionBegin;
 76:   if (useit) {
 77:     // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
 78:     PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
 79:       PetscFunctionBegin;
 80:       PetscCallCUPM(cupmMallocHost(ptr, sz));
 81:       if (clear) std::memset(*ptr, 0, sz);
 82:       PetscFunctionReturn(PETSC_SUCCESS);
 83:     };
 84:     PetscTrFree = [](void *ptr, int, const char *, const char *) {
 85:       PetscFunctionBegin;
 86:       PetscCallCUPM(cupmFreeHost(ptr));
 87:       PetscFunctionReturn(PETSC_SUCCESS);
 88:     };
 89:     PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
 90:       // REVIEW ME: can be implemented by malloc->copy->free?
 91:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
 92:     };
 93:   }
 94:   PetscFunctionReturnVoid();
 95: }

 97: template <DeviceType T>
 98: inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
 99: {
100:   PetscFunctionBegin;
101:   if (value()) {
102:     PetscTrMalloc  = oldmalloc_;
103:     PetscTrFree    = oldfree_;
104:     PetscTrRealloc = oldrealloc_;
105:   }
106:   PetscFunctionReturnVoid();
107: }

109: template <DeviceType T>
110: inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
111: {
112:   return v_;
113: }

115: } // anonymous namespace

117: template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
118: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL RestoreableArray : Interface<T> {
119: public:
120:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

122:   static constexpr auto memory_type = MemoryType;
123:   static constexpr auto access_type = AccessMode;

125:   using value_type        = PetscScalar;
126:   using pointer_type      = value_type *;
127:   using cupm_pointer_type = cupmScalar_t *;

129:   PETSC_NODISCARD pointer_type      data() const noexcept;
130:   PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;

132:   operator pointer_type() const noexcept;
133:   // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
134:   // we make a dummy template parameter to allow SFINAE to nix it for us
135:   template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
136:   operator cupm_pointer_type() const noexcept;

138: protected:
139:   constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;

141:   value_type        *ptr_  = nullptr;
142:   PetscDeviceContext dctx_ = nullptr;
143: };

145: // ==========================================================================================
146: // RestoreableArray - Static Variables
147: // ==========================================================================================

149: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
150: const PetscMemType RestoreableArray<T, MT, MA>::memory_type;

152: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
153: const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;

155: // ==========================================================================================
156: // RestoreableArray - Public API
157: // ==========================================================================================

159: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
160: constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
161: {
162: }

164: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
165: inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
166: {
167:   return ptr_;
168: }

170: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
171: inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
172: {
173:   return cupmScalarPtrCast(data());
174: }

176: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
177: inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
178: {
179:   return data();
180: }

182: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
183: // we make a dummy template parameter to allow SFINAE to nix it for us
184: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
185: template <typename U, typename>
186: inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
187: {
188:   return cupmdata();
189: }

191: template <DeviceType T>
192: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMObject : SolverInterface<T> {
193: protected:
194:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);

196: private:
197:   // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
198:   // compute handles and ensure the given PetscDeviceContext is of the right type
199:   static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
200:   static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;

202: protected:
203:   PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;

205:   // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
206:   // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
207:   // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
208:   // them check that the PetscDeviceContext is of the appropriate type
209:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;

211:   // triple
212:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
213:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;

215:   // double
216:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
217:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;

219:   // single
220:   static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
221:   static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
222:   static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;

224:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
225:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
226:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;

228:   // disallow implicit conversion
229:   template <typename U>
230:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
231:   // utility for using cupmHostAlloc()
232:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;

234:   // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
235:   // actually correct. Errors on mismatch
236:   static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
237: };

239: template <DeviceType T>
240: inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
241: {
242:   // REVIEW ME: HIP default rng?
243:   return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
244: }

246: template <DeviceType T>
247: inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream_handle) noexcept
248: {
249:   PetscFunctionBegin;
251:   if (blas_handle) {
252:     PetscAssertPointer(blas_handle, 2);
253:     *blas_handle = nullptr;
254:   }
255:   if (solver_handle) {
256:     PetscAssertPointer(solver_handle, 3);
257:     *solver_handle = nullptr;
258:   }
259:   if (stream_handle) {
260:     PetscAssertPointer(stream_handle, 4);
261:     *stream_handle = nullptr;
262:   }
263:   if (PetscDefined(USE_DEBUG)) {
264:     PetscDeviceType dtype;

266:     PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
267:     PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
268:   }
269:   if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
270:   if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
271:   if (stream_handle) {
272:     cupmStream_t *stream = nullptr;

274:     PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, (void **)&stream));
275:     *stream_handle = *stream;
276:   }
277:   PetscFunctionReturn(PETSC_SUCCESS);
278: }

280: template <DeviceType T>
281: inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
282: {
283:   PetscDeviceContext dctx_loc = nullptr;

285:   PetscFunctionBegin;
286:   // silence uninitialized variable warnings
287:   if (dctx) *dctx = nullptr;
288:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
289:   PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
290:   if (dctx) *dctx = dctx_loc;
291:   PetscFunctionReturn(PETSC_SUCCESS);
292: }

294: template <DeviceType T>
295: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
296: {
297:   return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
298: }

300: template <DeviceType T>
301: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
302: {
303:   return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
304: }

306: template <DeviceType T>
307: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
308: {
309:   return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
310: }

312: template <DeviceType T>
313: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
314: {
315:   return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
316: }

318: template <DeviceType T>
319: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
320: {
321:   return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
322: }

324: template <DeviceType T>
325: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
326: {
327:   return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
328: }

330: template <DeviceType T>
331: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
332: {
333:   return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
334: }

336: template <DeviceType T>
337: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
338: {
339:   return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
340: }

342: template <DeviceType T>
343: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
344: {
345:   return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
346: }

348: template <DeviceType T>
349: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
350: {
351:   return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
352: }

354: template <DeviceType T>
355: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
356: {
357:   return {b};
358: }

360: template <DeviceType T>
361: inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
362: {
363:   PetscFunctionBegin;
364:   if (PetscDefined(USE_DEBUG) && ptr) {
365:     PetscMemType ptr_mtype;

367:     PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
368:     if (mtype == PETSC_MEMTYPE_HOST) {
369:       PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
370:     } else if (mtype == PETSC_MEMTYPE_DEVICE) {
371:       // generic "device" memory should only care if the actual memtype is also generically
372:       // "device"
373:       PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
374:     } else {
375:       PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
376:     }
377:   }
378:   PetscFunctionReturn(PETSC_SUCCESS);
379: }

381: #define PETSC_CUPMOBJECT_HEADER(T) \
382:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
383:   using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
384:   using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
385:   using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
386:   using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
387:   using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_

389: } // namespace impl

391: } // namespace cupm

393: } // namespace device

395: } // namespace Petsc