Actual source code: aijcupm.hpp

  1: #pragma once

  3: /* Shared CUPM (CUDA/HIP) implementations for SeqAIJCUSPARSE and SeqAIJHIPSPARSE
  4:    that do not depend on the cuSPARSE/hipSPARSE library proper.

  6:    Include ordering requirement: the vendor-specific impl header
  7:    (cusparsematimpl.h or hipsparsematimpl.h) must be included before this
  8:    header so that CsrMatrix, THRUSTINTARRAY*, THRUSTARRAY and all device-specific
  9:    struct types are visible when this header is processed.

 11:    Instantiated by:
 12:      aijcusparse.cu      (DeviceType::CUDA, using MatSeqAIJCUSPARSE_Policy)
 13:      aijhipsparse.hip.cxx (DeviceType::HIP,  using MatSeqAIJHIPSPARSE_Policy) */

 15: #include <petsc/private/cupmobject.hpp>
 16: #include <petsc/private/cupmblasinterface.hpp>
 17: #include <petsc/private/matimpl.h>
 18: #include <../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp>
 19: #include <../src/mat/impls/aij/seq/aij.h>

 21: #include <thrust/device_ptr.h>
 22: #include <thrust/iterator/counting_iterator.h>
 23: #include <thrust/iterator/permutation_iterator.h>
 24: #include <thrust/functional.h>
 25: #include <thrust/fill.h>
 26: #include <thrust/tuple.h>
 27: #include <thrust/transform.h>
 28: #include <thrust/for_each.h>
 29: #include <thrust/equal.h>

 31: /* Forward declaration of SeqAIJ fallback function used inside template */
 32: PETSC_INTERN PetscErrorCode MatGetDiagonal_SeqAIJ(Mat, Vec);

 34: namespace Petsc
 35: {

 37: namespace mat
 38: {

 40: namespace aij
 41: {

 43: namespace cupm
 44: {

 46: namespace impl
 47: {

 49: // expand row_offsets to uncompressed row indices coo_i[]
 50: struct Csr2coo {
 51:   const PetscInt *row_offsets;
 52:   PetscInt       *coo_i;

 54:   Csr2coo(const PetscInt *roff, PetscInt *cooi) : row_offsets(roff), coo_i(cooi) { }

 56:   PETSC_HOSTDEVICE_INLINE_DECL void operator()(PetscInt i) const
 57:   {
 58:     for (PetscInt j = row_offsets[i]; j < row_offsets[i + 1]; ++j) coo_i[j] = i;
 59:   }
 60: };

 62: struct PetscIntToCInt {
 63:   // Caller should check overflow
 64:   PETSC_HOSTDEVICE_INLINE_DECL int operator()(PetscInt i) const { return static_cast<int>(i); }
 65: };

 67: /* --------------------------------------------------------------------------
 68:    Shared device functor: left-scale CSR rows.
 69:    cprow[i] gives the logical row index for compressed row i; NULL = identity.
 70:    -------------------------------------------------------------------------- */
 71: struct DiagonalScaleLeft_CSR_Functor {
 72:   const PetscInt    *row_ptr;
 73:   PetscScalar       *val_ptr;
 74:   const PetscScalar *lv_ptr;
 75:   const PetscInt    *cprow;

 77:   PETSC_HOSTDEVICE_INLINE_DECL void operator()(PetscInt i) const
 78:   {
 79:     const PetscInt    row = cprow ? cprow[i] : i;
 80:     const PetscScalar s   = lv_ptr[row];
 81:     for (PetscInt j = row_ptr[i]; j < row_ptr[i + 1]; j++) val_ptr[j] *= s;
 82:   }
 83: };

 85: /* --------------------------------------------------------------------------
 86:    Shared device functor: get<1>(t) = get<0>(t).
 87:    Replaces the identical VecCUDAEquals / VecHIPEquals structs.
 88:    -------------------------------------------------------------------------- */
 89: struct VecCUPMEquals {
 90:   template <typename Tuple>
 91:   PETSC_HOSTDEVICE_INLINE_DECL void operator()(Tuple t) const
 92:   {
 93:     thrust::get<1>(t) = thrust::get<0>(t);
 94:   }
 95: };

 97: /* --------------------------------------------------------------------------
 98:    Shared __global__ kernel: accumulate COO values into a CSR array.
 99:    __global__ is valid for both nvcc and hipcc; the body is identical.
100:    -------------------------------------------------------------------------- */
101: __global__ static void MatAddCOOValues(const PetscScalar kv[], PetscCount nnz, const PetscCount jmap[], const PetscCount perm[], InsertMode imode, PetscScalar a[])
102: {
103:   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
104:   const PetscCount grid_size = gridDim.x * blockDim.x;
105:   for (; i < nnz; i += grid_size) {
106:     PetscScalar sum = 0.0;
107:     for (PetscCount k = jmap[i]; k < jmap[i + 1]; k++) sum += kv[perm[k]];
108:     a[i] = (imode == INSERT_VALUES ? (PetscScalar)0.0 : a[i]) + sum;
109:   }
110: }

112: /* --------------------------------------------------------------------------
113:    Shared __global__ kernel: extract the CSR diagonal.
114:    -------------------------------------------------------------------------- */
115: __global__ void GetDiagonal_CSR(const PetscInt *row, const PetscInt *col, const PetscScalar *val, const PetscInt len, PetscScalar *diag)
116: {
117:   const size_t x = blockIdx.x * blockDim.x + threadIdx.x;

119:   if (x < (size_t)len) {
120:     const PetscInt rowx = row[x], num_non0_row = row[x + 1] - rowx;
121:     PetscScalar    d = 0.0;

123:     for (PetscInt i = 0; i < num_non0_row; i++) {
124:       if (col[i + rowx] == (PetscInt)x) {
125:         d = val[i + rowx];
126:         break;
127:       }
128:     }
129:     diag[x] = d;
130:   }
131: }

133: /* ==========================================================================
134:    MatSeqAIJCUSPARSE_CUPM<T, Policy>

136:    Policy (C++11 traits class) requirements - all static methods:

138:      // Device struct types
139:      typedef ... mat_struct_type;       // Mat_SeqAIJCUSPARSE / Mat_SeqAIJHIPSPARSE
140:      typedef ... mult_struct_type;      // ...MultStruct equivalent

142:      // Storage-format constants (value of each format enumerator)
143:      static int storage_format_csr();
144:      static int storage_format_ell();
145:      static int storage_format_hyb();

147:      // Bookkeeping helpers (device-type specific)
148:      static PetscErrorCode CopyToGPU(Mat);
149:      static PetscErrorCode CopyFromGPU(Mat);
150:      static PetscErrorCode InvalidateTranspose(Mat, PetscBool);
151:      static PetscErrorCode ConvertFromSeqAIJ(Mat, MatType, MatReuse, Mat *);
152:      static const char    *mat_type_name;   // "seqaijcusparse" / "seqaijhipsparse"

154:      // Destruction helpers (device-type specific)
155:      static PetscErrorCode Destroy(Mat);
156:      static PetscErrorCode TriFactorsDestroy(void **);

158:      // Compose-function keys that differ between CUDA and HIP
159:      static const char *set_format_c;          // "MatCUSPARSESetFormat_C"        / "MatHIPSPARSESetFormat_C"
160:      static const char *set_use_cpu_solve_c;   // "MatCUSPARSESetUseCPUSolve_C"   / "MatHIPSPARSESetUseCPUSolve_C"
161:      static const char *product_seqdense_device_c; // "...seqdensecuda_C"          / "...seqdensehip_C"
162:      static const char *product_seqdense_c;    // "...seqdense_C"
163:      static const char *product_self_c;        // "...seqaijcusparse_C"           / "...seqaijhipsparse_C"
164:      static const char *seq_convert_hypre_c;   // "MatConvert_seqaijcusparse_hypre_C" / "_seqaijhipsparse_hypre_C"

166:      // Vec device-array access (device-type specific)
167:      static PetscErrorCode VecGetArrayRead  (Vec, const PetscScalar **);
168:      static PetscErrorCode VecRestoreArrayRead(Vec, const PetscScalar **);
169:      static PetscErrorCode VecGetArrayWrite (Vec, PetscScalar **);
170:      static PetscErrorCode VecRestoreArrayWrite(Vec, PetscScalar **);
171:    ========================================================================== */

173: template <device::cupm::DeviceType T, typename Policy>
174: struct MatSeqAIJCUSPARSE_CUPM : device::cupm::impl::CUPMObject<T> {
175:   PETSC_CUPMOBJECT_HEADER(T);

177:   typedef typename Policy::mat_struct_type  MatStructType;
178:   typedef typename Policy::mult_struct_type MultStructType;

180:   /* -------------------------------------------------------------------
181:      Tier 1 - Trivial
182:      ------------------------------------------------------------------- */

184:   /* MatAssemblyEnd: delegation to SeqAIJ */
185:   static PetscErrorCode AssemblyEnd(Mat A, MatAssemblyType mode) noexcept
186:   {
187:     PetscFunctionBegin;
188:     PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
189:     PetscFunctionReturn(PETSC_SUCCESS);
190:   }

192:   /* MatDuplicate */
193:   static PetscErrorCode Duplicate(Mat A, MatDuplicateOption cpvalues, Mat *B) noexcept
194:   {
195:     PetscFunctionBegin;
196:     PetscCall(MatDuplicate_SeqAIJ(A, cpvalues, B));
197:     PetscCall(Policy::ConvertFromSeqAIJ(*B, Policy::mat_type_name, MAT_INPLACE_MATRIX, B));
198:     PetscFunctionReturn(PETSC_SUCCESS);
199:   }

201:   /* MatGetCurrentMemType */
202:   static PetscErrorCode GetCurrentMemType(PETSC_UNUSED Mat A, PetscMemType *m) noexcept
203:   {
204:     PetscFunctionBegin;
205:     *m = PETSC_MEMTYPE_CUPM();
206:     PetscFunctionReturn(PETSC_SUCCESS);
207:   }

209:   /* MatCOOStructDestroy: free device jmap and perm fields */
210:   static PetscErrorCode COOStructDestroy(PetscCtxRt ctx) noexcept
211:   {
212:     MatCOOStruct_SeqAIJ *coo = *(MatCOOStruct_SeqAIJ **)ctx;

214:     PetscFunctionBegin;
215:     PetscCallCUPM(cupmFree(coo->perm));
216:     PetscCallCUPM(cupmFree(coo->jmap));
217:     PetscCall(PetscFree(coo));
218:     PetscFunctionReturn(PETSC_SUCCESS);
219:   }

221:   /* -------------------------------------------------------------------
222:      Tier 2 - Straightforward
223:      ------------------------------------------------------------------- */

225:   /* MatZeroEntries: fill device CSR values with zero */
226:   static PetscErrorCode ZeroEntries(Mat A) noexcept
227:   {
228:     PetscBool      gpu = PETSC_FALSE;
229:     Mat_SeqAIJ    *a   = (Mat_SeqAIJ *)A->data;
230:     MatStructType *spptr;

232:     PetscFunctionBegin;
233:     if (A->factortype == MAT_FACTOR_NONE) {
234:       spptr = (MatStructType *)A->spptr;
235:       if (spptr->mat) {
236:         CsrMatrix *matrix = (CsrMatrix *)spptr->mat->mat;
237:         if (matrix->values) {
238:           gpu = PETSC_TRUE;
239:           PetscCallThrust(thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), (PetscScalar)0.));
240:         }
241:       }
242:       if (spptr->matTranspose) {
243:         CsrMatrix *matrix = (CsrMatrix *)spptr->matTranspose->mat;
244:         if (matrix->values) PetscCallThrust(thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), (PetscScalar)0.));
245:       }
246:     }
247:     if (gpu) A->offloadmask = PETSC_OFFLOAD_GPU;
248:     else {
249:       PetscCall(PetscArrayzero(a->a, a->i[A->rmap->n]));
250:       A->offloadmask = PETSC_OFFLOAD_CPU;
251:     }
252:     PetscFunctionReturn(PETSC_SUCCESS);
253:   }

255:   /* MatScale: cupmBlasXscal on the device CSR values */
256:   static PetscErrorCode Scale(Mat Y, PetscScalar a) noexcept
257:   {
258:     Mat_SeqAIJ      *y  = (Mat_SeqAIJ *)Y->data;
259:     PetscScalar     *ay = nullptr;
260:     cupmBlasHandle_t blashandle;
261:     PetscBLASInt     one = 1, bnz = 1;

263:     PetscFunctionBegin;
264:     PetscCall(GetArray(Y, &ay));
265:     PetscCall(GetHandles_(&blashandle));
266:     PetscCall(PetscBLASIntCast(y->nz, &bnz));
267:     PetscCall(PetscLogGpuTimeBegin());
268:     PetscCallCUPMBLAS(cupmBlasXscal(blashandle, bnz, cupmScalarPtrCast(&a), cupmScalarPtrCast(ay), one));
269:     PetscCall(PetscLogGpuFlops(bnz));
270:     PetscCall(PetscLogGpuTimeEnd());
271:     PetscCall(RestoreArray(Y, &ay));
272:     PetscFunctionReturn(PETSC_SUCCESS);
273:   }

275:   /* MatDiagonalScale: Thrust-based left and right scaling of CSR values */
276:   static PetscErrorCode DiagonalScale(Mat A, Vec ll, Vec rr) noexcept
277:   {
278:     Mat_SeqAIJ    *aij = (Mat_SeqAIJ *)A->data;
279:     MatStructType *devstruct;
280:     CsrMatrix     *csr;
281:     PetscScalar   *av = nullptr;
282:     PetscInt       m, n, nz = aij->nz;
283:     cupmStream_t   stream;

285:     PetscFunctionBegin;
286:     PetscCall(GetHandles_(&stream));
287:     PetscCall(PetscLogGpuTimeBegin());
288:     PetscCall(GetArray(A, &av));
289:     devstruct = (MatStructType *)A->spptr;
290:     csr       = (CsrMatrix *)devstruct->mat->mat;
291:     if (ll) {
292:       const PetscScalar *lv;
293:       PetscCall(VecGetLocalSize(ll, &m));
294:       PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
295:       PetscCall(Policy::VecGetArrayRead(ll, &lv));
296:       {
297:         const PetscInt               *cprow   = devstruct->mat->cprowIndices ? devstruct->mat->cprowIndices->data().get() : NULL;
298:         DiagonalScaleLeft_CSR_Functor functor = {csr->row_offsets->data().get(), av, lv, cprow};
299:         PetscCallThrust(THRUST_CALL(thrust::for_each, stream, thrust::counting_iterator<PetscInt>(0), thrust::counting_iterator<PetscInt>(csr->num_rows), functor));
300:       }
301:       PetscCall(Policy::VecRestoreArrayRead(ll, &lv));
302:       PetscCall(PetscLogGpuFlops(nz));
303:     }
304:     if (rr) {
305:       const PetscScalar *rv;
306:       PetscCall(VecGetLocalSize(rr, &n));
307:       PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
308:       PetscCall(Policy::VecGetArrayRead(rr, &rv));
309: #if PetscDefined(USING_NVCC) && CCCL_VERSION >= 3001000
310:       PetscCallThrust(THRUST_CALL(thrust::transform, stream, csr->values->begin(), csr->values->end(), thrust::make_permutation_iterator(thrust::device_pointer_cast(rv), csr->column_indices->begin()), csr->values->begin(), cuda::std::multiplies<PetscScalar>()));
311: #else
312:       PetscCallThrust(THRUST_CALL(thrust::transform, stream, csr->values->begin(), csr->values->end(), thrust::make_permutation_iterator(thrust::device_pointer_cast(rv), csr->column_indices->begin()), csr->values->begin(), thrust::multiplies<PetscScalar>()));
313: #endif
314:       PetscCall(Policy::VecRestoreArrayRead(rr, &rv));
315:       PetscCall(PetscLogGpuFlops(nz));
316:     }
317:     PetscCall(RestoreArray(A, &av));
318:     PetscCall(PetscLogGpuTimeEnd());
319:     PetscFunctionReturn(PETSC_SUCCESS);
320:   }

322:   /* MatSeqAIJGetIJ: return device CSR row-pointer and column-index arrays */
323:   static PetscErrorCode GetIJ(Mat A, PetscBool compressed, const PetscInt **i, const PetscInt **j) noexcept
324:   {
325:     MatStructType *cusp = (MatStructType *)A->spptr;
326:     Mat_SeqAIJ    *a    = (Mat_SeqAIJ *)A->data;
327:     CsrMatrix     *csr;

329:     PetscFunctionBegin;
331:     if (!i || !j) PetscFunctionReturn(PETSC_SUCCESS);
332:     PetscCheckTypeName(A, Policy::mat_type_name);
333:     PetscCheck(cusp->format != (decltype(cusp->format))Policy::storage_format_ell() && cusp->format != (decltype(cusp->format))Policy::storage_format_hyb(), PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
334:     PetscCall(Policy::CopyToGPU(A));
335:     PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing MultStruct");
336:     csr = (CsrMatrix *)cusp->mat->mat;
337:     if (i) {
338:       if (!compressed && a->compressedrow.use) { /* need full row offset */
339:         if (!cusp->rowoffsets_gpu) {
340:           cusp->rowoffsets_gpu = new THRUSTINTARRAY(A->rmap->n + 1);
341:           cusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
342:           PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
343:         }
344:         *i = cusp->rowoffsets_gpu->data().get();
345:       } else *i = csr->row_offsets->data().get();
346:     }
347:     if (j) *j = csr->column_indices->data().get();
348:     PetscFunctionReturn(PETSC_SUCCESS);
349:   }

351:   /* MatSeqAIJRestoreIJ: nullify the pointers previously obtained with GetIJ */
352:   static PetscErrorCode RestoreIJ(Mat A, PetscBool compressed, const PetscInt **i, const PetscInt **j) noexcept
353:   {
354:     PetscFunctionBegin;
356:     PetscCheckTypeName(A, Policy::mat_type_name);
357:     if (i) *i = NULL;
358:     if (j) *j = NULL;
359:     (void)compressed;
360:     PetscFunctionReturn(PETSC_SUCCESS);
361:   }

363:   /* MatSetPreallocationCOO: copy COO bookkeeping struct to device */
364:   static PetscErrorCode SetPreallocationCOO(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) noexcept
365:   {
366:     PetscBool            dev_ij = PETSC_FALSE;
367:     PetscMemType         mtype  = PETSC_MEMTYPE_HOST;
368:     PetscInt            *i, *j;
369:     PetscContainer       container_h;
370:     MatCOOStruct_SeqAIJ *coo_h, *coo_d;

372:     PetscFunctionBegin;
373:     PetscCall(PetscGetMemType(coo_i, &mtype));
374:     if (PetscMemTypeDevice(mtype)) {
375:       dev_ij = PETSC_TRUE;
376:       PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
377:       PetscCallCUPM(cupmMemcpy(i, coo_i, coo_n * sizeof(PetscInt), cupmMemcpyDeviceToHost));
378:       PetscCallCUPM(cupmMemcpy(j, coo_j, coo_n * sizeof(PetscInt), cupmMemcpyDeviceToHost));
379:     } else {
380:       i = coo_i;
381:       j = coo_j;
382:     }
383:     PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, i, j));
384:     if (dev_ij) PetscCall(PetscFree2(i, j));
385:     mat->offloadmask = PETSC_OFFLOAD_CPU;
386:     /* Create the GPU memory */
387:     PetscCall(Policy::CopyToGPU(mat));

389:     /* Copy the COO struct to device */
390:     PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
391:     PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
392:     PetscCall(PetscMalloc1(1, &coo_d));
393:     *coo_d = *coo_h; /* shallow copy; device fields amended below */
394:     PetscCallCUPM(cupmMalloc((void **)&coo_d->jmap, (coo_h->nz + 1) * sizeof(PetscCount)));
395:     PetscCallCUPM(cupmMemcpy(coo_d->jmap, coo_h->jmap, (coo_h->nz + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
396:     PetscCallCUPM(cupmMalloc((void **)&coo_d->perm, coo_h->Atot * sizeof(PetscCount)));
397:     PetscCallCUPM(cupmMemcpy(coo_d->perm, coo_h->perm, coo_h->Atot * sizeof(PetscCount), cupmMemcpyHostToDevice));

399:     PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatSeqAIJCUSPARSE_CUPM::COOStructDestroy));
400:     PetscFunctionReturn(PETSC_SUCCESS);
401:   }

403:   /* MatSetValuesCOO: launch MatAddCOOValues kernel */
404:   static PetscErrorCode SetValuesCOO(Mat A, const PetscScalar v[], InsertMode imode) noexcept
405:   {
406:     Mat_SeqAIJ          *seq  = (Mat_SeqAIJ *)A->data;
407:     MatStructType       *dev  = (MatStructType *)A->spptr;
408:     PetscCount           Annz = seq->nz;
409:     PetscMemType         memtype;
410:     const PetscScalar   *v1 = v;
411:     PetscScalar         *Aa = nullptr;
412:     PetscContainer       container;
413:     MatCOOStruct_SeqAIJ *coo;
414:     cupmStream_t         stream;

416:     PetscFunctionBegin;
417:     if (!dev->mat) PetscCall(Policy::CopyToGPU(A));

419:     PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
420:     PetscCall(PetscContainerGetPointer(container, (void **)&coo));

422:     PetscCall(PetscGetMemType(v, &memtype));
423:     if (PetscMemTypeHost(memtype)) { /* copy host values to device */
424:       PetscCallCUPM(cupmMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
425:       PetscCallCUPM(cupmMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), cupmMemcpyHostToDevice));
426:     }

428:     if (imode == INSERT_VALUES) PetscCall(GetArrayWrite(A, &Aa));
429:     else PetscCall(GetArray(A, &Aa));

431:     PetscCall(GetHandles_(&stream));
432:     PetscCall(PetscLogGpuTimeBegin());
433:     if (Annz) {
434:       PetscCallCUPM(cupmLaunchKernel(MatAddCOOValues, (unsigned int)((Annz + 255) / 256), 256u, (size_t)0, stream, v1, Annz, coo->jmap, coo->perm, imode, Aa));
435:       PetscCallCUPM(cupmGetLastError());
436:     }
437:     PetscCall(PetscLogGpuTimeEnd());

439:     if (imode == INSERT_VALUES) PetscCall(RestoreArrayWrite(A, &Aa));
440:     else PetscCall(RestoreArray(A, &Aa));

442:     if (PetscMemTypeHost(memtype)) {
443:       void *v1_device = (void *)v1;
444:       PetscCallCUPM(cupmFree(v1_device));
445:     }
446:     PetscFunctionReturn(PETSC_SUCCESS);
447:   }

449:   /* MatSeqAIJCopySubArray: scatter-gather a sub-array of CSR values */
450:   static PetscErrorCode CopySubArray(Mat A, PetscInt n, const PetscInt idx[], PetscScalar v[]) noexcept
451:   {
452:     const PetscScalar *av = nullptr;
453:     PetscMemType       mtype;
454:     PetscBool          dmem;

456:     PetscFunctionBegin;
457:     PetscCall(PetscCUPMGetMemType(v, &mtype));
458:     dmem = PetscMemTypeDevice(mtype);
459:     PetscCall(GetArrayRead(A, &av));
460:     if (n && idx) {
461:       THRUSTINTARRAY widx(n);
462:       widx.assign(idx, idx + n);
463:       PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));

465:       THRUSTARRAY                    *w = NULL;
466:       thrust::device_ptr<PetscScalar> dv;
467:       if (dmem) {
468:         dv = thrust::device_pointer_cast(v);
469:       } else {
470:         w  = new THRUSTARRAY(n);
471:         dv = w->data();
472:       }
473:       {
474:         thrust::device_ptr<const PetscScalar> dav   = thrust::device_pointer_cast(av);
475:         auto                                  zibit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.begin()), dv));
476:         auto                                  zieit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.end()), dv + n));
477:         PetscCallThrust(thrust::for_each(zibit, zieit, VecCUPMEquals{}));
478:       }
479:       if (w) PetscCallCUPM(cupmMemcpy(v, w->data().get(), n * sizeof(PetscScalar), cupmMemcpyDeviceToHost));
480:       delete w;
481:     } else {
482:       PetscCallCUPM(cupmMemcpy(v, av, n * sizeof(PetscScalar), dmem ? cupmMemcpyDeviceToDevice : cupmMemcpyDeviceToHost));
483:     }
484:     if (!dmem) PetscCall(PetscLogCpuToGpu(n * sizeof(PetscScalar)));
485:     PetscCall(RestoreArrayRead(A, &av));
486:     PetscFunctionReturn(PETSC_SUCCESS);
487:   }

