Actual source code: mpiaijkok.kokkos.cxx

  1: #include <petsc_kokkos.hpp>
  2: #include <petscvec_kokkos.hpp>
  3: #include <petscpkg_version.h>
  4: #include <petsc/private/sfimpl.h>
  5: #include <petsc/private/kokkosimpl.hpp>
  6: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
  7: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  8: #include <KokkosSparse_spadd.hpp>
  9: #include <KokkosSparse_spgemm.hpp>

 11: static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
 12: {
 13:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;

 15:   PetscFunctionBegin;
 16:   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
 17:   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
 18:      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
 19:    */
 20:   if (mode == MAT_FINAL_ASSEMBLY) {
 21:     PetscScalarKokkosView v;

 23:     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
 24:     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
 25:     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));  // lvec is init'ed on host, without copying to device
 26:     PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
 27:     PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
 28:   }
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
 33: {
 34:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;

 36:   PetscFunctionBegin;
 37:   // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
 38:   if (mat->hash_active) {
 39:     mat->ops[0]      = mpiaij->cops;
 40:     mat->hash_active = PETSC_FALSE;
 41:   }

 43:   PetscCall(PetscLayoutSetUp(mat->rmap));
 44:   PetscCall(PetscLayoutSetUp(mat->cmap));
 45: #if defined(PETSC_USE_DEBUG)
 46:   if (d_nnz) {
 47:     PetscInt i;
 48:     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
 49:   }
 50:   if (o_nnz) {
 51:     PetscInt i;
 52:     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
 53:   }
 54: #endif
 55: #if defined(PETSC_USE_CTABLE)
 56:   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
 57: #else
 58:   PetscCall(PetscFree(mpiaij->colmap));
 59: #endif
 60:   PetscCall(PetscFree(mpiaij->garray));
 61:   PetscCall(VecDestroy(&mpiaij->lvec));
 62:   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
 63:   /* Because the B will have been resized we simply destroy it and create a new one each time */
 64:   PetscCall(MatDestroy(&mpiaij->B));

 66:   if (!mpiaij->A) {
 67:     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
 68:     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
 69:   }
 70:   if (!mpiaij->B) {
 71:     PetscMPIInt size;
 72:     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
 73:     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
 74:     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
 75:   }
 76:   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
 77:   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
 78:   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
 79:   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
 80:   mat->preallocated = PETSC_TRUE;
 81:   PetscFunctionReturn(PETSC_SUCCESS);
 82: }

 84: static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
 85: {
 86:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
 87:   PetscInt    nt;

 89:   PetscFunctionBegin;
 90:   PetscCall(VecGetLocalSize(xx, &nt));
 91:   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
 92:   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
 93:   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
 94:   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
 95:   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
 96:   PetscFunctionReturn(PETSC_SUCCESS);
 97: }

 99: static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
100: {
101:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
102:   PetscInt    nt;

104:   PetscFunctionBegin;
105:   PetscCall(VecGetLocalSize(xx, &nt));
106:   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
107:   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
108:   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
109:   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
110:   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
111:   PetscFunctionReturn(PETSC_SUCCESS);
112: }

114: static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
115: {
116:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
117:   PetscInt    nt;

119:   PetscFunctionBegin;
120:   PetscCall(VecGetLocalSize(xx, &nt));
121:   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
122:   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
123:   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
124:   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
125:   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
126:   PetscFunctionReturn(PETSC_SUCCESS);
127: }

129: /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
130:    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
131:    C still uses local column ids. Their corresponding global column ids are returned in glob.
132: */
133: static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
134: {
135:   Mat             Ad, Ao;
136:   const PetscInt *cmap;

138:   PetscFunctionBegin;
139:   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
140:   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
141:   if (glob) {
142:     PetscInt cst, i, dn, on, *gidx;
143:     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
144:     PetscCall(MatGetLocalSize(Ao, NULL, &on));
145:     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
146:     PetscCall(PetscMalloc1(dn + on, &gidx));
147:     for (i = 0; i < dn; i++) gidx[i] = cst + i;
148:     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
149:     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
150:   }
151:   PetscFunctionReturn(PETSC_SUCCESS);
152: }

154: /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
155: struct MatMatStruct {
156:   PetscInt            n, *garray;     // C's garray and its size.
157:   KokkosCsrMatrix     Cd, Co;         // C is in split form matrices (all in local column indcies)
158:   KokkosCsrMatrix     C1, C2, C3, C4; // intermediate mat products
159:   KokkosCsrMatrix     C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
160:   PetscIntKokkosView  E_NzLeft;
161:   PetscSF             sf = nullptr; // SF to bcast or reduce matrices E to F
162:   MatScalarKokkosView rootBuf, leafBuf;
163:   KokkosCsrMatrix     Fd, Fo; // F in split form

165:   KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
166:   KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
167:   KernelHandle kh3; // compute C3
168:   KernelHandle kh4; // compute C4

170:   PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
171:   PetscInt E_VectorLength;
172:   PetscInt E_RowsPerTeam;
173:   PetscInt F_TeamSize;
174:   PetscInt F_VectorLength;
175:   PetscInt F_RowsPerTeam;

177:   ~MatMatStruct()
178:   {
179:     PetscFunctionBegin;
180:     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
181:     PetscFunctionReturnVoid();
182:   }
183: };

185: struct MatMatStruct_AB : public MatMatStruct {
186:   PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
187:   PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
188:   PetscIntKokkosView rowoffset;
189: };

191: struct MatMatStruct_AtB : public MatMatStruct {
192:   MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
193:   MatColIdxKokkosView Fdjperm;
194:   MatColIdxKokkosView Fojmap;
195:   MatColIdxKokkosView Fojperm;
196: };

198: struct MatProductData_MPIAIJKokkos {
199:   MatMatStruct_AB  *mmAB     = nullptr;
200:   MatMatStruct_AtB *mmAtB    = nullptr;
201:   PetscBool         reusesym = PETSC_FALSE;
202:   Mat               Z        = nullptr; // store Z=AB in computing BtAB

204:   ~MatProductData_MPIAIJKokkos()
205:   {
206:     delete mmAB;
207:     delete mmAtB;
208:     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
209:   }
210: };

212: static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
213: {
214:   PetscFunctionBegin;
215:   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
216:   PetscFunctionReturn(PETSC_SUCCESS);
217: }

219: /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
220:    It is similar to MatCreateMPIAIJWithSplitArrays.

222:   Input Parameters:
223: +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
224: .  A     - the diag matrix using local col ids
225: -  B     - the offdiag matrix using global col ids

227:   Output Parameter:
228: .  mat   - the updated MATMPIAIJKOKKOS matrix
229: */
230: static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
231: {
232:   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
233:   PetscInt    m, n, M, N, Am, An, Bm, Bn;

235:   PetscFunctionBegin;
236:   PetscCall(MatGetSize(mat, &M, &N));
237:   PetscCall(MatGetLocalSize(mat, &m, &n));
238:   PetscCall(MatGetLocalSize(A, &Am, &An));
239:   PetscCall(MatGetLocalSize(B, &Bm, &Bn));

241:   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
242:   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
243:   // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
244:   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
245:   mpiaij->A      = A;
246:   mpiaij->B      = B;
247:   mpiaij->garray = garray;

249:   mat->preallocated     = PETSC_TRUE;
250:   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */

252:   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
253:   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
254:   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
255:     also gets mpiaij->B compacted, with its col ids and size reduced
256:   */
257:   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
258:   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
259:   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
260:   PetscFunctionReturn(PETSC_SUCCESS);
261: }

