Actual source code: cupmobject.hpp
1: #pragma once
3: #include <petsc/private/deviceimpl.h>
4: #include <petsc/private/cupmsolverinterface.hpp>
6: #include <cstring> // std::memset
8: inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
9: {
10: PetscFunctionBegin;
11: PetscAssertPointer(dest, 2);
12: if (*dest) {
13: PetscAssertPointer(*dest, 2);
14: PetscCall(PetscFree(*dest));
15: }
16: PetscCall(PetscStrallocpy(target, dest));
17: PetscFunctionReturn(PETSC_SUCCESS);
18: }
20: namespace Petsc
21: {
23: namespace device
24: {
26: namespace cupm
27: {
29: namespace impl
30: {
32: namespace
33: {
35: // ==========================================================================================
36: // UseCUPMHostAllocGuard
37: //
38: // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
39: // regular versions would be an enormous pain to square with the templated types...
40: // ==========================================================================================
41: template <DeviceType T>
42: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL UseCUPMHostAllocGuard : Interface<T> {
43: public:
44: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
46: UseCUPMHostAllocGuard(bool) noexcept;
47: ~UseCUPMHostAllocGuard() noexcept;
49: PETSC_NODISCARD bool value() const noexcept;
51: private:
52: // would have loved to just do
53: //
54: // const auto oldmalloc = PetscTrMalloc;
55: //
56: // but in order to use auto the member needs to be static; in order to be static it must
57: // also be constexpr -- which in turn requires an initializer (also implicitly required by
58: // auto). But constexpr needs a constant expression initializer, so we can't initialize it
59: // with global (mutable) variables...
60: #define DECLTYPE_AUTO(left, right) decltype(right) left = right
61: const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
62: const DECLTYPE_AUTO(oldfree_, PetscTrFree);
63: const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
64: #undef DECLTYPE_AUTO
65: bool v_;
66: };
68: // ==========================================================================================
69: // UseCUPMHostAllocGuard -- Public API
70: // ==========================================================================================
72: template <DeviceType T>
73: inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
74: {
75: PetscFunctionBegin;
76: if (useit) {
77: // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
78: PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
79: PetscFunctionBegin;
80: PetscCallCUPM(cupmMallocHost(ptr, sz));
81: if (clear) std::memset(*ptr, 0, sz);
82: PetscFunctionReturn(PETSC_SUCCESS);
83: };
84: PetscTrFree = [](void *ptr, int, const char *, const char *) {
85: PetscFunctionBegin;
86: PetscCallCUPM(cupmFreeHost(ptr));
87: PetscFunctionReturn(PETSC_SUCCESS);
88: };
89: PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
90: // REVIEW ME: can be implemented by malloc->copy->free?
91: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
92: };
93: }
94: PetscFunctionReturnVoid();
95: }
97: template <DeviceType T>
98: inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
99: {
100: PetscFunctionBegin;
101: if (value()) {
102: PetscTrMalloc = oldmalloc_;
103: PetscTrFree = oldfree_;
104: PetscTrRealloc = oldrealloc_;
105: }
106: PetscFunctionReturnVoid();
107: }
109: template <DeviceType T>
110: inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
111: {
112: return v_;
113: }
115: } // anonymous namespace
117: template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
118: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL RestoreableArray : Interface<T> {
119: public:
120: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
122: static constexpr auto memory_type = MemoryType;
123: static constexpr auto access_type = AccessMode;
125: using value_type = PetscScalar;
126: using pointer_type = value_type *;
127: using cupm_pointer_type = cupmScalar_t *;
129: PETSC_NODISCARD pointer_type data() const noexcept;
130: PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;
132: operator pointer_type() const noexcept;
133: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
134: // we make a dummy template parameter to allow SFINAE to nix it for us
135: template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
136: operator cupm_pointer_type() const noexcept;
138: protected:
139: constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;
141: value_type *ptr_ = nullptr;
142: PetscDeviceContext dctx_ = nullptr;
143: };
145: // ==========================================================================================
146: // RestoreableArray - Static Variables
147: // ==========================================================================================
149: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
150: const PetscMemType RestoreableArray<T, MT, MA>::memory_type;
152: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
153: const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;
155: // ==========================================================================================
156: // RestoreableArray - Public API
157: // ==========================================================================================
159: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
160: constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
161: {
162: }
164: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
165: inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
166: {
167: return ptr_;
168: }
170: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
171: inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
172: {
173: return cupmScalarPtrCast(data());
174: }
176: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
177: inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
178: {
179: return data();
180: }
182: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
183: // we make a dummy template parameter to allow SFINAE to nix it for us
184: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
185: template <typename U, typename>
186: inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
187: {
188: return cupmdata();
189: }
191: template <DeviceType T>
192: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMObject : SolverInterface<T> {
193: protected:
194: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
196: private:
197: // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
198: // compute handles and ensure the given PetscDeviceContext is of the right type
199: static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
200: static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
202: protected:
203: PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;
205: // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
206: // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
207: // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
208: // them check that the PetscDeviceContext is of the appropriate type
209: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
211: // triple
212: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
213: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
215: // double
216: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
217: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;
219: // single
220: static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
221: static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
222: static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;
224: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
225: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
226: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;
228: // disallow implicit conversion
229: template <typename U>
230: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
231: // utility for using cupmHostAlloc()
232: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
234: // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
235: // actually correct. Errors on mismatch
236: static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
237: };
239: template <DeviceType T>
240: inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
241: {
242: // REVIEW ME: HIP default rng?
243: return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
244: }
246: template <DeviceType T>
247: inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream_handle) noexcept
248: {
249: PetscFunctionBegin;
251: if (blas_handle) {
252: PetscAssertPointer(blas_handle, 2);
253: *blas_handle = nullptr;
254: }
255: if (solver_handle) {
256: PetscAssertPointer(solver_handle, 3);
257: *solver_handle = nullptr;
258: }
259: if (stream_handle) {
260: PetscAssertPointer(stream_handle, 4);
261: *stream_handle = nullptr;
262: }
263: if (PetscDefined(USE_DEBUG)) {
264: PetscDeviceType dtype;
266: PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
267: PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
268: }
269: if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
270: if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
271: if (stream_handle) {
272: cupmStream_t *stream = nullptr;
274: PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, (void **)&stream));
275: *stream_handle = *stream;
276: }
277: PetscFunctionReturn(PETSC_SUCCESS);
278: }
280: template <DeviceType T>
281: inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
282: {
283: PetscDeviceContext dctx_loc = nullptr;
285: PetscFunctionBegin;
286: // silence uninitialized variable warnings
287: if (dctx) *dctx = nullptr;
288: PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
289: PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
290: if (dctx) *dctx = dctx_loc;
291: PetscFunctionReturn(PETSC_SUCCESS);
292: }
294: template <DeviceType T>
295: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
296: {
297: return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
298: }
300: template <DeviceType T>
301: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
302: {
303: return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
304: }
306: template <DeviceType T>
307: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
308: {
309: return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
310: }
312: template <DeviceType T>
313: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
314: {
315: return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
316: }
318: template <DeviceType T>
319: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
320: {
321: return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
322: }
324: template <DeviceType T>
325: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
326: {
327: return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
328: }
330: template <DeviceType T>
331: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
332: {
333: return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
334: }
336: template <DeviceType T>
337: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
338: {
339: return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
340: }
342: template <DeviceType T>
343: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
344: {
345: return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
346: }
348: template <DeviceType T>
349: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
350: {
351: return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
352: }
354: template <DeviceType T>
355: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
356: {
357: return {b};
358: }
360: template <DeviceType T>
361: inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
362: {
363: PetscFunctionBegin;
364: if (PetscDefined(USE_DEBUG) && ptr) {
365: PetscMemType ptr_mtype;
367: PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
368: if (mtype == PETSC_MEMTYPE_HOST) {
369: PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
370: } else if (mtype == PETSC_MEMTYPE_DEVICE) {
371: // generic "device" memory should only care if the actual memtype is also generically
372: // "device"
373: PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
374: } else {
375: PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
376: }
377: }
378: PetscFunctionReturn(PETSC_SUCCESS);
379: }
381: #define PETSC_CUPMOBJECT_HEADER(T) \
382: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
383: using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
384: using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
385: using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
386: using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
387: using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_
389: } // namespace impl
391: } // namespace cupm
393: } // namespace device
395: } // namespace Petsc