Actual source code: vecseqcupm.hpp

  1: #pragma once

  3: #include <petsc/private/veccupmimpl.h>
  4: #include <petsc/private/cpp/utility.hpp>

  6: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
  7: #include <../src/vec/vec/impls/dvecimpl.h>

  9: namespace Petsc
 10: {

 12: namespace vec
 13: {

 15: namespace cupm
 16: {

 18: namespace impl
 19: {

 21: // ==========================================================================================
 22: // VecSeq_CUPM
 23: // ==========================================================================================

 25: template <device::cupm::DeviceType T>
 26: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
 27: public:
 28:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);

 30: private:
 31:   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
 32:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
 33:   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;

 35:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 36:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 37:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 38:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 40:   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;

 42:   // common core for min and max
 43:   template <typename TupleFuncT, typename UnaryFuncT>
 44:   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
 45:   // common core for pointwise binary and pointwise unary thrust functions
 46:   template <typename BinaryFuncT>
 47:   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 48:   template <typename BinaryFuncT>
 49:   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 50:   template <typename UnaryFuncT>
 51:   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 52:   // mdot dispatchers
 53:   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 54:   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 55:   template <std::size_t... Idx>
 56:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
 57:   template <int>
 58:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
 59:   template <std::size_t... Idx>
 60:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
 61:   template <int>
 62:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
 63:   // common core for the various create routines
 64:   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;

 66: public:
 67:   // callable directly via a bespoke function
 68:   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 69:   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 71:   static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
 72:   static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;

 74:   // callable indirectly via function pointers
 75:   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
 76:   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
 77:   static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
 78:   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
 79:   static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
 80:   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
 81:   static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 82:   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
 83:   static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 84:   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
 85:   static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 86:   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
 87:   static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 88:   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
 89:   static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 90:   static PetscErrorCode Reciprocal(Vec) noexcept;
 91:   static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
 92:   static PetscErrorCode Abs(Vec) noexcept;
 93:   static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
 94:   static PetscErrorCode SqrtAbs(Vec) noexcept;
 95:   static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
 96:   static PetscErrorCode Exp(Vec) noexcept;
 97:   static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
 98:   static PetscErrorCode Log(Vec) noexcept;
 99:   static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100:   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101:   static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102:   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103:   static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104:   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105:   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106:   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107:   static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108:   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109:   static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110:   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111:   static PetscErrorCode Copy(Vec, Vec) noexcept;
112:   static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113:   static PetscErrorCode Swap(Vec, Vec) noexcept;
114:   static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115:   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116:   static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117:   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118:   static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119:   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120:   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121:   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122:   static PetscErrorCode Conjugate(Vec) noexcept;
123:   static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124:   template <PetscMemoryAccessMode>
125:   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126:   template <PetscMemoryAccessMode>
127:   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128:   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129:   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130:   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131:   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132:   static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133:   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134:   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135:   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136:   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137: };

139: namespace kernels
140: {

142: template <typename F>
143: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144: {
145:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146:     const auto  end = jmap[i + 1];
147:     const auto  idx = xvindex(i);
148:     PetscScalar sum = 0.0;

150:     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];

152:     if (imode == INSERT_VALUES) {
153:       xv[idx] = sum;
154:     } else {
155:       xv[idx] += sum;
156:     }
157:   });
158:   return;
159: }

161: namespace
162: {

164: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165: {
166:   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167:   return;
168: }

170: } // namespace

172: #if PetscDefined(USING_HCC)
173: namespace do_not_use
174: {

176: // Needed to silence clang warning:
177: //
178: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
179: //
180: // The warning is silly, since the function *is* used, however the host compiler does not
181: // appear see this. Likely because the function using it is in a template.
182: //
183: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
184: inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
185: {
186:   (void)add_coo_values;
187: }

189: } // namespace do_not_use
190: #endif

192: } // namespace kernels

194: } // namespace impl

196: // ==========================================================================================
197: // VecSeq_CUPM - Implementations
198: // ==========================================================================================

200: template <device::cupm::DeviceType T>
201: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
202: {
203:   PetscFunctionBegin;
204:   PetscAssertPointer(v, 4);
205:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
206:   PetscFunctionReturn(PETSC_SUCCESS);
207: }

209: template <device::cupm::DeviceType T>
210: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
211: {
212:   PetscFunctionBegin;
213:   if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
214:   PetscAssertPointer(v, 6);
215:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
216:   PetscFunctionReturn(PETSC_SUCCESS);
217: }

219: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
220: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
221: {
222:   PetscFunctionBegin;
224:   PetscAssertPointer(a, 2);
225:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
226:   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
227:   PetscFunctionReturn(PETSC_SUCCESS);
228: }

230: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
231: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
232: {
233:   PetscFunctionBegin;
235:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
236:   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
237:   PetscFunctionReturn(PETSC_SUCCESS);
238: }

240: template <device::cupm::DeviceType T>
241: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
242: {
243:   PetscFunctionBegin;
244:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
245:   PetscFunctionReturn(PETSC_SUCCESS);
246: }

248: template <device::cupm::DeviceType T>
249: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
250: {
251:   PetscFunctionBegin;
252:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
253:   PetscFunctionReturn(PETSC_SUCCESS);
254: }

256: template <device::cupm::DeviceType T>
257: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
258: {
259:   PetscFunctionBegin;
260:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
261:   PetscFunctionReturn(PETSC_SUCCESS);
262: }

264: template <device::cupm::DeviceType T>
265: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
266: {
267:   PetscFunctionBegin;
268:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
269:   PetscFunctionReturn(PETSC_SUCCESS);
270: }

272: template <device::cupm::DeviceType T>
273: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
274: {
275:   PetscFunctionBegin;
276:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
277:   PetscFunctionReturn(PETSC_SUCCESS);
278: }

280: template <device::cupm::DeviceType T>
281: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
282: {
283:   PetscFunctionBegin;
284:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
285:   PetscFunctionReturn(PETSC_SUCCESS);
286: }

288: template <device::cupm::DeviceType T>
289: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
290: {
291:   PetscFunctionBegin;
293:   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
294:   PetscFunctionReturn(PETSC_SUCCESS);
295: }

297: template <device::cupm::DeviceType T>
298: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
299: {
300:   PetscFunctionBegin;
302:   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
303:   PetscFunctionReturn(PETSC_SUCCESS);
304: }

306: template <device::cupm::DeviceType T>
307: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
308: {
309:   PetscFunctionBegin;
311:   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
312:   PetscFunctionReturn(PETSC_SUCCESS);
313: }

315: } // namespace cupm

317: } // namespace vec

319: } // namespace Petsc

321: #if PetscDefined(HAVE_CUDA)
322: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
323: #endif

325: #if PetscDefined(HAVE_HIP)
326: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
327: #endif