263: // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
264: // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
265: template <class ExecutionSpace>
266: static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
267: {
268:   Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);

270:   PetscFunctionBegin;
271:   PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices

273:   if (nnz_per_row < 1) nnz_per_row = 1;

275:   int max_vector_length = teamPolicy.vector_length_max();

277:   if (vector_length < 1) {
278:     vector_length = 1;
279:     while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
280:   }

282:   // Determine rows per thread
283:   if (rows_per_thread < 1) {
284:     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
285:     else {
286:       if (nnz_per_row < 20 && nnz > 5000000) {
287:         rows_per_thread = 256;
288:       } else rows_per_thread = 64;
289:     }
290:   }

292:   if (team_size < 1) {
293:     if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
294:       team_size = 256 / vector_length;
295:     } else {
296:       team_size = 1;
297:     }
298:   }

300:   rows_per_team = rows_per_thread * team_size;

302:   if (rows_per_team < 0) {
303:     PetscInt nnz_per_team = 4096;
304:     PetscInt conc         = ExecutionSpace().concurrency();
305:     while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
306:     rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
307:   }
308:   PetscFunctionReturn(PETSC_SUCCESS);
309: }

311: /*
312:   Reduce two sets of global indices into local ones

314:   Input Parameters:
315: +  n1          - size of garray1[], the first set
316: .  garray1[n1] - a sorted global index array (without duplicates)
317: .  m           - size of indices[], the second set
318: -  indices[m]  - a unsorted global index array (might have duplicates), which will be updated on output into local ones

320:   Output Parameters:
321: +  n2          - size of garray2[], the merged set, which combines garray1[] and indices[]
322: .  garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
323: .  map[n1]     - allocated by caller. It gives garray1[i] = garray2[map[i]]
324: -  indices[m]  - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]

326:    Example, say
327:     n1         = 5
328:     garray1[5] = {1, 4, 7, 8, 10}
329:     m          = 4
330:     indices[4] = {2, 4, 8, 9}

332:    Combining them together, we have 7 global indices in garray2[]
333:     n2         = 7
334:     garray2[7] = {1, 2, 4, 7, 8, 9, 10}

336:    And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
337:     map[5] = {0, 2, 3, 4, 6}

339:    On output, indices[] is updated with local indices
340:     indices[4] = {1, 2, 4, 5}
341: */
342: static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
343: {
344:   PetscHMapI    g2l = nullptr;
345:   PetscHashIter iter;
346:   PetscInt      tot, key, val; // total unique global indices. key is global id; val is local id
347:   PetscInt      n2, *garray2;

349:   PetscFunctionBegin;
350:   tot = 0;
351:   PetscCall(PetscHMapICreateWithSize(n1, &g2l));
352:   for (PetscInt i = 0; i < m; i++) {                                // insert those in indices[]
353:     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
354:     if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++));  // val < 0 means gid is not in the hash table yet
355:   }

357:   for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
358:     PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
359:     if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
360:   }

362:   // Pull out (unique) globals in the hash table and put them in garray2[]
363:   n2 = tot;
364:   PetscCall(PetscMalloc1(n2, &garray2));
365:   tot = 0;
366:   PetscHashIterBegin(g2l, iter);
367:   while (!PetscHashIterAtEnd(g2l, iter)) {
368:     PetscHashIterGetKey(g2l, iter, key);
369:     PetscHashIterNext(g2l, iter);
370:     garray2[tot++] = key;
371:   }

373:   // Sort garray2[] and then map them to local indices starting from 0
374:   PetscCall(PetscSortInt(n2, garray2));
375:   PetscCall(PetscHMapIClear(g2l));
376:   for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id

378:   // Rewrite indices[] with local indices
379:   for (PetscInt i = 0; i < m; i++) {
380:     PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
381:     PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
382:     indices[i] = val;
383:   }
384:   // Record the map that maps garray1[i] to garray2[map[i]]
385:   for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
386:   PetscCall(PetscHMapIDestroy(&g2l));
387:   *n2_      = n2;
388:   *garray2_ = garray2;
389:   PetscFunctionReturn(PETSC_SUCCESS);
390: }

392: /*
393:   MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)

395:   It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.

397:   Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
398:   In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.

400:   Input Parameters:
401: +  comm       - MPI communicator of E
402: .  A          - diag block of E, using local column indices
403: .  B          - off-diag block of E, using local column indices
404: .  cstart      - (global) start column of Ed
405: .  cend        - (global) end column + 1 of Ed.  In other words, E's column ownership is in range of [cstart, cend)
406: .  garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
407: .  ownerSF     - the SF specifies ownership (root) of rows in E
408: .  reuse       - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
409: -  mm          - to stash intermediate data structures for reuse

411:   Output Parameters:
412: +  map[n1]  - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
413: -  mm       - contains various info, such as garray2[], F (Fd, Fo) etc.

415:   Notes:
416:   When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.

418:  */
419: static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
420: {
421:   PetscFunctionBegin;
422:   if (reuse == MAT_INITIAL_MATRIX) {
423:     PetscInt Em = A.numRows(), Fm;
424:     PetscInt n1 = B.numCols();

426:     PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF

428:     // Do the analysis on host
429:     auto                 Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
430:     auto                 Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
431:     auto                 Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
432:     auto                 Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
433:     const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
434:     const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();

436:     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
437:     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
438:     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
439:     for (PetscInt i = 0; i < Em; i++) {
440:       const PetscInt *first, *last, *it;
441:       PetscInt        count, step;
442:       // std::lower_bound(first,last,cstart), but need to use global column indices
443:       first = Bj + Bi[i];
444:       last  = Bj + Bi[i + 1];
445:       count = last - first;
446:       while (count > 0) {
447:         it   = first;
448:         step = count / 2;
449:         it += step;
450:         if (garray1[*it] < cstart) { // map local to global
451:           first = ++it;
452:           count -= step + 1;
453:         } else count = step;
454:       }
455:       E_NzLeft[i] = first - (Bj + Bi[i]);
456:       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
457:     }

459:     // Get length of rows (i.e., sizes of leaves) that contribute to my roots
460:     const PetscMPIInt *iranks, *ranks;
461:     const PetscInt    *ioffset, *irootloc, *roffset, *rmine;
462:     PetscInt           niranks, nranks;
463:     MPI_Request       *reqs;
464:     PetscMPIInt        tag;
465:     PetscSF            reduceSF;
466:     PetscInt          *sdisp, *rdisp;

468:     PetscCall(PetscCommGetNewTag(comm, &tag));
469:     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));  // get leaf ranks connecting to roots on this process (I'll recv from them)
470:     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)

472:     // Find out length of each row I will receive. Even for the same row index, when they are from
473:     // different senders, they might have different lengths (and sparsity patterns)
474:     PetscInt  sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
475:     PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process

477:     PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));

479:     for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
480:     recvRowLen[0] = 0; // since we will make it in CSR format later
481:     recvRowLen++;      // advance the pointer now
482:     for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
483:     for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
484:     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));

