Actual source code: vecseqcupm.hpp

  1: #pragma once

  3: #include <petsc/private/veccupmimpl.h>
  4: #include <petsc/private/cpp/utility.hpp>

  6: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
  7: #include <../src/vec/vec/impls/dvecimpl.h>

  9: namespace Petsc
 10: {

 12: namespace vec
 13: {

 15: namespace cupm
 16: {

 18: namespace impl
 19: {

 21: // ==========================================================================================
 22: // VecSeq_CUPM
 23: // ==========================================================================================

 25: template <device::cupm::DeviceType T>
 26: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
 27: public:
 28:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);

 30: private:
 31:   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
 32:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
 33:   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;

 35:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 36:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 37:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 38:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 40:   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;

 42:   // common core for min and max
 43:   template <typename TupleFuncT, typename UnaryFuncT>
 44:   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
 45:   // common core for pointwise binary and pointwise unary thrust functions
 46:   template <typename BinaryFuncT>
 47:   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 48:   template <typename BinaryFuncT>
 49:   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 50:   template <typename UnaryFuncT>
 51:   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 52:   // mdot dispatchers
 53:   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 54:   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 55:   template <std::size_t... Idx>
 56:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
 57:   template <int>
 58:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
 59:   template <std::size_t... Idx>
 60:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
 61:   template <int>
 62:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
 63:   // common core for the various create routines
 64:   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;

 66: public:
 67:   // callable directly via a bespoke function
 68:   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 69:   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 71:   static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
 72:   static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;

 74:   // callable indirectly via function pointers
 75:   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
 76:   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
 77:   static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
 78:   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
 79:   static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
 80:   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
 81:   static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 82:   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
 83:   static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 84:   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
 85:   static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 86:   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
 87:   static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 88:   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
 89:   static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 90:   static PetscErrorCode PointwiseSignAsync(Vec, Vec, VecSignMode, PetscDeviceContext) noexcept;
 91:   static PetscErrorCode Reciprocal(Vec) noexcept;
 92:   static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
 93:   static PetscErrorCode Abs(Vec) noexcept;
 94:   static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
 95:   static PetscErrorCode SqrtAbs(Vec) noexcept;
 96:   static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
 97:   static PetscErrorCode Exp(Vec) noexcept;
 98:   static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
 99:   static PetscErrorCode Log(Vec) noexcept;
100:   static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
101:   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
102:   static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
103:   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
104:   static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
105:   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
106:   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
107:   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
108:   static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
109:   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
110:   static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
111:   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
112:   static PetscErrorCode Copy(Vec, Vec) noexcept;
113:   static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
114:   static PetscErrorCode Swap(Vec, Vec) noexcept;
115:   static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
116:   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
117:   static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
118:   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
119:   static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
120:   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
121:   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
122:   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
123:   static PetscErrorCode Conjugate(Vec) noexcept;
124:   static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
125:   template <PetscMemoryAccessMode>
126:   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
127:   template <PetscMemoryAccessMode>
128:   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
129:   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
130:   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
131:   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
132:   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
133:   static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
134:   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
135:   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
136:   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
137:   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
138: };

140: namespace kernels
141: {

143: template <typename F>
144: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
145: {
146:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
147:     const auto  end = jmap[i + 1];
148:     const auto  idx = xvindex(i);
149:     PetscScalar sum = 0.0;

151:     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];

153:     if (imode == INSERT_VALUES) {
154:       xv[idx] = sum;
155:     } else {
156:       xv[idx] += sum;
157:     }
158:   });
159:   return;
160: }

162: namespace
163: {
164: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function")
165: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
166: {
167:   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
168:   return;
169: }
170: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
171: } // namespace

173: } // namespace kernels

175: } // namespace impl

177: // ==========================================================================================
178: // VecSeq_CUPM - Implementations
179: // ==========================================================================================

181: template <device::cupm::DeviceType T>
182: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
183: {
184:   PetscFunctionBegin;
185:   PetscAssertPointer(v, 4);
186:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
187:   PetscFunctionReturn(PETSC_SUCCESS);
188: }

190: template <device::cupm::DeviceType T>
191: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
192: {
193:   PetscFunctionBegin;
194:   if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
195:   PetscAssertPointer(v, 6);
196:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
197:   PetscFunctionReturn(PETSC_SUCCESS);
198: }

200: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
201: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
202: {
203:   PetscFunctionBegin;
205:   PetscAssertPointer(a, 2);
206:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
207:   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
208:   PetscFunctionReturn(PETSC_SUCCESS);
209: }

211: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
212: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
213: {
214:   PetscFunctionBegin;
216:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
217:   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
218:   PetscFunctionReturn(PETSC_SUCCESS);
219: }

221: template <device::cupm::DeviceType T>
222: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
223: {
224:   PetscFunctionBegin;
225:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
226:   PetscFunctionReturn(PETSC_SUCCESS);
227: }

229: template <device::cupm::DeviceType T>
230: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
231: {
232:   PetscFunctionBegin;
233:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
234:   PetscFunctionReturn(PETSC_SUCCESS);
235: }

237: template <device::cupm::DeviceType T>
238: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
239: {
240:   PetscFunctionBegin;
241:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
242:   PetscFunctionReturn(PETSC_SUCCESS);
243: }

245: template <device::cupm::DeviceType T>
246: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
247: {
248:   PetscFunctionBegin;
249:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
250:   PetscFunctionReturn(PETSC_SUCCESS);
251: }

253: template <device::cupm::DeviceType T>
254: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
255: {
256:   PetscFunctionBegin;
257:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
258:   PetscFunctionReturn(PETSC_SUCCESS);
259: }

261: template <device::cupm::DeviceType T>
262: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
263: {
264:   PetscFunctionBegin;
265:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
266:   PetscFunctionReturn(PETSC_SUCCESS);
267: }

269: template <device::cupm::DeviceType T>
270: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
271: {
272:   PetscFunctionBegin;
274:   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
275:   PetscFunctionReturn(PETSC_SUCCESS);
276: }

278: template <device::cupm::DeviceType T>
279: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
280: {
281:   PetscFunctionBegin;
283:   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
284:   PetscFunctionReturn(PETSC_SUCCESS);
285: }

287: template <device::cupm::DeviceType T>
288: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
289: {
290:   PetscFunctionBegin;
292:   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
293:   PetscFunctionReturn(PETSC_SUCCESS);
294: }

296: } // namespace cupm

298: } // namespace vec

300: } // namespace Petsc

302: #if PetscDefined(HAVE_CUDA)
303: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
304: #endif

306: #if PetscDefined(HAVE_HIP)
307: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
308: #endif