489:   /* -------------------------------------------------------------------
490:      Tier 3 - AXPY shared branches (SAME_NZ and DIFFERENT_NZ only).
491:      The SUBSET_NZ branch calls cuSPARSE/hipSPARSE and stays in the caller.
492:      ------------------------------------------------------------------- */

494:   /* AXPY SAME_NONZERO_PATTERN branch: cupmBlasXaxpy */
495:   static PetscErrorCode AXPY_SameNZ(Mat Y, PetscScalar a, Mat X) noexcept
496:   {
497:     Mat_SeqAIJ        *x  = (Mat_SeqAIJ *)X->data;
498:     const PetscScalar *ax = nullptr;
499:     PetscScalar       *ay = nullptr;
500:     cupmBlasHandle_t   blashandle;
501:     PetscBLASInt       one = 1, bnz = 1;

503:     PetscFunctionBegin;
504:     PetscCall(GetArrayRead(X, &ax));
505:     PetscCall(GetArray(Y, &ay));
506:     PetscCall(GetHandles_(&blashandle));
507:     PetscCall(PetscBLASIntCast(x->nz, &bnz));
508:     PetscCall(PetscLogGpuTimeBegin());
509:     PetscCallCUPMBLAS(cupmBlasXaxpy(blashandle, bnz, cupmScalarPtrCast(&a), cupmScalarPtrCast(ax), one, cupmScalarPtrCast(ay), one));
510:     PetscCall(PetscLogGpuFlops(2.0 * bnz));
511:     PetscCall(PetscLogGpuTimeEnd());
512:     PetscCall(RestoreArrayRead(X, &ax));
513:     PetscCall(RestoreArray(Y, &ay));
514:     PetscFunctionReturn(PETSC_SUCCESS);
515:   }