486:     // Build the real PetscSF for reducing E rows (buffer to buffer)
487:     rdisp[0] = 0;
488:     for (PetscInt i = 0; i < niranks; i++) {
489:       rdisp[i + 1] = rdisp[i];
490:       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
491:     }
492:     recvRowLen--; // put it back into csr format
493:     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];

495:     for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
496:     for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
497:     PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));

499:     PetscInt     nleaves = 0, Enz = 0;    // leaves are nonzeros I will send
500:     PetscInt     nroots = rdisp[niranks]; // roots are nonzeros I will recv
501:     PetscSFNode *iremote;

503:     for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
504:     PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
505:     PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF

507:     for (PetscInt i = 0; i < nranks; i++) {
508:       PetscInt count = 0;
509:       for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
510:       for (PetscInt j = 0; j < count; j++) {
511:         iremote[nleaves + j].rank  = ranks[i];
512:         iremote[nleaves + j].index = sdisp[i] + j;
513:       }
514:       nleaves += count;
515:     }
516:     PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");

518:     PetscCall(PetscSFCreate(comm, &reduceSF));
519:     PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));

521:     // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
522:     PetscInt *sendCol, *recvCol;
523:     PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
524:     for (PetscInt k = 0; k < roffset[nranks]; k++) {
525:       PetscInt  i      = rmine[k]; // row to be copied
526:       PetscInt *buf    = &sendCol[Ai[i] + Bi[i]];
527:       PetscInt  nzLeft = E_NzLeft[i];
528:       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
529:       for (PetscInt j = 0; j < alen + blen; j++) {
530:         if (j < nzLeft) {
531:           buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
532:         } else if (j < nzLeft + alen) {
533:           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
534:         } else {
535:           buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
536:         }
537:       }
538:     }
539:     PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
540:     PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));

542:     // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
543:     PetscInt *recvRowPerm, *recvColSorted;
544:     PetscInt *recvNzPerm, *recvNzPermSorted;
545:     PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));

547:     for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i;                   // numbering all received nonzeros
548:     for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i;              // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
549:     PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed

551:     // i[] array, nz are always easiest to compute
552:     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
553:     MatRowMapType          *Fdi, *Foi;
554:     PetscInt                FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
555:     PetscInt                iter;

557:     Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
558:     Kokkos::deep_copy(Foi_h, 0);
559:     Fdi  = Fdi_h.data() + 1; // +1 for easy indexing in code below
560:     Foi  = Foi_h.data() + 1;
561:     iter = 0;
562:     while (iter < recvRowCnt) { // iter over received rows
563:       PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
564:       PetscInt dupRows   = 1; // current row has this many contributing rows (of various sparsity patterns)

566:       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;

568:       // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
569:       PetscInt  nz    = 0; // nz (with dups) in the current row
570:       PetscInt *jbuf  = recvColSorted + FnzDups;
571:       PetscInt *pbuf  = recvNzPermSorted + FnzDups;
572:       PetscInt *jbuf2 = jbuf; // temp pointers
573:       PetscInt *pbuf2 = pbuf;
574:       for (PetscInt d = 0; d < dupRows; d++) {
575:         PetscInt i   = recvRowPerm[iter + d];
576:         PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
577:         PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
578:         PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
579:         jbuf2 += len;
580:         pbuf2 += len;
581:         nz += len;
582:       }
583:       PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted

585:       // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
586:       PetscInt cur = 0;
587:       while (cur < nz) {
588:         PetscInt curColIdx = jbuf[cur];
589:         PetscInt dups      = 1;

591:         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
592:         if (curColIdx >= cstart && curColIdx < cend) {
593:           Fdi[curRowIdx]++;
594:           FdnzDups += dups;
595:         } else {
596:           Foi[curRowIdx]++;
597:           FonzDups += dups;
598:         }
599:         cur += dups;
600:       }

602:       FnzDups += nz;
603:       iter += dupRows; // Move to next unique row
604:     }

606:     Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
607:     Foi = Foi_h.data();
608:     for (PetscInt i = 0; i < Fm; i++) {
609:       Fdi[i + 1] += Fdi[i];
610:       Foi[i + 1] += Foi[i];
611:     }
612:     Fdnz = Fdi[Fm];
613:     Fonz = Foi[Fm];
614:     PetscCall(PetscFree2(sendCol, recvCol));

616:     // Allocate j, jmap, jperm for Fd and Fo
617:     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
618:     MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
619:     MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
620:     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
621:     MatRowMapType          *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
622:     MatRowMapType          *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();

624:     // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
625:     Fdjmap[0] = 0;
626:     Fojmap[0] = 0;
627:     FnzDups   = 0;
628:     Fdnz      = 0;
629:     Fonz      = 0;
630:     iter      = 0; // iter over received rows
631:     while (iter < recvRowCnt) {
632:       PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
633:       PetscInt dupRows   = 1;                           // It has this many contributing rows (of various lengths)
634:       PetscInt nz        = 0;                           // nz (with dups) in the current row

636:       while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
637:       for (PetscInt d = 0; d < dupRows; d++) {
638:         PetscInt i = recvRowPerm[iter + d];
639:         nz += recvRowLen[i + 1] - recvRowLen[i];
640:       }

642:       PetscInt *jbuf = recvColSorted + FnzDups;
643:       // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
644:       PetscInt cur = 0;
645:       while (cur < nz) {
646:         PetscInt curColIdx = jbuf[cur];
647:         PetscInt dups      = 1;

649:         while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
650:         if (curColIdx >= cstart && curColIdx < cend) {
651:           Fdj[Fdnz]        = curColIdx - cstart; // easily convert to local
652:           Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
653:           for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
654:           FdnzDups += dups;
655:           Fdnz++;
656:         } else {
657:           Foj[Fonz]        = curColIdx; // in global
658:           Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
659:           for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
660:           FonzDups += dups;
661:           Fonz++;
662:         }
663:         cur += dups;
664:         FnzDups += dups;
665:       }
666:       iter += dupRows; // Move to next unique row
667:     }
668:     PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
669:     PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));

671:     // Combine global column indices in garray1[] and Foj[]
672:     PetscInt n2, *garray2;

674:     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
675:     mm->sf       = reduceSF;
676:     mm->leafBuf  = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
677:     mm->rootBuf  = MatScalarKokkosView(NoInit("rootBuf"), nroots);
678:     mm->garray   = garray2; // give ownership, so no free
679:     mm->n        = n2;
680:     mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
681:     mm->Fdjmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
682:     mm->Fdjperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
683:     mm->Fojmap   = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
684:     mm->Fojperm  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);

686:     // Output Fd and Fo in KokkosCsrMatrix format
687:     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
688:     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
689:     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
690:     MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
691:     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
692:     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);

694:     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
695:     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]

697:     // Compute kernel launch parameters in merging E
698:     PetscInt teamSize, vectorLength, rowsPerTeam;

700:     teamSize = vectorLength = rowsPerTeam = -1;
701:     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
702:     mm->E_TeamSize     = teamSize;
703:     mm->E_VectorLength = vectorLength;
704:     mm->E_RowsPerTeam  = rowsPerTeam;
705:   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);

