Actual source code: vecseqcupm.hpp
1: #pragma once
3: #include <petsc/private/veccupmimpl.h>
4: #include <petsc/private/cpp/utility.hpp>
6: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
7: #include <../src/vec/vec/impls/dvecimpl.h>
9: namespace Petsc
10: {
12: namespace vec
13: {
15: namespace cupm
16: {
18: namespace impl
19: {
21: // ==========================================================================================
22: // VecSeq_CUPM
23: // ==========================================================================================
25: template <device::cupm::DeviceType T>
26: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
27: public:
28: PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
30: private:
31: PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept;
32: PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
33: PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
35: static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
36: static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
37: static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
38: static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
40: static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
42: // common core for min and max
43: template <typename TupleFuncT, typename UnaryFuncT>
44: static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
45: // common core for pointwise binary and pointwise unary thrust functions
46: template <typename BinaryFuncT>
47: static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
48: template <typename BinaryFuncT>
49: static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
50: template <typename UnaryFuncT>
51: static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
52: // mdot dispatchers
53: static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54: static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55: template <std::size_t... Idx>
56: static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
57: template <int>
58: static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
59: template <std::size_t... Idx>
60: static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
61: template <int>
62: static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
63: // common core for the various create routines
64: static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
66: public:
67: // callable directly via a bespoke function
68: static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
69: static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
71: static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
72: static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
74: // callable indirectly via function pointers
75: static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
76: static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
77: static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
78: static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
79: static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
80: static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
81: static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
82: static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
83: static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
84: static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
85: static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
86: static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
87: static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
88: static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
89: static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
90: static PetscErrorCode Reciprocal(Vec) noexcept;
91: static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
92: static PetscErrorCode Abs(Vec) noexcept;
93: static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
94: static PetscErrorCode SqrtAbs(Vec) noexcept;
95: static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
96: static PetscErrorCode Exp(Vec) noexcept;
97: static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
98: static PetscErrorCode Log(Vec) noexcept;
99: static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100: static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101: static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102: static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103: static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104: static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105: static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106: static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107: static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108: static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109: static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110: static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111: static PetscErrorCode Copy(Vec, Vec) noexcept;
112: static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113: static PetscErrorCode Swap(Vec, Vec) noexcept;
114: static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115: static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116: static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117: static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118: static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119: static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120: static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121: static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122: static PetscErrorCode Conjugate(Vec) noexcept;
123: static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124: template <PetscMemoryAccessMode>
125: static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126: template <PetscMemoryAccessMode>
127: static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128: static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129: static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130: static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131: static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132: static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133: static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134: static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135: static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136: static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137: };
139: namespace kernels
140: {
142: template <typename F>
143: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144: {
145: ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146: const auto end = jmap[i + 1];
147: const auto idx = xvindex(i);
148: PetscScalar sum = 0.0;
150: for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
152: if (imode == INSERT_VALUES) {
153: xv[idx] = sum;
154: } else {
155: xv[idx] += sum;
156: }
157: });
158: return;
159: }
161: namespace
162: {
164: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165: {
166: add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167: return;
168: }
170: } // namespace
172: #if PetscDefined(USING_HCC)
173: namespace do_not_use
174: {
176: // Needed to silence clang warning:
177: //
178: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
179: //
180: // The warning is silly, since the function *is* used, however the host compiler does not
181: // appear see this. Likely because the function using it is in a template.
182: //
183: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
184: inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
185: {
186: (void)add_coo_values;
187: }
189: } // namespace do_not_use
190: #endif
192: } // namespace kernels
194: } // namespace impl
196: // ==========================================================================================
197: // VecSeq_CUPM - Implementations
198: // ==========================================================================================
200: template <device::cupm::DeviceType T>
201: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
202: {
203: PetscFunctionBegin;
204: PetscAssertPointer(v, 4);
205: PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
206: PetscFunctionReturn(PETSC_SUCCESS);
207: }
209: template <device::cupm::DeviceType T>
210: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
211: {
212: PetscFunctionBegin;
213: if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
214: PetscAssertPointer(v, 6);
215: PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
216: PetscFunctionReturn(PETSC_SUCCESS);
217: }
219: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
220: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
221: {
222: PetscFunctionBegin;
224: PetscAssertPointer(a, 2);
225: PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
226: PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
227: PetscFunctionReturn(PETSC_SUCCESS);
228: }
230: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
231: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
232: {
233: PetscFunctionBegin;
235: PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
236: PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
237: PetscFunctionReturn(PETSC_SUCCESS);
238: }
240: template <device::cupm::DeviceType T>
241: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
242: {
243: PetscFunctionBegin;
244: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
245: PetscFunctionReturn(PETSC_SUCCESS);
246: }
248: template <device::cupm::DeviceType T>
249: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
250: {
251: PetscFunctionBegin;
252: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
253: PetscFunctionReturn(PETSC_SUCCESS);
254: }
256: template <device::cupm::DeviceType T>
257: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
258: {
259: PetscFunctionBegin;
260: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
261: PetscFunctionReturn(PETSC_SUCCESS);
262: }
264: template <device::cupm::DeviceType T>
265: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
266: {
267: PetscFunctionBegin;
268: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
269: PetscFunctionReturn(PETSC_SUCCESS);
270: }
272: template <device::cupm::DeviceType T>
273: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
274: {
275: PetscFunctionBegin;
276: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
277: PetscFunctionReturn(PETSC_SUCCESS);
278: }
280: template <device::cupm::DeviceType T>
281: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
282: {
283: PetscFunctionBegin;
284: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
285: PetscFunctionReturn(PETSC_SUCCESS);
286: }
288: template <device::cupm::DeviceType T>
289: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
290: {
291: PetscFunctionBegin;
293: PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
294: PetscFunctionReturn(PETSC_SUCCESS);
295: }
297: template <device::cupm::DeviceType T>
298: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
299: {
300: PetscFunctionBegin;
302: PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
303: PetscFunctionReturn(PETSC_SUCCESS);
304: }
306: template <device::cupm::DeviceType T>
307: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
308: {
309: PetscFunctionBegin;
311: PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
312: PetscFunctionReturn(PETSC_SUCCESS);
313: }
315: } // namespace cupm
317: } // namespace vec
319: } // namespace Petsc
321: #if PetscDefined(HAVE_CUDA)
322: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
323: #endif
325: #if PetscDefined(HAVE_HIP)
326: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
327: #endif