Actual source code: vecseqcupm.hpp

  1: #pragma once

  3: #include <petsc/private/veccupmimpl.h>
  4: #include <petsc/private/cpp/utility.hpp>

  6: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>
  7: #include <../src/vec/vec/impls/dvecimpl.h>

  9: namespace Petsc
 10: {

 12: namespace vec
 13: {

 15: namespace cupm
 16: {

 18: namespace impl
 19: {

 21: // ==========================================================================================
 22: // VecSeq_CUPM
 23: // ==========================================================================================

 25: template <device::cupm::DeviceType T>
 26: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
 27: public:
 28:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);

 30: private:
 31:   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
 32:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;
 33:   PETSC_NODISCARD static constexpr VecType VECIMPL_() noexcept;

 35:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 36:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 37:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 38:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 40:   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;

 42:   // common core for min and max
 43:   template <typename TupleFuncT, typename UnaryFuncT>
 44:   static PetscErrorCode MinMax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
 45:   // common core for pointwise binary and pointwise unary thrust functions
 46:   template <typename BinaryFuncT>
 47:   static PetscErrorCode PointwiseBinary_(BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 48:   template <typename BinaryFuncT>
 49:   static PetscErrorCode PointwiseBinaryDispatch_(PetscErrorCode (*)(Vec, Vec, Vec), BinaryFuncT &&, Vec, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 50:   template <typename UnaryFuncT>
 51:   static PetscErrorCode PointwiseUnary_(UnaryFuncT &&, Vec, Vec, PetscDeviceContext = nullptr) noexcept;
 52:   // mdot dispatchers
 53:   static PetscErrorCode MDot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 54:   static PetscErrorCode MDot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 55:   template <std::size_t... Idx>
 56:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
 57:   template <int>
 58:   static PetscErrorCode MDot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
 59:   template <std::size_t... Idx>
 60:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
 61:   template <int>
 62:   static PetscErrorCode MAXPY_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
 63:   // common core for the various create routines
 64:   static PetscErrorCode CreateSeqCUPM_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;

 66: public:
 67:   // callable directly via a bespoke function
 68:   static PetscErrorCode CreateSeqCUPM(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 69:   static PetscErrorCode CreateSeqCUPMWithBothArrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 71:   static PetscErrorCode InitializeAsyncFunctions(Vec) noexcept;
 72:   static PetscErrorCode ClearAsyncFunctions(Vec) noexcept;

 74:   // callable indirectly via function pointers
 75:   static PetscErrorCode Duplicate(Vec, Vec *) noexcept;
 76:   static PetscErrorCode AYPX(Vec, PetscScalar, Vec) noexcept;
 77:   static PetscErrorCode AYPXAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
 78:   static PetscErrorCode AXPY(Vec, PetscScalar, Vec) noexcept;
 79:   static PetscErrorCode AXPYAsync(Vec, PetscScalar, Vec, PetscDeviceContext) noexcept;
 80:   static PetscErrorCode PointwiseDivide(Vec, Vec, Vec) noexcept;
 81:   static PetscErrorCode PointwiseDivideAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 82:   static PetscErrorCode PointwiseMult(Vec, Vec, Vec) noexcept;
 83:   static PetscErrorCode PointwiseMultAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 84:   static PetscErrorCode PointwiseMax(Vec, Vec, Vec) noexcept;
 85:   static PetscErrorCode PointwiseMaxAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 86:   static PetscErrorCode PointwiseMaxAbs(Vec, Vec, Vec) noexcept;
 87:   static PetscErrorCode PointwiseMaxAbsAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 88:   static PetscErrorCode PointwiseMin(Vec, Vec, Vec) noexcept;
 89:   static PetscErrorCode PointwiseMinAsync(Vec, Vec, Vec, PetscDeviceContext) noexcept;
 90:   static PetscErrorCode Reciprocal(Vec) noexcept;
 91:   static PetscErrorCode ReciprocalAsync(Vec, PetscDeviceContext) noexcept;
 92:   static PetscErrorCode Abs(Vec) noexcept;
 93:   static PetscErrorCode AbsAsync(Vec, PetscDeviceContext) noexcept;
 94:   static PetscErrorCode SqrtAbs(Vec) noexcept;
 95:   static PetscErrorCode SqrtAbsAsync(Vec, PetscDeviceContext) noexcept;
 96:   static PetscErrorCode Exp(Vec) noexcept;
 97:   static PetscErrorCode ExpAsync(Vec, PetscDeviceContext) noexcept;
 98:   static PetscErrorCode Log(Vec) noexcept;
 99:   static PetscErrorCode LogAsync(Vec, PetscDeviceContext) noexcept;
100:   static PetscErrorCode WAXPY(Vec, PetscScalar, Vec, Vec) noexcept;
101:   static PetscErrorCode WAXPYAsync(Vec, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
102:   static PetscErrorCode MAXPY(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
103:   static PetscErrorCode MAXPYAsync(Vec, PetscInt, const PetscScalar[], Vec *, PetscDeviceContext) noexcept;
104:   static PetscErrorCode Dot(Vec, Vec, PetscScalar *) noexcept;
105:   static PetscErrorCode MDot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
106:   static PetscErrorCode Set(Vec, PetscScalar) noexcept;
107:   static PetscErrorCode SetAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
108:   static PetscErrorCode Scale(Vec, PetscScalar) noexcept;
109:   static PetscErrorCode ScaleAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
110:   static PetscErrorCode TDot(Vec, Vec, PetscScalar *) noexcept;
111:   static PetscErrorCode Copy(Vec, Vec) noexcept;
112:   static PetscErrorCode CopyAsync(Vec, Vec, PetscDeviceContext) noexcept;
113:   static PetscErrorCode Swap(Vec, Vec) noexcept;
114:   static PetscErrorCode SwapAsync(Vec, Vec, PetscDeviceContext) noexcept;
115:   static PetscErrorCode AXPBY(Vec, PetscScalar, PetscScalar, Vec) noexcept;
116:   static PetscErrorCode AXPBYAsync(Vec, PetscScalar, PetscScalar, Vec, PetscDeviceContext) noexcept;
117:   static PetscErrorCode AXPBYPCZ(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
118:   static PetscErrorCode AXPBYPCZAsync(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec, PetscDeviceContext) noexcept;
119:   static PetscErrorCode Norm(Vec, NormType, PetscReal *) noexcept;
120:   static PetscErrorCode ErrorWnorm(Vec, Vec, Vec, NormType, PetscReal, Vec, PetscReal, Vec, PetscReal, PetscReal *, PetscInt *, PetscReal *, PetscInt *, PetscReal *, PetscInt *) noexcept;
121:   static PetscErrorCode DotNorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
122:   static PetscErrorCode Conjugate(Vec) noexcept;
123:   static PetscErrorCode ConjugateAsync(Vec, PetscDeviceContext) noexcept;
124:   template <PetscMemoryAccessMode>
125:   static PetscErrorCode GetLocalVector(Vec, Vec) noexcept;
126:   template <PetscMemoryAccessMode>
127:   static PetscErrorCode RestoreLocalVector(Vec, Vec) noexcept;
128:   static PetscErrorCode Max(Vec, PetscInt *, PetscReal *) noexcept;
129:   static PetscErrorCode Min(Vec, PetscInt *, PetscReal *) noexcept;
130:   static PetscErrorCode Sum(Vec, PetscScalar *) noexcept;
131:   static PetscErrorCode Shift(Vec, PetscScalar) noexcept;
132:   static PetscErrorCode ShiftAsync(Vec, PetscScalar, PetscDeviceContext) noexcept;
133:   static PetscErrorCode SetRandom(Vec, PetscRandom) noexcept;
134:   static PetscErrorCode BindToCPU(Vec, PetscBool) noexcept;
135:   static PetscErrorCode SetPreallocationCOO(Vec, PetscCount, const PetscInt[]) noexcept;
136:   static PetscErrorCode SetValuesCOO(Vec, const PetscScalar[], InsertMode) noexcept;
137: };

139: namespace kernels
140: {

142: template <typename F>
143: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
144: {
145:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
146:     const auto  end = jmap[i + 1];
147:     const auto  idx = xvindex(i);
148:     PetscScalar sum = 0.0;

150:     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];

152:     if (imode == INSERT_VALUES) {
153:       xv[idx] = sum;
154:     } else {
155:       xv[idx] += sum;
156:     }
157:   });
158:   return;
159: }

161: namespace
162: {
163: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wunused-function")
164: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
165: {
166:   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
167:   return;
168: }
169: PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
170: } // namespace

172: } // namespace kernels

174: } // namespace impl

176: // ==========================================================================================
177: // VecSeq_CUPM - Implementations
178: // ==========================================================================================

180: template <device::cupm::DeviceType T>
181: inline PetscErrorCode VecCreateSeqCUPMAsync(MPI_Comm comm, PetscInt n, Vec *v) noexcept
182: {
183:   PetscFunctionBegin;
184:   PetscAssertPointer(v, 4);
185:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPM(comm, 0, n, v, PETSC_TRUE));
186:   PetscFunctionReturn(PETSC_SUCCESS);
187: }

189: template <device::cupm::DeviceType T>
190: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
191: {
192:   PetscFunctionBegin;
193:   if (n && cpuarray) PetscAssertPointer(cpuarray, 4);
194:   PetscAssertPointer(v, 6);
195:   PetscCall(impl::VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(comm, bs, n, cpuarray, gpuarray, v));
196:   PetscFunctionReturn(PETSC_SUCCESS);
197: }

199: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
200: inline PetscErrorCode VecCUPMGetArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
201: {
202:   PetscFunctionBegin;
204:   PetscAssertPointer(a, 2);
205:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
206:   PetscCall(impl::VecSeq_CUPM<T>::template GetArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
207:   PetscFunctionReturn(PETSC_SUCCESS);
208: }

210: template <PetscMemoryAccessMode mode, device::cupm::DeviceType T>
211: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(Vec v, PetscScalar **a, PetscDeviceContext dctx) noexcept
212: {
213:   PetscFunctionBegin;
215:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
216:   PetscCall(impl::VecSeq_CUPM<T>::template RestoreArray<PETSC_MEMTYPE_DEVICE, mode>(v, a, dctx));
217:   PetscFunctionReturn(PETSC_SUCCESS);
218: }

220: template <device::cupm::DeviceType T>
221: inline PetscErrorCode VecCUPMGetArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
222: {
223:   PetscFunctionBegin;
224:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
225:   PetscFunctionReturn(PETSC_SUCCESS);
226: }

228: template <device::cupm::DeviceType T>
229: inline PetscErrorCode VecCUPMRestoreArrayAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
230: {
231:   PetscFunctionBegin;
232:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE, T>(v, a, dctx));
233:   PetscFunctionReturn(PETSC_SUCCESS);
234: }

236: template <device::cupm::DeviceType T>
237: inline PetscErrorCode VecCUPMGetArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
238: {
239:   PetscFunctionBegin;
240:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
241:   PetscFunctionReturn(PETSC_SUCCESS);
242: }

244: template <device::cupm::DeviceType T>
245: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(Vec v, const PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
246: {
247:   PetscFunctionBegin;
248:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ, T>(v, const_cast<PetscScalar **>(a), dctx));
249:   PetscFunctionReturn(PETSC_SUCCESS);
250: }

252: template <device::cupm::DeviceType T>
253: inline PetscErrorCode VecCUPMGetArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
254: {
255:   PetscFunctionBegin;
256:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
257:   PetscFunctionReturn(PETSC_SUCCESS);
258: }

260: template <device::cupm::DeviceType T>
261: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(Vec v, PetscScalar **a, PetscDeviceContext dctx = nullptr) noexcept
262: {
263:   PetscFunctionBegin;
264:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE, T>(v, a, dctx));
265:   PetscFunctionReturn(PETSC_SUCCESS);
266: }

268: template <device::cupm::DeviceType T>
269: inline PetscErrorCode VecCUPMPlaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
270: {
271:   PetscFunctionBegin;
273:   PetscCall(impl::VecSeq_CUPM<T>::template PlaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
274:   PetscFunctionReturn(PETSC_SUCCESS);
275: }

277: template <device::cupm::DeviceType T>
278: inline PetscErrorCode VecCUPMReplaceArrayAsync(Vec vin, const PetscScalar a[]) noexcept
279: {
280:   PetscFunctionBegin;
282:   PetscCall(impl::VecSeq_CUPM<T>::template ReplaceArray<PETSC_MEMTYPE_DEVICE>(vin, a));
283:   PetscFunctionReturn(PETSC_SUCCESS);
284: }

286: template <device::cupm::DeviceType T>
287: inline PetscErrorCode VecCUPMResetArrayAsync(Vec vin) noexcept
288: {
289:   PetscFunctionBegin;
291:   PetscCall(impl::VecSeq_CUPM<T>::template ResetArray<PETSC_MEMTYPE_DEVICE>(vin));
292:   PetscFunctionReturn(PETSC_SUCCESS);
293: }

295: } // namespace cupm

297: } // namespace vec

299: } // namespace Petsc

301: #if PetscDefined(HAVE_CUDA)
302: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::CUDA>;
303: #endif

305: #if PetscDefined(HAVE_HIP)
306: extern template class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL ::Petsc::vec::cupm::impl::VecSeq_CUPM<::Petsc::device::cupm::DeviceType::HIP>;
307: #endif