707:   // Handy aliases
708:   auto       &Aa           = A.values;
709:   auto       &Ba           = B.values;
710:   const auto &Ai           = A.graph.row_map;
711:   const auto &Bi           = B.graph.row_map;
712:   const auto &E_NzLeft     = mm->E_NzLeft;
713:   auto       &leafBuf      = mm->leafBuf;
714:   auto       &rootBuf      = mm->rootBuf;
715:   PetscSF     reduceSF     = mm->sf;
716:   PetscInt    Em           = A.numRows();
717:   PetscInt    teamSize     = mm->E_TeamSize;
718:   PetscInt    vectorLength = mm->E_VectorLength;
719:   PetscInt    rowsPerTeam  = mm->E_RowsPerTeam;
720:   PetscInt    workSets     = (Em + rowsPerTeam - 1) / rowsPerTeam;

722:   // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
723:   PetscCallCXX(Kokkos::parallel_for(
724:     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
725:       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
726:         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
727:         if (i < Em) {
728:           PetscInt disp   = Ai(i) + Bi(i);
729:           PetscInt alen   = Ai(i + 1) - Ai(i);
730:           PetscInt blen   = Bi(i + 1) - Bi(i);
731:           PetscInt nzleft = E_NzLeft(i);

733:           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
734:             MatScalar &val = leafBuf(disp + j);
735:             if (j < nzleft) { // B left
736:               val = Ba(Bi(i) + j);
737:             } else if (j < nzleft + alen) { // diag A
738:               val = Aa(Ai(i) + j - nzleft);
739:             } else { // B right
740:               val = Ba(Bi(i) + j - alen);
741:             }
742:           });
743:         }
744:       });
745:     }));
746:   PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
747:   PetscFunctionReturn(PETSC_SUCCESS);
748: }

750: // To finish MatMPIAIJKokkosReduce.
751: static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
752: {
753:   auto       &leafBuf  = mm->leafBuf;
754:   auto       &rootBuf  = mm->rootBuf;
755:   auto       &Fda      = mm->Fd.values;
756:   const auto &Fdjmap   = mm->Fdjmap;
757:   const auto &Fdjperm  = mm->Fdjperm;
758:   auto        Fdnz     = mm->Fd.nnz();
759:   auto       &Foa      = mm->Fo.values;
760:   const auto &Fojmap   = mm->Fojmap;
761:   const auto &Fojperm  = mm->Fojperm;
762:   auto        Fonz     = mm->Fo.nnz();
763:   PetscSF     reduceSF = mm->sf;

765:   PetscFunctionBegin;
766:   PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));

768:   // Reduce data in rootBuf to Fd and Fo
769:   PetscCallCXX(Kokkos::parallel_for(
770:     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
771:       PetscScalar sum = 0.0;
772:       for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
773:       Fda(i) = sum;
774:     }));

776:   PetscCallCXX(Kokkos::parallel_for(
777:     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
778:       PetscScalar sum = 0.0;
779:       for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
780:       Foa(i) = sum;
781:     }));
782:   PetscFunctionReturn(PETSC_SUCCESS);
783: }

785: /*
786:   MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form

788:   This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
789:   device and involves various index mapping.

791:   In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
792:   Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
793:   to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
794:   F has the same column layout as E.

796:   Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
797:   Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
798:   Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
799:   column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
800:   column indices in Fo and update Fo with local indices.

802:    Input Parameters:
803: +   E       - the MPIAIJKOKKOS matrix
804: .   ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
805: .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
806: -   mm      - to stash matproduct intermediate data structures

808:     Output Parameters:
809: +   map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
810: -   mm      - contains various info, such as garray2[], Fd, Fo, etc.

812:     Notes:
813:     When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
814:     The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
815: */
816: static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
817: {
818:   Mat_MPIAIJ       *empi = static_cast<Mat_MPIAIJ *>(E->data);
819:   Mat               A = empi->A, B = empi->B; // diag and off-diag
820:   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
821:   PetscInt          Em = E->rmap->n; // #local rows
822:   MPI_Comm          comm;

824:   PetscFunctionBegin;
825:   PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
826:   if (reuse == MAT_INITIAL_MATRIX) {
827:     Mat_SeqAIJ     *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
828:     PetscInt        n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
829:     const PetscInt *garray1 = empi->garray; // its size is n1
830:     PetscInt        cstart, cend;
831:     PetscSF         bcastSF;

833:     PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));

835:     // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
836:     PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
837:     PetscInt              *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
838:     for (PetscInt i = 0; i < Em; i++) {
839:       const PetscInt *first, *last, *it;
840:       PetscInt        count, step;
841:       // std::lower_bound(first,last,cstart), but need to use global column indices
842:       first = Bj + Bi[i];
843:       last  = Bj + Bi[i + 1];
844:       count = last - first;
845:       while (count > 0) {
846:         it   = first;
847:         step = count / 2;
848:         it += step;
849:         if (empi->garray[*it] < cstart) { // map local to global
850:           first = ++it;
851:           count -= step + 1;
852:         } else count = step;
853:       }
854:       E_NzLeft[i] = first - (Bj + Bi[i]);
855:       E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
856:     }

858:     // Compute row pointer Fi of F
859:     PetscInt *Fi, Fm, Fnz;
860:     PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
861:     PetscCall(PetscMalloc1(Fm + 1, &Fi));
862:     Fi[0] = 0;
863:     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
864:     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
865:     for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
866:     Fnz = Fi[Fm];

868:     // Build the real PetscSF for bcasting E rows (buffer to buffer)
869:     const PetscMPIInt *iranks, *ranks;
870:     const PetscInt    *ioffset, *irootloc, *roffset;
871:     PetscInt           niranks, nranks, *sdisp, *rdisp;
872:     MPI_Request       *reqs;
873:     PetscMPIInt        tag;

875:     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
876:     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL));  // recv info
877:     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));

879:     sdisp[0] = 0; // send displacement
880:     for (PetscInt i = 0; i < niranks; i++) {
881:       sdisp[i + 1] = sdisp[i];
882:       for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
883:         PetscInt r = irootloc[j]; // row to be sent
884:         sdisp[i + 1] += E_RowLen[r];
885:       }
886:     }

888:     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
889:     for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
890:     for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
891:     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));

893:     PetscInt     nleaves = Fnz;            // leaves are nonzeros I will receive
894:     PetscInt     nroots  = sdisp[niranks]; // roots are nonzeros I will send
895:     PetscSFNode *iremote;                  // give ownership to bcastSF
896:     PetscCall(PetscMalloc1(nleaves, &iremote));
897:     for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
898:       PetscInt k = 0;
899:       for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
900:         iremote[j].rank  = ranks[i];
901:         iremote[j].index = rdisp[i] + k; // their root location
902:         k++;
903:       }
904:     }
905:     PetscCall(PetscSFCreate(comm, &bcastSF));
906:     PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
907:     PetscCall(PetscFree3(sdisp, rdisp, reqs));

909:     // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
910:     PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
911:     PetscInt              *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
912:     rowoffset[0]                     = 0;
913:     for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }

915:     // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
916:     PetscInt *jbuf, *Fj;
917:     PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
918:     for (PetscInt k = 0; k < ioffset[niranks]; k++) {
919:       PetscInt  i      = irootloc[k]; // row to be copied
920:       PetscInt *buf    = &jbuf[rowoffset[k]];
921:       PetscInt  nzLeft = E_NzLeft[i];
922:       PetscInt  alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
923:       for (PetscInt j = 0; j < alen + blen; j++) {
924:         if (j < nzLeft) {
925:           buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
926:         } else if (j < nzLeft + alen) {
927:           buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
928:         } else {
929:           buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
930:         }
931:       }
932:     }
933:     PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
934:     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));

936:     // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
937:     MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
938:     MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm);                           // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
939:     MatRowMapType          *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
940:     MatColIdxType          *F_NzLeft = F_NzLeft_h.data();

942:     Fdi[0] = Foi[0] = 0;
943:     for (PetscInt i = 0; i < Fm; i++) {
944:       PetscInt *first, *last, *lb1, *lb2;
945:       // cut the row into: Left, [cstart, cend), Right
946:       first       = Fj + Fi[i];
947:       last        = Fj + Fi[i + 1];
948:       lb1         = std::lower_bound(first, last, cstart);
949:       F_NzLeft[i] = lb1 - first;
950:       lb2         = std::lower_bound(first, last, cend);
951:       Fdi[i + 1]  = lb2 - lb1;                        // row i length in Fdi
952:       Foi[i + 1]  = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
953:     }
954:     for (PetscInt i = 0; i < Fm; i++) {
955:       Fdi[i + 1] += Fdi[i];
956:       Foi[i + 1] += Foi[i];
957:     }

959:     // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
960:     PetscInt                Fdnz = Fdi[Fm], Fonz = Foi[Fm];
961:     MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
962:     MatColIdxType          *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;

964:     for (PetscInt i = 0; i < Fm; i++) {
965:       PetscInt nzLeft = F_NzLeft[i];
966:       PetscInt len    = Fdi[i + 1] - Fdi[i]; // diag row len
967:       for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
968:         gid = Fj[Fi[i] + j];
969:         if (j < nzLeft) { // left, in global
970:           Foj[Foi[i] + j] = gid;
971:         } else if (j < nzLeft + len) { // diag, in local
972:           Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
973:         } else { // right, in global
974:           Foj[Foi[i] + j - len] = gid;
975:         }
976:       }
977:     }
978:     PetscCall(PetscFree2(jbuf, Fj));
979:     PetscCall(PetscFree(Fi));

981:     // Reduce global indices in Foj[] and garray1[] into local ones
982:     PetscInt n2, *garray2;
983:     PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));

985:     // Record the plans built above, for reuse
986:     PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
987:     PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
988:     Kokkos::deep_copy(irootloc_h, tmp);
989:     mm->sf        = bcastSF;
990:     mm->E_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
991:     mm->F_NzLeft  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
992:     mm->irootloc  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
993:     mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
994:     mm->rootBuf   = MatScalarKokkosView(NoInit("rootBuf"), nroots);
995:     mm->leafBuf   = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
996:     mm->garray    = garray2;
997:     mm->n         = n2;

999:     // Output Fd and Fo in KokkosCsrMatrix format
1000:     MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
1001:     MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
1002:     MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
1003:     MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
1004:     MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);

1006:     PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1007:     PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));

1009:     // Compute kernel launch parameters in merging E or splitting F
1010:     PetscInt teamSize, vectorLength, rowsPerTeam;

1012:     teamSize = vectorLength = rowsPerTeam = -1;
1013:     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1014:     mm->E_TeamSize     = teamSize;
1015:     mm->E_VectorLength = vectorLength;
1016:     mm->E_RowsPerTeam  = rowsPerTeam;

1018:     teamSize = vectorLength = rowsPerTeam = -1;
1019:     PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1020:     mm->F_TeamSize     = teamSize;
1021:     mm->F_VectorLength = vectorLength;
1022:     mm->F_RowsPerTeam  = rowsPerTeam;
1023:   } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);

1025:   // Sync E's value to device
1026:   akok->a_dual.sync_device();
1027:   bkok->a_dual.sync_device();

1029:   // Handy aliases
1030:   const auto &Aa = akok->a_dual.view_device();
1031:   const auto &Ba = bkok->a_dual.view_device();
1032:   const auto &Ai = akok->i_dual.view_device();
1033:   const auto &Bi = bkok->i_dual.view_device();

1035:   // Fetch the plans
1036:   PetscIntKokkosView  &E_NzLeft  = mm->E_NzLeft;
1037:   PetscSF             &bcastSF   = mm->sf;
1038:   MatScalarKokkosView &rootBuf   = mm->rootBuf;
1039:   MatScalarKokkosView &leafBuf   = mm->leafBuf;
1040:   PetscIntKokkosView  &irootloc  = mm->irootloc;
1041:   PetscIntKokkosView  &rowoffset = mm->rowoffset;

1043:   PetscInt teamSize     = mm->E_TeamSize;
1044:   PetscInt vectorLength = mm->E_VectorLength;
1045:   PetscInt rowsPerTeam  = mm->E_RowsPerTeam;
1046:   PetscInt workSets     = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;

1048:   // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1049:   PetscCallCXX(Kokkos::parallel_for(
1050:     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1051:       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1052:         size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1053:         if (r < irootloc.extent(0)) {
1054:           PetscInt i      = irootloc(r); // row i of E
1055:           PetscInt disp   = rowoffset(r);
1056:           PetscInt alen   = Ai(i + 1) - Ai(i);
1057:           PetscInt blen   = Bi(i + 1) - Bi(i);
1058:           PetscInt nzleft = E_NzLeft(i);

1060:           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1061:             if (j < nzleft) { // B left
1062:               rootBuf(disp + j) = Ba(Bi(i) + j);
1063:             } else if (j < nzleft + alen) { // diag A
1064:               rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1065:             } else { // B right
1066:               rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1067:             }
1068:           });
1069:         }
1070:       });
1071:     }));
1072:   PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1073:   PetscFunctionReturn(PETSC_SUCCESS);
1074: }

1076: // To finish MatMPIAIJKokkosBcast.
1077: static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1078: {
1079:   PetscFunctionBegin;
1080:   const auto &Fd  = mm->Fd;
1081:   const auto &Fo  = mm->Fo;
1082:   const auto &Fdi = Fd.graph.row_map;
1083:   const auto &Foi = Fo.graph.row_map;
1084:   auto       &Fda = Fd.values;
1085:   auto       &Foa = Fo.values;
1086:   auto        Fm  = Fd.numRows();

1088:   PetscIntKokkosView  &F_NzLeft     = mm->F_NzLeft;
1089:   PetscSF             &bcastSF      = mm->sf;
1090:   MatScalarKokkosView &rootBuf      = mm->rootBuf;
1091:   MatScalarKokkosView &leafBuf      = mm->leafBuf;
1092:   PetscInt             teamSize     = mm->F_TeamSize;
1093:   PetscInt             vectorLength = mm->F_VectorLength;
1094:   PetscInt             rowsPerTeam  = mm->F_RowsPerTeam;
1095:   PetscInt             workSets     = (Fm + rowsPerTeam - 1) / rowsPerTeam;

1097:   PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));