517:   /* GetDiagonal: kernel-based extraction of the CSR diagonal */
518:   static PetscErrorCode GetDiagonal(Mat A, Vec diag) noexcept
519:   {
520:     MatStructType  *devstruct = (MatStructType *)A->spptr;
521:     MultStructType *matstruct = (MultStructType *)devstruct->mat;
522:     PetscScalar    *darray;
523:     cupmStream_t    stream;

525:     PetscFunctionBegin;
526:     if (A->offloadmask == PETSC_OFFLOAD_BOTH || A->offloadmask == PETSC_OFFLOAD_GPU) {
527:       PetscInt   n   = A->rmap->n;
528:       CsrMatrix *mat = (CsrMatrix *)matstruct->mat;

530:       PetscCheck(devstruct->format == (decltype(devstruct->format))Policy::storage_format_csr(), PETSC_COMM_SELF, PETSC_ERR_SUP, "Only CSR format supported");
531:       if (n > 0) {
532:         PetscCall(Policy::VecGetArrayWrite(diag, &darray));
533:         PetscCall(GetHandles_(&stream));
534:         PetscCallCUPM(cupmLaunchKernel(GetDiagonal_CSR, (unsigned int)((n + 255) / 256), 256u, (size_t)0, stream, mat->row_offsets->data().get(), mat->column_indices->data().get(), mat->values->data().get(), n, darray));
535:         PetscCallCUPM(cupmGetLastError());
536:         PetscCall(Policy::VecRestoreArrayWrite(diag, &darray));
537:       }
538:     } else {
539:       PetscCall(MatGetDiagonal_SeqAIJ(A, diag));
540:     }
541:     PetscFunctionReturn(PETSC_SUCCESS);
542:   }

