Actual source code: cupmobject.hpp
1: #pragma once
3: #include <petsc/private/deviceimpl.h>
4: #include <petsc/private/cupmsolverinterface.hpp>
6: #include <cstring> // std::memset
8: inline PetscErrorCode PetscStrFreeAllocpy(const char target[], char **dest) noexcept
9: {
10: PetscFunctionBegin;
11: PetscAssertPointer(dest, 2);
12: if (*dest) {
13: PetscAssertPointer(*dest, 2);
14: PetscCall(PetscFree(*dest));
15: }
16: PetscCall(PetscStrallocpy(target, dest));
17: PetscFunctionReturn(PETSC_SUCCESS);
18: }
20: namespace Petsc
21: {
23: namespace device
24: {
26: namespace cupm
27: {
29: namespace impl
30: {
32: namespace
33: {
35: // ==========================================================================================
36: // UseCUPMHostAllocGuard
37: //
38: // A simple RAII helper for PetscMallocSet[CUDA|HIP]Host(). it exists because integrating the
39: // regular versions would be an enormous pain to square with the templated types...
40: // ==========================================================================================
41: template <DeviceType T>
42: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL UseCUPMHostAllocGuard : Interface<T> {
43: public:
44: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
46: UseCUPMHostAllocGuard(bool) noexcept;
47: ~UseCUPMHostAllocGuard() noexcept;
49: PETSC_NODISCARD bool value() const noexcept;
51: private:
52: // would have loved to just do
53: //
54: // const auto oldmalloc = PetscTrMalloc;
55: //
56: // but in order to use auto the member needs to be static; in order to be static it must
57: // also be constexpr -- which in turn requires an initializer (also implicitly required by
58: // auto). But constexpr needs a constant expression initializer, so we can't initialize it
59: // with global (mutable) variables...
60: #define DECLTYPE_AUTO(left, right) decltype(right) left = right
61: const DECLTYPE_AUTO(oldmalloc_, PetscTrMalloc);
62: const DECLTYPE_AUTO(oldfree_, PetscTrFree);
63: const DECLTYPE_AUTO(oldrealloc_, PetscTrRealloc);
64: #undef DECLTYPE_AUTO
65: bool v_;
66: };
68: // ==========================================================================================
69: // UseCUPMHostAllocGuard -- Public API
70: // ==========================================================================================
72: template <DeviceType T>
73: inline UseCUPMHostAllocGuard<T>::UseCUPMHostAllocGuard(bool useit) noexcept : v_(useit)
74: {
75: PetscFunctionBegin;
76: if (useit) {
77: // all unused arguments are un-named, this saves having to add PETSC_UNUSED to them all
78: PetscTrMalloc = [](std::size_t sz, PetscBool clear, int, const char *, const char *, void **ptr) {
79: PetscFunctionBegin;
80: PetscCallCUPM(cupmMallocHost(ptr, sz));
81: if (clear) std::memset(*ptr, 0, sz);
82: PetscFunctionReturn(PETSC_SUCCESS);
83: };
84: PetscTrFree = [](void *ptr, int, const char *, const char *) {
85: PetscFunctionBegin;
86: PetscCallCUPM(cupmFreeHost(ptr));
87: PetscFunctionReturn(PETSC_SUCCESS);
88: };
89: PetscTrRealloc = [](std::size_t, int, const char *, const char *, void **) {
90: // REVIEW ME: can be implemented by malloc->copy->free?
91: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "%s has no realloc()", cupmName());
92: };
93: }
94: PetscFunctionReturnVoid();
95: }
97: template <DeviceType T>
98: inline UseCUPMHostAllocGuard<T>::~UseCUPMHostAllocGuard() noexcept
99: {
100: PetscFunctionBegin;
101: if (value()) {
102: PetscTrMalloc = oldmalloc_;
103: PetscTrFree = oldfree_;
104: PetscTrRealloc = oldrealloc_;
105: }
106: PetscFunctionReturnVoid();
107: }
109: template <DeviceType T>
110: inline bool UseCUPMHostAllocGuard<T>::value() const noexcept
111: {
112: return v_;
113: }
115: } // anonymous namespace
117: template <DeviceType T, PetscMemType MemoryType, PetscMemoryAccessMode AccessMode>
118: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL RestoreableArray : Interface<T> {
119: public:
120: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(T);
122: static constexpr auto memory_type = MemoryType;
123: static constexpr auto access_type = AccessMode;
125: using value_type = PetscScalar;
126: using pointer_type = value_type *;
127: using cupm_pointer_type = cupmScalar_t *;
129: PETSC_NODISCARD pointer_type data() const noexcept;
130: PETSC_NODISCARD cupm_pointer_type cupmdata() const noexcept;
132: operator pointer_type() const noexcept;
133: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
134: // we make a dummy template parameter to allow SFINAE to nix it for us
135: template <typename U = pointer_type, typename = util::enable_if_t<!std::is_same<U, cupm_pointer_type>::value>>
136: operator cupm_pointer_type() const noexcept;
138: protected:
139: constexpr explicit RestoreableArray(PetscDeviceContext) noexcept;
141: value_type *ptr_ = nullptr;
142: PetscDeviceContext dctx_ = nullptr;
143: };
145: // ==========================================================================================
146: // RestoreableArray - Static Variables
147: // ==========================================================================================
149: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
150: const PetscMemType RestoreableArray<T, MT, MA>::memory_type;
152: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
153: const PetscMemoryAccessMode RestoreableArray<T, MT, MA>::access_type;
155: // ==========================================================================================
156: // RestoreableArray - Public API
157: // ==========================================================================================
159: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
160: constexpr inline RestoreableArray<T, MT, MA>::RestoreableArray(PetscDeviceContext dctx) noexcept : dctx_{dctx}
161: {
162: }
164: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
165: inline typename RestoreableArray<T, MT, MA>::pointer_type RestoreableArray<T, MT, MA>::data() const noexcept
166: {
167: return ptr_;
168: }
170: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
171: inline typename RestoreableArray<T, MT, MA>::cupm_pointer_type RestoreableArray<T, MT, MA>::cupmdata() const noexcept
172: {
173: return cupmScalarPtrCast(data());
174: }
176: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
177: inline RestoreableArray<T, MT, MA>::operator pointer_type() const noexcept
178: {
179: return data();
180: }
182: // in case pointer_type == cupmscalar_pointer_type we don't want this overload to exist, so
183: // we make a dummy template parameter to allow SFINAE to nix it for us
184: template <DeviceType T, PetscMemType MT, PetscMemoryAccessMode MA>
185: template <typename U, typename>
186: inline RestoreableArray<T, MT, MA>::operator cupm_pointer_type() const noexcept
187: {
188: return cupmdata();
189: }
191: template <DeviceType T>
192: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL CUPMObject : SolverInterface<T> {
193: protected:
194: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
196: private:
197: // The final stop in the GetHandles_/GetFromHandles_ chain. This retrieves the various
198: // compute handles and ensure the given PetscDeviceContext is of the right type
199: static PetscErrorCode GetFromHandleDispatch_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
200: static PetscErrorCode GetHandleDispatch_(PetscDeviceContext *, cupmBlasHandle_t *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
202: protected:
203: PETSC_NODISCARD static constexpr PetscRandomType PETSCDEVICERAND() noexcept;
205: // Helper routines to retrieve various combinations of handles. The first set (GetHandles_)
206: // gets a PetscDeviceContext along with it, while the second set (GetHandlesFrom_) assumes
207: // you've gotten the PetscDeviceContext already, and retrieves the handles from it. All of
208: // them check that the PetscDeviceContext is of the appropriate type
209: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t * = nullptr, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
211: // triple
212: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmBlasHandle_t *, cupmStream_t *) noexcept;
213: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *, cupmStream_t *) noexcept;
215: // double
216: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmSolverHandle_t *) noexcept;
217: static PetscErrorCode GetHandles_(PetscDeviceContext *, cupmStream_t *) noexcept;
219: // single
220: static PetscErrorCode GetHandles_(cupmBlasHandle_t *) noexcept;
221: static PetscErrorCode GetHandles_(cupmSolverHandle_t *) noexcept;
222: static PetscErrorCode GetHandles_(cupmStream_t *) noexcept;
224: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmBlasHandle_t *, cupmSolverHandle_t * = nullptr, cupmStream_t * = nullptr) noexcept;
225: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmSolverHandle_t *, cupmStream_t * = nullptr) noexcept;
226: static PetscErrorCode GetHandlesFrom_(PetscDeviceContext, cupmStream_t *) noexcept;
228: // disallow implicit conversion
229: template <typename U>
230: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(U) noexcept = delete;
231: // utility for using cupmHostAlloc()
232: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(bool) noexcept;
233: PETSC_NODISCARD static UseCUPMHostAllocGuard<T> UseCUPMHostAlloc(PetscBool) noexcept;
235: // A debug check to ensure that a given pointer-memtype pairing taken from user-land is
236: // actually correct. Errors on mismatch
237: static PetscErrorCode CheckPointerMatchesMemType_(const void *, PetscMemType) noexcept;
238: };
240: template <DeviceType T>
241: inline constexpr PetscRandomType CUPMObject<T>::PETSCDEVICERAND() noexcept
242: {
243: // REVIEW ME: HIP default rng?
244: return T == DeviceType::CUDA ? PETSCCURAND : PETSCRANDER48;
245: }
247: template <DeviceType T>
248: inline PetscErrorCode CUPMObject<T>::GetFromHandleDispatch_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream_handle) noexcept
249: {
250: PetscFunctionBegin;
252: if (blas_handle) {
253: PetscAssertPointer(blas_handle, 2);
254: *blas_handle = nullptr;
255: }
256: if (solver_handle) {
257: PetscAssertPointer(solver_handle, 3);
258: *solver_handle = nullptr;
259: }
260: if (stream_handle) {
261: PetscAssertPointer(stream_handle, 4);
262: *stream_handle = nullptr;
263: }
264: if (PetscDefined(USE_DEBUG)) {
265: PetscDeviceType dtype;
267: PetscCall(PetscDeviceContextGetDeviceType(dctx, &dtype));
268: PetscCheckCompatibleDeviceTypes(PETSC_DEVICE_CUPM(), -1, dtype, 1);
269: }
270: if (blas_handle) PetscCall(PetscDeviceContextGetBLASHandle_Internal(dctx, blas_handle));
271: if (solver_handle) PetscCall(PetscDeviceContextGetSOLVERHandle_Internal(dctx, solver_handle));
272: if (stream_handle) {
273: cupmStream_t *stream = nullptr;
275: PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, (void **)&stream));
276: *stream_handle = *stream;
277: }
278: PetscFunctionReturn(PETSC_SUCCESS);
279: }
281: template <DeviceType T>
282: inline PetscErrorCode CUPMObject<T>::GetHandleDispatch_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
283: {
284: PetscDeviceContext dctx_loc = nullptr;
286: PetscFunctionBegin;
287: // silence uninitialized variable warnings
288: if (dctx) *dctx = nullptr;
289: PetscCall(PetscDeviceContextGetCurrentContext(&dctx_loc));
290: PetscCall(GetFromHandleDispatch_(dctx_loc, blas_handle, solver_handle, stream));
291: if (dctx) *dctx = dctx_loc;
292: PetscFunctionReturn(PETSC_SUCCESS);
293: }
295: template <DeviceType T>
296: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
297: {
298: return GetHandleDispatch_(dctx, blas_handle, solver_handle, stream);
299: }
301: template <DeviceType T>
302: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmBlasHandle_t *blas_handle, cupmStream_t *stream) noexcept
303: {
304: return GetHandleDispatch_(dctx, blas_handle, nullptr, stream);
305: }
307: template <DeviceType T>
308: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
309: {
310: return GetHandleDispatch_(dctx, nullptr, solver_handle, stream);
311: }
313: template <DeviceType T>
314: inline PetscErrorCode CUPMObject<T>::GetHandles_(PetscDeviceContext *dctx, cupmStream_t *stream) noexcept
315: {
316: return GetHandleDispatch_(dctx, nullptr, nullptr, stream);
317: }
319: template <DeviceType T>
320: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmBlasHandle_t *handle) noexcept
321: {
322: return GetHandleDispatch_(nullptr, handle, nullptr, nullptr);
323: }
325: template <DeviceType T>
326: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmSolverHandle_t *handle) noexcept
327: {
328: return GetHandleDispatch_(nullptr, nullptr, handle, nullptr);
329: }
331: template <DeviceType T>
332: inline PetscErrorCode CUPMObject<T>::GetHandles_(cupmStream_t *stream) noexcept
333: {
334: return GetHandleDispatch_(nullptr, nullptr, nullptr, stream);
335: }
337: template <DeviceType T>
338: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmBlasHandle_t *blas_handle, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
339: {
340: return GetFromHandleDispatch_(dctx, blas_handle, solver_handle, stream);
341: }
343: template <DeviceType T>
344: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmSolverHandle_t *solver_handle, cupmStream_t *stream) noexcept
345: {
346: return GetFromHandleDispatch_(dctx, nullptr, solver_handle, stream);
347: }
349: template <DeviceType T>
350: inline PetscErrorCode CUPMObject<T>::GetHandlesFrom_(PetscDeviceContext dctx, cupmStream_t *stream) noexcept
351: {
352: return GetFromHandleDispatch_(dctx, nullptr, nullptr, stream);
353: }
355: template <DeviceType T>
356: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(bool b) noexcept
357: {
358: return {b};
359: }
361: template <DeviceType T>
362: inline UseCUPMHostAllocGuard<T> CUPMObject<T>::UseCUPMHostAlloc(PetscBool b) noexcept
363: {
364: return UseCUPMHostAlloc(static_cast<bool>(b));
365: }
367: template <DeviceType T>
368: inline PetscErrorCode CUPMObject<T>::CheckPointerMatchesMemType_(const void *ptr, PetscMemType mtype) noexcept
369: {
370: PetscFunctionBegin;
371: if (PetscDefined(USE_DEBUG) && ptr) {
372: PetscMemType ptr_mtype;
374: PetscCall(PetscCUPMGetMemType(ptr, &ptr_mtype));
375: if (mtype == PETSC_MEMTYPE_HOST) {
376: PetscCheck(PetscMemTypeHost(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
377: } else if (mtype == PETSC_MEMTYPE_DEVICE) {
378: // generic "device" memory should only care if the actual memtype is also generically
379: // "device"
380: PetscCheck(PetscMemTypeDevice(ptr_mtype), PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
381: } else {
382: PetscCheck(mtype == ptr_mtype, PETSC_COMM_SELF, PETSC_ERR_POINTER, "Pointer %p declared as %s does not match actual memtype %s", ptr, PetscMemTypeToString(mtype), PetscMemTypeToString(ptr_mtype));
383: }
384: }
385: PetscFunctionReturn(PETSC_SUCCESS);
386: }
388: #define PETSC_CUPMOBJECT_HEADER(T) \
389: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T); \
390: using ::Petsc::device::cupm::impl::CUPMObject<T>::UseCUPMHostAlloc; \
391: using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandles_; \
392: using ::Petsc::device::cupm::impl::CUPMObject<T>::GetHandlesFrom_; \
393: using ::Petsc::device::cupm::impl::CUPMObject<T>::PETSCDEVICERAND; \
394: using ::Petsc::device::cupm::impl::CUPMObject<T>::CheckPointerMatchesMemType_
396: } // namespace impl
398: } // namespace cupm
400: } // namespace device
402: } // namespace Petsc