1099:   // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1100:   PetscCallCXX(Kokkos::parallel_for(
1101:     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1102:       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1103:         PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1104:         if (i < Fm) {
1105:           PetscInt nzLeft = F_NzLeft(i);
1106:           PetscInt alen   = Fdi(i + 1) - Fdi(i);
1107:           PetscInt blen   = Foi(i + 1) - Foi(i);
1108:           PetscInt Fii    = Fdi(i) + Foi(i);

1110:           Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1111:             PetscScalar val = leafBuf(Fii + j);
1112:             if (j < nzLeft) { // left
1113:               Foa(Foi(i) + j) = val;
1114:             } else if (j < nzLeft + alen) { // diag
1115:               Fda(Fdi(i) + j - nzLeft) = val;
1116:             } else { // right
1117:               Foa(Foi(i) + j - alen) = val;
1118:             }
1119:           });
1120:         }
1121:       });
1122:     }));
1123:   PetscFunctionReturn(PETSC_SUCCESS);
1124: }

1126: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1127: {
1128:   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1129:   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1130:   KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1131:   PetscInt        cstart, cend;
1132:   MPI_Comm        comm;

1134:   PetscFunctionBegin;
1135:   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1136:   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1137:   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1138:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1139:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1140:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1141:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));

1143:   // TODO: add command line options to select spgemm algorithms
1144:   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK

1146:   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1147: #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1148:   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1149:   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1150:   #endif
1151: #endif

1153:   PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1154:   PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1155:   PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1156:   PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));

1158:   // Aot * (B's diag + B's off-diag)
1159:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1160:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1161:   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1162:   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1163:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1164:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1165: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)

1167:   PetscCallCXX(sort_crs_matrix(mm->C3));
1168:   PetscCallCXX(sort_crs_matrix(mm->C4));
1169: #endif

1171:   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1172:   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1173:   PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1174:   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));

1176:   // Adt * (B's diag + B's off-diag)
1177:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1178:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1179:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1180:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1181: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1182:   PetscCallCXX(sort_crs_matrix(mm->C1));
1183:   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1184: #endif

1186:   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));

1188:   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1189:   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1190:   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1191:   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1192:   PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));

1194:   // C = (C1+Fd, C2+Fo)
1195:   PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1196:   PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1197:   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1198:   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1199:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1200:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1201:   PetscFunctionReturn(PETSC_SUCCESS);
1202: }

1204: static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1205: {
1206:   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1207:   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1208:   KokkosCsrMatrix Adt, Aot, Bd, Bo;
1209:   MPI_Comm        comm;

1211:   PetscFunctionBegin;
1212:   PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1213:   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1214:   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1215:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1216:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));

1218:   // Aot * (B's diag + B's off-diag)
1219:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1220:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));

1222:   // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1223:   PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));

1225:   // Adt * (B's diag + B's off-diag)
1226:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1227:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));

1229:   PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));

1231:   // C = (C1+Fd, C2+Fo)
1232:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1233:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1234:   PetscFunctionReturn(PETSC_SUCCESS);
1235: }

1237: /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos

1239:   Input Parameters:
1240: +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1241: .  A        - an MPIAIJKOKKOS matrix
1242: .  B        - an MPIAIJKOKKOS matrix
1243: -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1244: */
1245: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1246: {
1247:   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1248:   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1249:   KokkosCsrMatrix Ad, Ao, Bd, Bo;

1251:   PetscFunctionBegin;
1252:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1253:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1254:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1255:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));

1257:   // TODO: add command line options to select spgemm algorithms
1258:   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK

1260:   // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1261: #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1262:   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1263:   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1264:   #endif
1265: #endif

1267:   mm->kh1.create_spgemm_handle(spgemm_alg);
1268:   mm->kh2.create_spgemm_handle(spgemm_alg);
1269:   mm->kh3.create_spgemm_handle(spgemm_alg);
1270:   mm->kh4.create_spgemm_handle(spgemm_alg);

1272:   // Bcast B's rows to form F, and overlap the communication
1273:   PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1274:   PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));

1276:   // A's diag * (B's diag + B's off-diag)
1277:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1278:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1279:   // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1280:   // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1281:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1282:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1283: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1284:   PetscCallCXX(sort_crs_matrix(mm->C1));
1285:   PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1286: #endif

1288:   PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));

1290:   // A's off-diag * (F's diag + F's off-diag)
1291:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1292:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1293:   PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1294:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1295: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1296:   PetscCallCXX(sort_crs_matrix(mm->C3));
1297:   PetscCallCXX(sort_crs_matrix(mm->C4));
1298: #endif

1300:   // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1301:   MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1302:   PetscIntKokkosView  map  = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1303:   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1304:   mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);

1306:   // C = (Cd, Co) = (C1+C3, C2+C4)
1307:   mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1308:   mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1309:   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1310:   PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1311:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1312:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1313:   PetscFunctionReturn(PETSC_SUCCESS);
1314: }

1316: static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1317: {
1318:   Mat_MPIAIJ     *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1319:   Mat_MPIAIJ     *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1320:   KokkosCsrMatrix Ad, Ao, Bd, Bo;

1322:   PetscFunctionBegin;
1323:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1324:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1325:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1326:   PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));

1328:   // Bcast B's rows to form F, and overlap the communication
1329:   PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));

1331:   // A's diag * (B's diag + B's off-diag)
1332:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1333:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));

1335:   PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));

1337:   // A's off-diag * (F's diag + F's off-diag)
1338:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1339:   PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));

1341:   // C = (Cd, Co) = (C1+C3, C2+C4)
1342:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1343:   PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1344:   PetscFunctionReturn(PETSC_SUCCESS);
1345: }

1347: static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1348: {
1349:   Mat_MPIAIJ                  *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1350:   Mat_Product                 *product;
1351:   MatProductData_MPIAIJKokkos *pdata;
1352:   MatProductType               ptype;
1353:   Mat                          A, B;

1355:   PetscFunctionBegin;
1356:   MatCheckProduct(C, 1); // make sure C is a product
1357:   product = C->product;
1358:   pdata   = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1359:   ptype   = product->type;
1360:   A       = product->A;
1361:   B       = product->B;

1363:   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1364:   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1365:   // we still do numeric.
1366:   if (pdata->reusesym) { // numeric reuses results from symbolic
1367:     pdata->reusesym = PETSC_FALSE;
1368:     PetscFunctionReturn(PETSC_SUCCESS);
1369:   }

1371:   if (ptype == MATPRODUCT_AB) {
1372:     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1373:   } else if (ptype == MATPRODUCT_AtB) {
1374:     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1375:   } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1376:     PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1377:     PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1378:   }

1380:   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1381:   PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1382:   PetscFunctionReturn(PETSC_SUCCESS);
1383: }

