Actual source code: vecseqcupm.hpp
1: #pragma once
3: #include <petsc/private/veccupmimpl.h>
4: #include <petsc/private/cpp/utility.hpp>
6: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
7: #include <../src/vec/vec/impls/dvecimpl.h>
9: namespace Petsc
10: {
12: namespace vec
13: {
15: namespace cupm
16: {
18: namespace impl
19: {
21: // ==========================================================================================
22: // VecSeq_CUPM
23: // ==========================================================================================
25: template <device::cupm::DeviceType T>
26: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
27: public:
28: PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);
30: private:
31: PETSC_NODISCARD static Vec_Seq *VecIMPLCast_(Vec) noexcept;
32: PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
33: PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;
35: static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
36: static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
37: static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
38: static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;
40: static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;
42: // common core for min and max
43: template <typename TupleFuncT, typename UnaryFuncT>
44: static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
45: // common core for pointwise binary and pointwise unary thrust functions
46: template <typename BinaryFuncT>
47: static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
48: template <typename BinaryFuncT>
49: static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
50: template <typename UnaryFuncT>
51: static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
52: // mdot dispatchers
53: static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
54: static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
55: template <std::size_t... Idx>
56: static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
57: template <int>
58: static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
59: template <std::size_t... Idx>
60: static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
61: template <int>
62: static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
63: // common core for the various create routines
64: static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;
66: public:
67: // callable directly via a bespoke function
68: static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
69: static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;
71: static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
72: static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;
74: // callable indirectly via function pointers
75: static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
76: static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
77: static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
78: static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
79: static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
80: static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
81: static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
82: static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
83: static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
84: static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
85: static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
86: static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
87: static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
88: static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
89: static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
90: static PetscErrorCode Reciprocal(Vec) noexcept;
91: static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
92: static PetscErrorCode Abs(Vec) noexcept;
93: static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
94: static PetscErrorCode SqrtAbs(Vec) noexcept;
95: static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
96: static PetscErrorCode Exp(Vec) noexcept;
97: static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
98: static PetscErrorCode Log(Vec) noexcept;
99: static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100: static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101: static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102: static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103: static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104: static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105: static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106: static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107: static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108: static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109: static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110: static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111: static PetscErrorCode Copy(Vec, Vec) noexcept;
112: static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113: static PetscErrorCode Swap(Vec, Vec) noexcept;
114: static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115: static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116: static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117: static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118: static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119: static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120: static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121: static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122: static PetscErrorCode Conjugate(Vec) noexcept;
123: static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124: template <PetscMemoryAccessMode>
125: static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126: template <PetscMemoryAccessMode>
127: static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128: static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129: static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130: static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131: static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132: static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133: static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134: static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135: static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136: static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137: };
139: namespace kernels
140: {
142: template <typename F>
143: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144: {
145: ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146: const auto end = jmap[i + 1];
147: const auto idx = xvindex(i);
148: PetscScalar sum = 0.0;
150: for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];
152: if (imode == INSERT_VALUES) {
153: xv[idx] = sum;
154: } else {
155: xv[idx] += sum;
156: }
157: });
158: return;
159: }
161: namespace
162: {
163: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function")
164: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165: {
166: add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167: return;
168: }
169: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
170: } // namespace
172: } // namespace kernels
174: } // namespace impl
176: // ==========================================================================================
177: // VecSeq_CUPM - Implementations
178: // ==========================================================================================
180: template <device::cupm::DeviceType T>
181: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
182: {
183: PetscFunctionBegin;
184: PetscAssertPointer(v, 4);
185: PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
186: PetscFunctionReturn(PETSC_SUCCESS);
187: }
189: template <device::cupm::DeviceType T>
190: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
191: {
192: PetscFunctionBegin;
193: if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
194: PetscAssertPointer(v, 6);
195: PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
196: PetscFunctionReturn(PETSC_SUCCESS);
197: }
199: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
200: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
201: {
202: PetscFunctionBegin;
204: PetscAssertPointer(a, 2);
205: PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
206: PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
207: PetscFunctionReturn(PETSC_SUCCESS);
208: }
210: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
211: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
212: {
213: PetscFunctionBegin;
215: PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
216: PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
217: PetscFunctionReturn(PETSC_SUCCESS);
218: }
220: template <device::cupm::DeviceType T>
221: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
222: {
223: PetscFunctionBegin;
224: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
225: PetscFunctionReturn(PETSC_SUCCESS);
226: }
228: template <device::cupm::DeviceType T>
229: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
230: {
231: PetscFunctionBegin;
232: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
233: PetscFunctionReturn(PETSC_SUCCESS);
234: }
236: template <device::cupm::DeviceType T>
237: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
238: {
239: PetscFunctionBegin;
240: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
241: PetscFunctionReturn(PETSC_SUCCESS);
242: }
244: template <device::cupm::DeviceType T>
245: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
246: {
247: PetscFunctionBegin;
248: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
249: PetscFunctionReturn(PETSC_SUCCESS);
250: }
252: template <device::cupm::DeviceType T>
253: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
254: {
255: PetscFunctionBegin;
256: PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
257: PetscFunctionReturn(PETSC_SUCCESS);
258: }
260: template <device::cupm::DeviceType T>
261: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
262: {
263: PetscFunctionBegin;
264: PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
265: PetscFunctionReturn(PETSC_SUCCESS);
266: }
268: template <device::cupm::DeviceType T>
269: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
270: {
271: PetscFunctionBegin;
273: PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
274: PetscFunctionReturn(PETSC_SUCCESS);
275: }
277: template <device::cupm::DeviceType T>
278: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
279: {
280: PetscFunctionBegin;
282: PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
283: PetscFunctionReturn(PETSC_SUCCESS);
284: }
286: template <device::cupm::DeviceType T>
287: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
288: {
289: PetscFunctionBegin;
291: PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
292: PetscFunctionReturn(PETSC_SUCCESS);
293: }
295: } // namespace cupm
297: } // namespace vec
299: } // namespace Petsc
301: #if PetscDefined(HAVE_CUDA)
302: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
303: #endif
305: #if PetscDefined(HAVE_HIP)
306: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
307: #endif