544:   /* -------------------------------------------------------------------
545:      Tier 4 - Device array access (moved here from vendor files so both
546:      SeqAIJCUSPARSE and SeqAIJHIPSPARSE share one implementation).
547:      ------------------------------------------------------------------- */

549:   /* GetArrayRead: read-only access to device CSR value array */
550:   static PetscErrorCode GetArrayRead(Mat A, const PetscScalar **a) noexcept
551:   {
552:     MatStructType *cusp = (MatStructType *)A->spptr;
553:     CsrMatrix     *csr;

555:     PetscFunctionBegin;
557:     PetscAssertPointer(a, 2);
558:     PetscCheckTypeName(A, Policy::mat_type_name);
559:     PetscCheck(cusp->format != (decltype(cusp->format))Policy::storage_format_ell() && cusp->format != (decltype(cusp->format))Policy::storage_format_hyb(), PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
560:     PetscCall(Policy::CopyToGPU(A));
561:     PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing MultStruct");
562:     csr = (CsrMatrix *)cusp->mat->mat;
563:     PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing device memory");
564:     *a = csr->values->data().get();
565:     PetscFunctionReturn(PETSC_SUCCESS);
566:   }

568:   /* RestoreArrayRead: release read-only access obtained from GetArrayRead */
569:   static PetscErrorCode RestoreArrayRead(Mat A, const PetscScalar **a) noexcept
570:   {
571:     PetscFunctionBegin;
573:     PetscAssertPointer(a, 2);
574:     PetscCheckTypeName(A, Policy::mat_type_name);
575:     *a = NULL;
576:     PetscFunctionReturn(PETSC_SUCCESS);
577:   }

579:   /* GetArray: read-write access to device CSR value array */
580:   static PetscErrorCode GetArray(Mat A, PetscScalar **a) noexcept
581:   {
582:     MatStructType *cusp = (MatStructType *)A->spptr;
583:     CsrMatrix     *csr;

585:     PetscFunctionBegin;
587:     PetscAssertPointer(a, 2);
588:     PetscCheckTypeName(A, Policy::mat_type_name);
589:     PetscCheck(cusp->format != (decltype(cusp->format))Policy::storage_format_ell() && cusp->format != (decltype(cusp->format))Policy::storage_format_hyb(), PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
590:     PetscCall(Policy::CopyToGPU(A));
591:     PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing MultStruct");
592:     csr = (CsrMatrix *)cusp->mat->mat;
593:     PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing device memory");
594:     *a             = csr->values->data().get();
595:     A->offloadmask = PETSC_OFFLOAD_GPU;
596:     PetscCall(Policy::InvalidateTranspose(A, PETSC_FALSE));
597:     PetscFunctionReturn(PETSC_SUCCESS);
598:   }

600:   /* RestoreArray: restore read-write access obtained from GetArray */
601:   static PetscErrorCode RestoreArray(Mat A, PetscScalar **a) noexcept
602:   {
603:     PetscFunctionBegin;
605:     PetscAssertPointer(a, 2);
606:     PetscCheckTypeName(A, Policy::mat_type_name);
607:     PetscCall(PetscObjectStateIncrease((PetscObject)A));
608:     *a = NULL;
609:     PetscFunctionReturn(PETSC_SUCCESS);
610:   }

612:   /* GetArrayWrite: write-only access to device CSR value array (no host-to-device copy) */
613:   static PetscErrorCode GetArrayWrite(Mat A, PetscScalar **a) noexcept
614:   {
615:     MatStructType *cusp = (MatStructType *)A->spptr;
616:     CsrMatrix     *csr;

618:     PetscFunctionBegin;
620:     PetscAssertPointer(a, 2);
621:     PetscCheckTypeName(A, Policy::mat_type_name);
622:     PetscCheck(cusp->format != (decltype(cusp->format))Policy::storage_format_ell() && cusp->format != (decltype(cusp->format))Policy::storage_format_hyb(), PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
623:     PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing MultStruct");
624:     csr = (CsrMatrix *)cusp->mat->mat;
625:     PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing device memory");
626:     *a             = csr->values->data().get();
627:     A->offloadmask = PETSC_OFFLOAD_GPU;
628:     PetscCall(Policy::InvalidateTranspose(A, PETSC_FALSE));
629:     PetscFunctionReturn(PETSC_SUCCESS);
630:   }

632:   /* RestoreArrayWrite: restore write-only access obtained from GetArrayWrite */
633:   static PetscErrorCode RestoreArrayWrite(Mat A, PetscScalar **a) noexcept
634:   {
635:     PetscFunctionBegin;
637:     PetscAssertPointer(a, 2);
638:     PetscCheckTypeName(A, Policy::mat_type_name);
639:     PetscCall(PetscObjectStateIncrease((PetscObject)A));
640:     *a = NULL;
641:     PetscFunctionReturn(PETSC_SUCCESS);
642:   }

644:   /* SeqAIJGetArray: copy GPU-to-CPU then return host value array (ops->getarray) */
645:   static PetscErrorCode SeqAIJGetArray(Mat A, PetscScalar *array[]) noexcept
646:   {
647:     PetscFunctionBegin;
648:     PetscCall(Policy::CopyFromGPU(A));
649:     *array = ((Mat_SeqAIJ *)A->data)->a;
650:     PetscFunctionReturn(PETSC_SUCCESS);
651:   }

653:   /* SeqAIJRestoreArray: mark matrix data CPU-valid (ops->restorearray) */
654:   static PetscErrorCode SeqAIJRestoreArray(Mat A, PetscScalar *array[]) noexcept
655:   {
656:     PetscFunctionBegin;
657:     A->offloadmask = PETSC_OFFLOAD_CPU;
658:     *array         = NULL;
659:     PetscFunctionReturn(PETSC_SUCCESS);
660:   }

662:   /* SeqAIJGetArrayRead: copy GPU-to-CPU then return host value array read-only (ops->getarrayread) */
663:   static PetscErrorCode SeqAIJGetArrayRead(Mat A, const PetscScalar *array[]) noexcept
664:   {
665:     PetscFunctionBegin;
666:     PetscCall(Policy::CopyFromGPU(A));
667:     *array = ((Mat_SeqAIJ *)A->data)->a;
668:     PetscFunctionReturn(PETSC_SUCCESS);
669:   }

671:   /* SeqAIJRestoreArrayRead: release read-only host array (ops->restorearrayread) */
672:   static PetscErrorCode SeqAIJRestoreArrayRead(Mat /*A*/, const PetscScalar *array[]) noexcept
673:   {
674:     PetscFunctionBegin;
675:     *array = NULL;
676:     PetscFunctionReturn(PETSC_SUCCESS);
677:   }

679:   /* SeqAIJGetArrayWrite: return host value array for write-only access (ops->getarraywrite) */
680:   static PetscErrorCode SeqAIJGetArrayWrite(Mat A, PetscScalar *array[]) noexcept
681:   {
682:     PetscFunctionBegin;
683:     *array = ((Mat_SeqAIJ *)A->data)->a;
684:     PetscFunctionReturn(PETSC_SUCCESS);
685:   }

687:   /* SeqAIJRestoreArrayWrite: mark matrix data CPU-valid after write (ops->restorearraywrite) */
688:   static PetscErrorCode SeqAIJRestoreArrayWrite(Mat A, PetscScalar *array[]) noexcept
689:   {
690:     PetscFunctionBegin;
691:     A->offloadmask = PETSC_OFFLOAD_CPU;
692:     *array         = NULL;
693:     PetscFunctionReturn(PETSC_SUCCESS);
694:   }

696:   /* CreateSeqAIJ: allocate and preallocate a seq sparse matrix of this type */
697:   static PetscErrorCode CreateSeqAIJ(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A) noexcept
698:   {
699:     PetscFunctionBegin;
700:     PetscCall(MatCreate(comm, A));
701:     PetscCall(MatSetSizes(*A, m, n, m, n));
702:     PetscCall(MatSetType(*A, Policy::mat_type_name));
703:     PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
704:     PetscFunctionReturn(PETSC_SUCCESS);
705:   }

707:   /* MatDestroy: free vendor-specific state, deregister composed functions */
708:   static PetscErrorCode Destroy(Mat A) noexcept
709:   {
710:     PetscFunctionBegin;
711:     if (A->factortype == MAT_FACTOR_NONE) PetscCall(Policy::Destroy(A));
712:     else PetscCall(Policy::TriFactorsDestroy(&A->spptr));
713:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
714:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::set_format_c, NULL));
715:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::set_use_cpu_solve_c, NULL));
716:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::product_seqdense_device_c, NULL));
717:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::product_seqdense_c, NULL));
718:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::product_self_c, NULL));
719:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
720:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
721:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
722:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::seq_convert_hypre_c, NULL));
723:     PetscCall(MatDestroy_SeqAIJ(A));
724:     PetscFunctionReturn(PETSC_SUCCESS);
725:   }
726: };

728: } // namespace impl

730: } // namespace cupm

732: } // namespace aij

734: } // namespace mat

736: } // namespace Petsc