1385: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1386: {
1387:   Mat                          A, B;
1388:   Mat_Product                 *product;
1389:   MatProductType               ptype;
1390:   MatProductData_MPIAIJKokkos *pdata;
1391:   MatMatStruct                *mm = NULL;
1392:   PetscInt                     m, n, M, N;
1393:   Mat                          Cd, Co;
1394:   MPI_Comm                     comm;

1396:   PetscFunctionBegin;
1397:   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1398:   MatCheckProduct(C, 1);
1399:   product = C->product;
1400:   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1401:   ptype = product->type;
1402:   A     = product->A;
1403:   B     = product->B;

1405:   switch (ptype) {
1406:   case MATPRODUCT_AB:
1407:     m = A->rmap->n;
1408:     n = B->cmap->n;
1409:     M = A->rmap->N;
1410:     N = B->cmap->N;
1411:     break;
1412:   case MATPRODUCT_AtB:
1413:     m = A->cmap->n;
1414:     n = B->cmap->n;
1415:     M = A->cmap->N;
1416:     N = B->cmap->N;
1417:     break;
1418:   case MATPRODUCT_PtAP:
1419:     m = B->cmap->n;
1420:     n = B->cmap->n;
1421:     M = B->cmap->N;
1422:     N = B->cmap->N;
1423:     break; /* BtAB */
1424:   default:
1425:     SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1426:   }

1428:   PetscCall(MatSetSizes(C, m, n, M, N));
1429:   PetscCall(PetscLayoutSetUp(C->rmap));
1430:   PetscCall(PetscLayoutSetUp(C->cmap));
1431:   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));

1433:   pdata           = new MatProductData_MPIAIJKokkos();
1434:   pdata->reusesym = product->api_user;

1436:   if (ptype == MATPRODUCT_AB) {
1437:     auto mmAB = new MatMatStruct_AB();
1438:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1439:     mm = pdata->mmAB = mmAB;
1440:   } else if (ptype == MATPRODUCT_AtB) {
1441:     auto mmAtB = new MatMatStruct_AtB();
1442:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1443:     mm = pdata->mmAtB = mmAtB;
1444:   } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1445:     Mat Zd, Zo, Z;                       // Zd, Zo are owned by pdata->Z

1447:     auto mmAB = new MatMatStruct_AB();
1448:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1449:     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1450:     PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1451:     pdata->mmAB = mmAB;

1453:     m = A->rmap->n; // Z's layout
1454:     n = B->cmap->n;
1455:     M = A->rmap->N;
1456:     N = B->cmap->N;
1457:     PetscCall(MatCreate(comm, &Z));
1458:     PetscCall(MatSetSizes(Z, m, n, M, N));
1459:     PetscCall(PetscLayoutSetUp(Z->rmap));
1460:     PetscCall(PetscLayoutSetUp(Z->cmap));
1461:     PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1462:     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));

1464:     auto mmAtB = new MatMatStruct_AtB();
1465:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}

1467:     pdata->Z = Z; // give ownership to pdata
1468:     mm = pdata->mmAtB = mmAtB;
1469:   }

1471:   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1472:   PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1473:   PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));

1475:   C->product->data       = pdata;
1476:   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1477:   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1478:   PetscFunctionReturn(PETSC_SUCCESS);
1479: }

1481: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1482: {
1483:   Mat_Product *product = mat->product;
1484:   PetscBool    match   = PETSC_FALSE;
1485:   PetscBool    usecpu  = PETSC_FALSE;

1487:   PetscFunctionBegin;
1488:   MatCheckProduct(mat, 1);
1489:   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1490:   if (match) { /* we can always fallback to the CPU if requested */
1491:     switch (product->type) {
1492:     case MATPRODUCT_AB:
1493:       if (product->api_user) {
1494:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1495:         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1496:         PetscOptionsEnd();
1497:       } else {
1498:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1499:         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1500:         PetscOptionsEnd();
1501:       }
1502:       break;
1503:     case MATPRODUCT_AtB:
1504:       if (product->api_user) {
1505:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1506:         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1507:         PetscOptionsEnd();
1508:       } else {
1509:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1510:         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1511:         PetscOptionsEnd();
1512:       }
1513:       break;
1514:     case MATPRODUCT_PtAP:
1515:       if (product->api_user) {
1516:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1517:         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1518:         PetscOptionsEnd();
1519:       } else {
1520:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1521:         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1522:         PetscOptionsEnd();
1523:       }
1524:       break;
1525:     default:
1526:       break;
1527:     }
1528:     match = (PetscBool)!usecpu;
1529:   }
1530:   if (match) {
1531:     switch (product->type) {
1532:     case MATPRODUCT_AB:
1533:     case MATPRODUCT_AtB:
1534:     case MATPRODUCT_PtAP:
1535:       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1536:       break;
1537:     default:
1538:       break;
1539:     }
1540:   }
1541:   /* fallback to MPIAIJ ops */
1542:   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1543:   PetscFunctionReturn(PETSC_SUCCESS);
1544: }

1546: // Mirror of MatCOOStruct_MPIAIJ on device
1547: struct MatCOOStruct_MPIAIJKokkos {
1548:   PetscCount           n;
1549:   PetscSF              sf;
1550:   PetscCount           Annz, Bnnz;
1551:   PetscCount           Annz2, Bnnz2;
1552:   PetscCountKokkosView Ajmap1, Aperm1;
1553:   PetscCountKokkosView Bjmap1, Bperm1;
1554:   PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1555:   PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1556:   PetscCountKokkosView Cperm1;
1557:   MatScalarKokkosView  sendbuf, recvbuf;

1559:   MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1560:   {
1561:     auto &exec = PetscGetKokkosExecutionSpace();

1563:     n       = coo_h->n;
1564:     sf      = coo_h->sf;
1565:     Annz    = coo_h->Annz;
1566:     Bnnz    = coo_h->Bnnz;
1567:     Annz2   = coo_h->Annz2;
1568:     Bnnz2   = coo_h->Bnnz2;
1569:     Ajmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1570:     Aperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1571:     Bjmap1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1572:     Bperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1573:     Aimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1574:     Ajmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1575:     Aperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1576:     Bimap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1577:     Bjmap2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1578:     Bperm2  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1579:     Cperm1  = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1580:     sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1581:     recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1582:     PetscCallVoid(PetscObjectReference((PetscObject)sf));
1583:   }

1585:   ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1586: };

1588: static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1589: {
1590:   PetscFunctionBegin;
1591:   PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1592:   PetscFunctionReturn(PETSC_SUCCESS);
1593: }

1595: static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1596: {
1597:   PetscContainer             container_h, container_d;
1598:   MatCOOStruct_MPIAIJ       *coo_h;
1599:   MatCOOStruct_MPIAIJKokkos *coo_d;

1601:   PetscFunctionBegin;
1602:   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1603:   mat->preallocated = PETSC_TRUE;
1604:   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1605:   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1606:   PetscCall(MatZeroEntries(mat));

1608:   // Copy the COO struct to device
1609:   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1610:   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1611:   PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));

1613:   // Put the COO struct in a container and then attach that to the matrix
1614:   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1615:   PetscCall(PetscContainerSetPointer(container_d, coo_d));
1616:   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1617:   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1618:   PetscCall(PetscContainerDestroy(&container_d));
1619:   PetscFunctionReturn(PETSC_SUCCESS);
1620: }

