Actual source code: cupmobject.hpp

  1: #pragma once

  3: #include <petsc/private/deviceimpl.h>
  4: #include <petsc/private/cupmsolverinterface.hpp>

  6: #include <cstring> // std::memset

  8: inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
  9: {
 10:   PetscFunctionBegin;
 11:   PetscAssertPointer(dest, 2);
 12:   if (*dest) {
 13:     PetscAssertPointer(*dest, 2);
 14:     PetscCall(PetscFree(*dest));
 15:   }
 16:   PetscCall(PetscStrallocpy(target, dest));
 17:   PetscFunctionReturn(PETSC_SUCCESS);
 18: }

 20: namespace Petsc
 21: {

 23: namespace device
 24: {

 26: namespace cupm
 27: {

 29: namespace impl
 30: {

 32: namespace
 33: {

 35: // ==========================================================================================
 36: // UseCUPMHostAllocGuard
 37: //
 38: // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
 39: // regular versions would be an enormous pain to square with the templated types...
 40: // ==========================================================================================
 41: template <DeviceType T>
 42: class UseCUPMHostAllocGuard : Interface<T> {
 43: public:
 44:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 46:   UseCUPMHostAllocGuard(bool) noexcept;
 47:   ~UseCUPMHostAllocGuard() noexcept;

 49:   PETSC_NODISCARD bool value() const noexcept;

 51: private:
 52:   // would have loved to just do
 53:   //
 54:   // const auto oldmalloc = PetscTrMalloc;
 55:   //
 56:   // but in order to use auto the member needs to be static; in order to be static it must
 57:   // also be constexpr -- which in turn requires an initializer (also implicitly required by
 58:   // auto). But constexpr needs a constant expression initializer, so we can't initialize it
 59:   // with global (mutable) variables...
 60: #define DECLTYPE_AUTO(left, right) decltype(right) left = right
 61:   const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
 62:   const DECLTYPE_AUTO(oldfree_, PetscTrFree);
 63:   const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
 64: #undef DECLTYPE_AUTO
 65:   bool v_;
 66: };

 68: // ==========================================================================================
 69: // UseCUPMHostAllocGuard -- Public API
 70: // ==========================================================================================

 72: template <DeviceType T>
 73: inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
 74: {
 75:   PetscFunctionBegin;
 76:   if (useit) {
 77:     // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
 78:     PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
 79:       PetscFunctionBegin;
 80:       PetscCallCUPM(cupmMallocHost(ptr, sz));
 81:       if (clear) std::memset(*ptr, 0, sz);
 82:       PetscFunctionReturn(PETSC_SUCCESS);
 83:     };
 84:     PetscTrFree = [](void *ptr, int, const char *, const char *) {
 85:       PetscFunctionBegin;
 86:       PetscCallCUPM(cupmFreeHost(ptr));
 87:       PetscFunctionReturn(PETSC_SUCCESS);
 88:     };
 89:     PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
 90:       // REVIEW ME: can be implemented by malloc->copy->free?
 91:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
 92:     };
 93:   }
 94:   PetscFunctionReturnVoid();
 95: }

 97: template <DeviceType T>
 98: inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
 99: {
100:   PetscFunctionBegin;
101:   if (value()) {
102:     PetscTrMalloc  = oldmalloc_;
103:     PetscTrFree    = oldfree_;
104:     PetscTrRealloc = oldrealloc_;
105:   }
106:   PetscFunctionReturnVoid();
107: }

109: template <DeviceType T>
110: inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
111: {
112:   return v_;
113: }

115: } // anonymous namespace

117: template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
118: class RestoreableArray : Interface<T> {
119: public:
120:   PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);

122:   static constexpr auto memory_type = MemoryType;
123:   static constexpr auto access_type = AccessMode;

125:   using value_type        = PetscScalar;
126:   using pointer_type      = value_type *;
127:   using cupm_pointer_type = cupmScalar_t *;

129:   PETSC_NODISCARD pointer_type      data() const noexcept;
130:   PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;

132:   operator pointer_type() const noexcept;
133:   // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
134:   // we make a dummy template parameter to allow SFINAE to nix it for us
135:   template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
136:   operator cupm_pointer_type() const noexcept;

138: protected:
139:   constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;

141:   value_type        *ptr_  = nullptr;
142:   PetscDeviceContext dctx_ = nullptr;
143: };

145: // ==========================================================================================
146: // RestoreableArray - Static Variables
147: // ==========================================================================================

149: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
150: const PetscMemType RestoreableArray<T, MT, MA>::memory_type;

152: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
153: const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;

155: // ==========================================================================================
156: // RestoreableArray - Public API
157: // ==========================================================================================

159: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
160: constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
161: {
162: }

164: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
165: inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
166: {
167:   return ptr_;
168: }

170: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
171: inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
172: {
173:   return cupmScalarPtrCast(data());
174: }

176: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
177: inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
178: {
179:   return data();
180: }

182: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
183: // we make a dummy template parameter to allow SFINAE to nix it for us
184: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
185: template <typename U, typename>
186: inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
187: {
188:   return cupmdata();
189: }

191: template <DeviceType T>
192: class CUPMObject : SolverInterface<T> {
193: protected:
194:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);

196: private:
197:   // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
198:   // compute handles and ensure the given PetscDeviceContext is of the right type
199:   static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
200:   static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;

202: protected:
203:   PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;

205:   // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
206:   // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
207:   // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
208:   // them check that the PetscDeviceContext is of the appropriate type
209:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;

211:   // triple
212:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
213:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;

215:   // double
216:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
217:   static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;

219:   // single
220:   static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
221:   static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
222:   static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;

224:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
225:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
226:   static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;

228:   // disallow implicit conversion
229:   template <typename U>
230:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
231:   // utility for using cupmHostAlloc()
232:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
233:   PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(PetscBool) noexcept;

235:   // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
236:   // actually correct. Errors on mismatch
237:   static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
238: };

