Actual source code: vecseqcupm.hpp
1: #pragma once
3: #include <petsc/private/veccupmimpl.h>
4: #include <petsc/private/cpp/utility.hpp>
6: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
7: #include <../src/vec/vec/impls/dvecimpl.h>
9: namespace Petsc
10: {
12: namespace vec
13: {
15: namespace cupm
16: {
18: namespace impl
19: {
21: // ==========================================================================================
22: // VecSeq_CUPM
23: // ==========================================================================================
25: template <device::cupm::DeviceType T>
26: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
27: public:
28: PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
30: private:
31: PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept;
32: PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
33: PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
35: static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
36: static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
37: static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
38: static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
40: static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
42: // common core for min and max
43: template <typename TupleFuncT, typename UnaryFuncT>
44: static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
45: // common core for pointwise binary and pointwise unary thrust functions
46: template <typename BinaryFuncT>
47: static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
48: template <typename BinaryFuncT>
49: static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
50: template <typename UnaryFuncT>
51: static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
52: // mdot dispatchers
53: static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54: static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55: template <std::size_t... Idx>
56: static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
57: template <int>
58: static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
59: template <std::size_t... Idx>
60: static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
61: template <int>
62: static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
63: // common core for the various create routines
64: static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
66: public:
67: // callable directly via a bespoke function
68: static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
69: static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
71: static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
72: static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
74: // callable indirectly via function pointers
75: static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
76: static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
77: static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
78: static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
79: static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
80: static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
81: static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
82: static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
83: static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
84: static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
85: static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
86: static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
87: static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
88: static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
89: static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
90: static PetscErrorCode PointwiseSignAsync(Vec, Vec, VecSignMode, PetscDeviceContext) noexcept;
91: static PetscErrorCode Reciprocal(Vec) noexcept;
92: static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
93: static PetscErrorCode Abs(Vec) noexcept;
94: static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
95: static PetscErrorCode SqrtAbs(Vec) noexcept;
96: static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
97: static PetscErrorCode Exp(Vec) noexcept;
98: static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
99: static PetscErrorCode Log(Vec) noexcept;
100: static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
101: static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
102: static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
103: static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
104: static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
105: static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
106: static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
107: static PetscErrorCode Set(Vec, PetscScalar) noexcept;
108: static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
109: static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
110: static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
111: static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
112: static PetscErrorCode Copy(Vec, Vec) noexcept;
113: static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
114: static PetscErrorCode Swap(Vec, Vec) noexcept;
115: static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
116: static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
117: static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
118: static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
119: static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
120: static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
121: static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
122: static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
123: static PetscErrorCode Conjugate(Vec) noexcept;
124: static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
125: template <PetscMemoryAccessMode>
126: static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
127: template <PetscMemoryAccessMode>
128: static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
129: static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
130: static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
131: static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
132: static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
133: static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
134: static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
135: static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
136: static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
137: static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
138: };
140: namespace kernels
141: {
143: template <typename F>
144: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
145: {
146: ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
147: const auto end = jmap[i + 1];
148: const auto idx = xvindex(i);
149: PetscScalar sum = 0.0;
151: for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
153: if (imode == INSERT_VALUES) {
154: xv[idx] = sum;
155: } else {
156: xv[idx] += sum;
157: }
158: });
159: return;
160: }
162: namespace
163: {
164: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function")
165: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
166: {
167: add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
168: return;
169: }
170: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
171: } // namespace
173: } // namespace kernels
175: } // namespace impl
177: // ==========================================================================================
178: // VecSeq_CUPM - Implementations
179: // ==========================================================================================
181: template <device::cupm::DeviceType T>
182: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
183: {
184: PetscFunctionBegin;
185: PetscAssertPointer(v, 4);
186: PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
187: PetscFunctionReturn(PETSC_SUCCESS);
188: }
190: template <device::cupm::DeviceType T>
191: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
192: {
193: PetscFunctionBegin;
194: if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
195: PetscAssertPointer(v, 6);
196: PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
197: PetscFunctionReturn(PETSC_SUCCESS);
198: }
200: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
201: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
202: {
203: PetscFunctionBegin;
205: PetscAssertPointer(a, 2);
206: PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
207: PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
208: PetscFunctionReturn(PETSC_SUCCESS);
209: }
211: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
212: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
213: {
214: PetscFunctionBegin;
216: PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
217: PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
218: PetscFunctionReturn(PETSC_SUCCESS);
219: }
221: template <device::cupm::DeviceType T>
222: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
223: {
224: PetscFunctionBegin;
225: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
226: PetscFunctionReturn(PETSC_SUCCESS);
227: }
229: template <device::cupm::DeviceType T>
230: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
231: {
232: PetscFunctionBegin;
233: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
234: PetscFunctionReturn(PETSC_SUCCESS);
235: }
237: template <device::cupm::DeviceType T>
238: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
239: {
240: PetscFunctionBegin;
241: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
242: PetscFunctionReturn(PETSC_SUCCESS);
243: }
245: template <device::cupm::DeviceType T>
246: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
247: {
248: PetscFunctionBegin;
249: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
250: PetscFunctionReturn(PETSC_SUCCESS);
251: }
253: template <device::cupm::DeviceType T>
254: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
255: {
256: PetscFunctionBegin;
257: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
258: PetscFunctionReturn(PETSC_SUCCESS);
259: }
261: template <device::cupm::DeviceType T>
262: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
263: {
264: PetscFunctionBegin;
265: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
266: PetscFunctionReturn(PETSC_SUCCESS);
267: }
269: template <device::cupm::DeviceType T>
270: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
271: {
272: PetscFunctionBegin;
274: PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
275: PetscFunctionReturn(PETSC_SUCCESS);
276: }
278: template <device::cupm::DeviceType T>
279: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
280: {
281: PetscFunctionBegin;
283: PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
284: PetscFunctionReturn(PETSC_SUCCESS);
285: }
287: template <device::cupm::DeviceType T>
288: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
289: {
290: PetscFunctionBegin;
292: PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
293: PetscFunctionReturn(PETSC_SUCCESS);
294: }
296: } // namespace cupm
298: } // namespace vec
300: } // namespace Petsc
302: #if PetscDefined(HAVE_CUDA)
303: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
304: #endif
306: #if PetscDefined(HAVE_HIP)
307: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
308: #endif