1622: static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1623: {
1624:   Mat_MPIAIJ                    *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1625:   Mat                            A = mpiaij->A, B = mpiaij->B;
1626:   MatScalarKokkosView            Aa, Ba;
1627:   MatScalarKokkosView            v1;
1628:   PetscMemType                   memtype;
1629:   PetscContainer                 container;
1630:   MatCOOStruct_MPIAIJKokkos     *coo;
1631:   Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();

1633:   PetscFunctionBegin;
1634:   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1635:   PetscCall(PetscContainerGetPointer(container, (void **)&coo));

1637:   const auto &n      = coo->n;
1638:   const auto &Annz   = coo->Annz;
1639:   const auto &Annz2  = coo->Annz2;
1640:   const auto &Bnnz   = coo->Bnnz;
1641:   const auto &Bnnz2  = coo->Bnnz2;
1642:   const auto &vsend  = coo->sendbuf;
1643:   const auto &v2     = coo->recvbuf;
1644:   const auto &Ajmap1 = coo->Ajmap1;
1645:   const auto &Ajmap2 = coo->Ajmap2;
1646:   const auto &Aimap2 = coo->Aimap2;
1647:   const auto &Bjmap1 = coo->Bjmap1;
1648:   const auto &Bjmap2 = coo->Bjmap2;
1649:   const auto &Bimap2 = coo->Bimap2;
1650:   const auto &Aperm1 = coo->Aperm1;
1651:   const auto &Aperm2 = coo->Aperm2;
1652:   const auto &Bperm1 = coo->Bperm1;
1653:   const auto &Bperm2 = coo->Bperm2;
1654:   const auto &Cperm1 = coo->Cperm1;

1656:   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1657:   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1658:     v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1659:   } else {
1660:     v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1661:   }

1663:   if (imode == INSERT_VALUES) {
1664:     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1665:     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1666:   } else {
1667:     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1668:     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1669:   }

1671:   PetscCall(PetscLogGpuTimeBegin());
1672:   /* Pack entries to be sent to remote */
1673:   Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });

1675:   /* Send remote entries to their owner and overlap the communication with local computation */
1676:   PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1677:   /* Add local entries to A and B in one kernel */
1678:   Kokkos::parallel_for(
1679:     Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1680:       PetscScalar sum = 0.0;
1681:       if (i < Annz) {
1682:         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1683:         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1684:       } else {
1685:         i -= Annz;
1686:         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1687:         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1688:       }
1689:     });
1690:   PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));

1692:   /* Add received remote entries to A and B in one kernel */
1693:   Kokkos::parallel_for(
1694:     Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1695:       if (i < Annz2) {
1696:         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1697:       } else {
1698:         i -= Annz2;
1699:         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1700:       }
1701:     });
1702:   PetscCall(PetscLogGpuTimeEnd());

1704:   if (imode == INSERT_VALUES) {
1705:     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1706:     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1707:   } else {
1708:     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1709:     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1710:   }
1711:   PetscFunctionReturn(PETSC_SUCCESS);
1712: }

1714: static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1715: {
1716:   PetscFunctionBegin;
1717:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1718:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1719:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1720:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1721:   PetscCall(MatDestroy_MPIAIJ(A));
1722:   PetscFunctionReturn(PETSC_SUCCESS);
1723: }

1725: static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1726: {
1727:   Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1728:   PetscBool   congruent;

1730:   PetscFunctionBegin;
1731:   PetscCall(MatHasCongruentLayouts(A, &congruent));
1732:   if (congruent) { // square matrix and the diagonals are solely in the diag block
1733:     PetscCall(MatShift(mpiaij->A, a));
1734:   } else { // too hard, use the general version
1735:     PetscCall(MatShift_Basic(A, a));
1736:   }
1737:   PetscFunctionReturn(PETSC_SUCCESS);
1738: }

1740: static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1741: {
1742:   PetscFunctionBegin;
1743:   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1744:   B->ops->mult                  = MatMult_MPIAIJKokkos;
1745:   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1746:   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1747:   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1748:   B->ops->destroy               = MatDestroy_MPIAIJKokkos;
1749:   B->ops->shift                 = MatShift_MPIAIJKokkos;

1751:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1752:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1753:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1754:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1755:   PetscFunctionReturn(PETSC_SUCCESS);
1756: }

1758: PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1759: {
1760:   Mat         B;
1761:   Mat_MPIAIJ *a;

1763:   PetscFunctionBegin;
1764:   if (reuse == MAT_INITIAL_MATRIX) {
1765:     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1766:   } else if (reuse == MAT_REUSE_MATRIX) {
1767:     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1768:   }
1769:   B = *newmat;

1771:   B->boundtocpu = PETSC_FALSE;
1772:   PetscCall(PetscFree(B->defaultvectype));
1773:   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1774:   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));

1776:   a = static_cast<Mat_MPIAIJ *>(A->data);
1777:   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1778:   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1779:   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1780:   PetscCall(MatSetOps_MPIAIJKokkos(B));
1781:   PetscFunctionReturn(PETSC_SUCCESS);
1782: }

1784: /*MC
1785:    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos

1787:    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types

1789:    Options Database Key:
1790: .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`

1792:   Level: beginner

1794: .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1795: M*/
1796: PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1797: {
1798:   PetscFunctionBegin;
1799:   PetscCall(PetscKokkosInitializeCheck());
1800:   PetscCall(MatCreate_MPIAIJ(A));
1801:   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1802:   PetscFunctionReturn(PETSC_SUCCESS);
1803: }

1805: /*@C
1806:   MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1807:   (the default parallel PETSc format).  This matrix will ultimately pushed down
1808:   to Kokkos for calculations.

1810:   Collective

1812:   Input Parameters:
1813: + comm  - MPI communicator, set to `PETSC_COMM_SELF`
1814: . m     - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1815:            This value should be the same as the local size used in creating the
1816:            y vector for the matrix-vector product y = Ax.
1817: . n     - This value should be the same as the local size used in creating the
1818:        x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1819:        calculated if N is given) For square matrices n is almost always `m`.
1820: . M     - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1821: . N     - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1822: . d_nz  - number of nonzeros per row in DIAGONAL portion of local submatrix
1823:            (same value is used for all local rows)
1824: . d_nnz - array containing the number of nonzeros in the various rows of the
1825:            DIAGONAL portion of the local submatrix (possibly different for each row)
1826:            or `NULL`, if `d_nz` is used to specify the nonzero structure.
1827:            The size of this array is equal to the number of local rows, i.e `m`.
1828:            For matrices you plan to factor you must leave room for the diagonal entry and
1829:            put in the entry even if it is zero.
1830: . o_nz  - number of nonzeros per row in the OFF-DIAGONAL portion of local
1831:            submatrix (same value is used for all local rows).
1832: - o_nnz - array containing the number of nonzeros in the various rows of the
1833:            OFF-DIAGONAL portion of the local submatrix (possibly different for
1834:            each row) or `NULL`, if `o_nz` is used to specify the nonzero
1835:            structure. The size of this array is equal to the number
1836:            of local rows, i.e `m`.

1838:   Output Parameter:
1839: . A - the matrix

1841:   Level: intermediate

1843:   Notes:
1844:   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1845:   MatXXXXSetPreallocation() paradigm instead of this routine directly.
1846:   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]

1848:   The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1849:   storage.  That is, the stored row and column indices can begin at
1850:   either one (as in Fortran) or zero.

1852: .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1853:           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1854: @*/
1855: PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1856: {
1857:   PetscMPIInt size;

1859:   PetscFunctionBegin;
1860:   PetscCall(MatCreate(comm, A));
1861:   PetscCall(MatSetSizes(*A, m, n, M, N));
1862:   PetscCallMPI(MPI_Comm_size(comm, &size));
1863:   if (size > 1) {
1864:     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1865:     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1866:   } else {
1867:     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1868:     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1869:   }
1870:   PetscFunctionReturn(PETSC_SUCCESS);
1871: }