240: template <DeviceType T>
241: inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
242: {
243:   // REVIEW ME: HIP default rng?
244:   return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
245: }

247: template <DeviceType T>
248: inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream_handle) noexcept
249: {
250:   PetscFunctionBegin;
252:   if (blas_handle) {
253:     PetscAssertPointer(blas_handle, 2);
254:     *blas_handle = nullptr;
255:   }
256:   if (solver_handle) {
257:     PetscAssertPointer(solver_handle, 3);
258:     *solver_handle = nullptr;
259:   }
260:   if (stream_handle) {
261:     PetscAssertPointer(stream_handle, 4);
262:     *stream_handle = nullptr;
263:   }
264:   if (PetscDefined(USE_DEBUG)) {
265:     PetscDeviceType dtype;

267:     PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
268:     PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
269:   }
270:   if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
271:   if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
272:   if (stream_handle) {
273:     cupmStream_t *stream = nullptr;

275:     PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, (void **)&stream));
276:     *stream_handle = *stream;
277:   }
278:   PetscFunctionReturn(PETSC_SUCCESS);
279: }

281: template <DeviceType T>
282: inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
283: {
284:   PetscDeviceContext dctx_loc = nullptr;

286:   PetscFunctionBegin;
287:   // silence uninitialized variable warnings
288:   if (dctx) *dctx = nullptr;
289:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
290:   PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
291:   if (dctx) *dctx = dctx_loc;
292:   PetscFunctionReturn(PETSC_SUCCESS);
293: }

295: template <DeviceType T>
296: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
297: {
298:   return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
299: }

301: template <DeviceType T>
302: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
303: {
304:   return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
305: }

307: template <DeviceType T>
308: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
309: {
310:   return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
311: }

313: template <DeviceType T>
314: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
315: {
316:   return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
317: }

319: template <DeviceType T>
320: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
321: {
322:   return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
323: }

325: template <DeviceType T>
326: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
327: {
328:   return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
329: }

331: template <DeviceType T>
332: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
333: {
334:   return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
335: }

337: template <DeviceType T>
338: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
339: {
340:   return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
341: }

343: template <DeviceType T>
344: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
345: {
346:   return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
347: }

349: template <DeviceType T>
350: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
351: {
352:   return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
353: }

355: template <DeviceType T>
356: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
357: {
358:   return {b};
359: }

361: template <DeviceType T>
362: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(PetscBool b) noexcept
363: {
364:   return UseCUPMHostAlloc(static_cast<bool>(b));
365: }

367: template <DeviceType T>
368: inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
369: {
370:   PetscFunctionBegin;
371:   if (PetscDefined(USE_DEBUG) && ptr) {
372:     PetscMemType ptr_mtype;

374:     PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
375:     if (mtype == PETSC_MEMTYPE_HOST) {
376:       PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
377:     } else if (mtype == PETSC_MEMTYPE_DEVICE) {
378:       // generic "device" memory should only care if the actual memtype is also generically
379:       // "device"
380:       PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
381:     } else {
382:       PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
383:     }
384:   }
385:   PetscFunctionReturn(PETSC_SUCCESS);
386: }

388: #define PETSC_CUPMOBJECT_HEADER(T) \
389:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
390:   using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
391:   using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
392:   using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
393:   using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
394:   using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_

396: } // namespace impl

398: } // namespace cupm

400: } // namespace device

402: } // namespace Petsc