Actual source code: mpiaijkok.kokkos.cxx
1: #include <petsc_kokkos.hpp>
2: #include <petscvec_kokkos.hpp>
3: #include <petscpkg_version.h>
4: #include <petsc/private/sfimpl.h>
5: #include <petsc/private/kokkosimpl.hpp>
6: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7: #include <../src/mat/impls/aij/mpi/mpiaij.h>
8: #include <KokkosSparse_spadd.hpp>
9: #include <KokkosSparse_spgemm.hpp>
11: static PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
12: {
13: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
15: PetscFunctionBegin;
16: PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
17: /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
18: Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
19: */
20: if (mode == MAT_FINAL_ASSEMBLY) {
21: PetscScalarKokkosView v;
23: PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
24: PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
25: PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS)); // lvec is init'ed on host, without copying to device
26: PetscCall(VecGetKokkosViewWrite(mpiaij->lvec, &v)); // mark lvec updated on device, as we never need to init lvec on device
27: PetscCall(VecRestoreKokkosViewWrite(mpiaij->lvec, &v));
28: }
29: PetscFunctionReturn(PETSC_SUCCESS);
30: }
32: static PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
33: {
34: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
36: PetscFunctionBegin;
37: // If mat was set to use the "set values with a hash table" mechanism, discard it and restore the cached ops
38: if (mat->hash_active) {
39: mat->ops[0] = mpiaij->cops;
40: mat->hash_active = PETSC_FALSE;
41: }
43: PetscCall(PetscLayoutSetUp(mat->rmap));
44: PetscCall(PetscLayoutSetUp(mat->cmap));
45: #if defined(PETSC_USE_DEBUG)
46: if (d_nnz) {
47: PetscInt i;
48: for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
49: }
50: if (o_nnz) {
51: PetscInt i;
52: for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
53: }
54: #endif
55: #if defined(PETSC_USE_CTABLE)
56: PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
57: #else
58: PetscCall(PetscFree(mpiaij->colmap));
59: #endif
60: PetscCall(PetscFree(mpiaij->garray));
61: PetscCall(VecDestroy(&mpiaij->lvec));
62: PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
63: /* Because the B will have been resized we simply destroy it and create a new one each time */
64: PetscCall(MatDestroy(&mpiaij->B));
66: if (!mpiaij->A) {
67: PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
68: PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
69: }
70: if (!mpiaij->B) {
71: PetscMPIInt size;
72: PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
73: PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
74: PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
75: }
76: PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
77: PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
78: PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
79: PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
80: mat->preallocated = PETSC_TRUE;
81: PetscFunctionReturn(PETSC_SUCCESS);
82: }
84: static PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
85: {
86: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
87: PetscInt nt;
89: PetscFunctionBegin;
90: PetscCall(VecGetLocalSize(xx, &nt));
91: PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
92: PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
93: PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
94: PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
95: PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
96: PetscFunctionReturn(PETSC_SUCCESS);
97: }
99: static PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
100: {
101: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
102: PetscInt nt;
104: PetscFunctionBegin;
105: PetscCall(VecGetLocalSize(xx, &nt));
106: PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
107: PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
108: PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
109: PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
110: PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
111: PetscFunctionReturn(PETSC_SUCCESS);
112: }
114: static PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
115: {
116: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
117: PetscInt nt;
119: PetscFunctionBegin;
120: PetscCall(VecGetLocalSize(xx, &nt));
121: PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
122: PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
123: PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
124: PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
125: PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
126: PetscFunctionReturn(PETSC_SUCCESS);
127: }
129: /* Merge the "A, B" matrices of mat into a matrix C. mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
130: A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
131: C still uses local column ids. Their corresponding global column ids are returned in glob.
132: */
133: static PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
134: {
135: Mat Ad, Ao;
136: const PetscInt *cmap;
138: PetscFunctionBegin;
139: PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
140: PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
141: if (glob) {
142: PetscInt cst, i, dn, on, *gidx;
143: PetscCall(MatGetLocalSize(Ad, NULL, &dn));
144: PetscCall(MatGetLocalSize(Ao, NULL, &on));
145: PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
146: PetscCall(PetscMalloc1(dn + on, &gidx));
147: for (i = 0; i < dn; i++) gidx[i] = cst + i;
148: for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
149: PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
150: }
151: PetscFunctionReturn(PETSC_SUCCESS);
152: }
154: /* Structs used in matrix products of type C=AB, C=A^tB and C=B^tAB */
155: struct MatMatStruct {
156: PetscInt n, *garray; // C's garray and its size.
157: KokkosCsrMatrix Cd, Co; // C is in split form matrices (all in local column indcies)
158: KokkosCsrMatrix C1, C2, C3, C4; // intermediate mat products
159: KokkosCsrMatrix C2_mid, C4_mid; // alias of C2, C4; share their a[], i[], but with different j[] (hence column size)
160: PetscIntKokkosView E_NzLeft;
161: PetscSF sf = nullptr; // SF to bcast or reduce matrices E to F
162: MatScalarKokkosView rootBuf, leafBuf;
163: KokkosCsrMatrix Fd, Fo; // F in split form
165: KernelHandle kh1; // compute C1, add C1+C3 or C1+Fd
166: KernelHandle kh2; // compute C2, add C2+C4 or C2+Fo
167: KernelHandle kh3; // compute C3
168: KernelHandle kh4; // compute C4
170: PetscInt E_TeamSize; // kernel launching parameters in merging E or splitting F
171: PetscInt E_VectorLength;
172: PetscInt E_RowsPerTeam;
173: PetscInt F_TeamSize;
174: PetscInt F_VectorLength;
175: PetscInt F_RowsPerTeam;
177: ~MatMatStruct()
178: {
179: PetscFunctionBegin;
180: PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
181: PetscFunctionReturnVoid();
182: }
183: };
185: struct MatMatStruct_AB : public MatMatStruct {
186: PetscIntKokkosView F_NzLeft; // plans to split F (in leafbuf) into Fd, Fo
187: PetscIntKokkosView irootloc; // plans to put E (i.e., Bd, Bo) into rootBuf
188: PetscIntKokkosView rowoffset;
189: };
191: struct MatMatStruct_AtB : public MatMatStruct {
192: MatColIdxKokkosView Fdjmap; // plans to reduce data in rootBuf to Fd, Fo
193: MatColIdxKokkosView Fdjperm;
194: MatColIdxKokkosView Fojmap;
195: MatColIdxKokkosView Fojperm;
196: };
198: struct MatProductData_MPIAIJKokkos {
199: MatMatStruct_AB *mmAB = nullptr;
200: MatMatStruct_AtB *mmAtB = nullptr;
201: PetscBool reusesym = PETSC_FALSE;
202: Mat Z = nullptr; // store Z=AB in computing BtAB
204: ~MatProductData_MPIAIJKokkos()
205: {
206: delete mmAB;
207: delete mmAtB;
208: PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&Z));
209: }
210: };
212: static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
213: {
214: PetscFunctionBegin;
215: PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
216: PetscFunctionReturn(PETSC_SUCCESS);
217: }
219: /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
220: It is similar to MatCreateMPIAIJWithSplitArrays.
222: Input Parameters:
223: + mat - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
224: . A - the diag matrix using local col ids
225: - B - the offdiag matrix using global col ids
227: Output Parameter:
228: . mat - the updated MATMPIAIJKOKKOS matrix
229: */
230: static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B, PetscInt *garray)
231: {
232: Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
233: PetscInt m, n, M, N, Am, An, Bm, Bn;
235: PetscFunctionBegin;
236: PetscCall(MatGetSize(mat, &M, &N));
237: PetscCall(MatGetLocalSize(mat, &m, &n));
238: PetscCall(MatGetLocalSize(A, &Am, &An));
239: PetscCall(MatGetLocalSize(B, &Bm, &Bn));
241: PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
242: PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
243: // PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
244: PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
245: mpiaij->A = A;
246: mpiaij->B = B;
247: mpiaij->garray = garray;
249: mat->preallocated = PETSC_TRUE;
250: mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */
252: PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
253: PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
254: /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
255: also gets mpiaij->B compacted, with its col ids and size reduced
256: */
257: PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
258: PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
259: PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));
260: PetscFunctionReturn(PETSC_SUCCESS);
261: }
263: // Adapted from Kokkos-Kernels spmv_launch_parameters(), to get parameters in Kokkos nested loops which we used to merge or
264: // split csr matrices. The rule is to have "vector_length * team_size" be around 256 on GPUs (e.g., for a CUDA thread block)
265: template <class ExecutionSpace>
266: static PetscErrorCode MatMergeGetLaunchParameters(PetscInt numRows, PetscInt nnz, PetscInt rows_per_thread, PetscInt &team_size, PetscInt &vector_length, PetscInt &rows_per_team)
267: {
268: Kokkos::TeamPolicy<ExecutionSpace> teamPolicy(128, Kokkos::AUTO);
270: PetscFunctionBegin;
271: PetscInt nnz_per_row = numRows ? (nnz / numRows) : 0; // we might meet empty matrices
273: if (nnz_per_row < 1) nnz_per_row = 1;
275: int max_vector_length = teamPolicy.vector_length_max();
277: if (vector_length < 1) {
278: vector_length = 1;
279: while (vector_length < max_vector_length && vector_length * 6 < nnz_per_row) vector_length *= 2;
280: }
282: // Determine rows per thread
283: if (rows_per_thread < 1) {
284: if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) rows_per_thread = 1;
285: else {
286: if (nnz_per_row < 20 && nnz > 5000000) {
287: rows_per_thread = 256;
288: } else rows_per_thread = 64;
289: }
290: }
292: if (team_size < 1) {
293: if (KokkosKernels::Impl::kk_is_gpu_exec_space<ExecutionSpace>()) {
294: team_size = 256 / vector_length;
295: } else {
296: team_size = 1;
297: }
298: }
300: rows_per_team = rows_per_thread * team_size;
302: if (rows_per_team < 0) {
303: PetscInt nnz_per_team = 4096;
304: PetscInt conc = ExecutionSpace().concurrency();
305: while ((conc * nnz_per_team * 4 > nnz) && (nnz_per_team > 256)) nnz_per_team /= 2;
306: rows_per_team = (nnz_per_team + nnz_per_row - 1) / nnz_per_row;
307: }
308: PetscFunctionReturn(PETSC_SUCCESS);
309: }
311: /*
312: Reduce two sets of global indices into local ones
314: Input Parameters:
315: + n1 - size of garray1[], the first set
316: . garray1[n1] - a sorted global index array (without duplicates)
317: . m - size of indices[], the second set
318: - indices[m] - a unsorted global index array (might have duplicates), which will be updated on output into local ones
320: Output Parameters:
321: + n2 - size of garray2[], the merged set, which combines garray1[] and indices[]
322: . garray2[n2] - allocated by callee using PetscMalloc1(). Contains sorted unique global indices (without duplicates). Caller needs to free it.
323: . map[n1] - allocated by caller. It gives garray1[i] = garray2[map[i]]
324: - indices[m] - on output, global indices in this array are rewritten with local ones, i.e, indices_input[i] = garray2[indices_output[i]]
326: Example, say
327: n1 = 5
328: garray1[5] = {1, 4, 7, 8, 10}
329: m = 4
330: indices[4] = {2, 4, 8, 9}
332: Combining them together, we have 7 global indices in garray2[]
333: n2 = 7
334: garray2[7] = {1, 2, 4, 7, 8, 9, 10}
336: And we have map[] to connect "garray1[i] = garray2[map[i]], i=[0,n1)"
337: map[5] = {0, 2, 3, 4, 6}
339: On output, indices[] is updated with local indices
340: indices[4] = {1, 2, 4, 5}
341: */
342: static PetscErrorCode ReduceTwoSetsOfGlobalIndices(PetscInt n1, const PetscInt *garray1, PetscInt m, PetscInt *indices, PetscInt *n2_, PetscInt **garray2_, PetscInt *map)
343: {
344: PetscHMapI g2l = nullptr;
345: PetscHashIter iter;
346: PetscInt tot, key, val; // total unique global indices. key is global id; val is local id
347: PetscInt n2, *garray2;
349: PetscFunctionBegin;
350: tot = 0;
351: PetscCall(PetscHMapICreateWithSize(n1, &g2l));
352: for (PetscInt i = 0; i < m; i++) { // insert those in indices[]
353: PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val)); // if not exist, val is set with -1
354: if (val < 0) PetscCall(PetscHMapISet(g2l, indices[i], tot++)); // val < 0 means gid is not in the hash table yet
355: }
357: for (PetscInt i = 0; i < n1; i++) { // insert those in garray1[]
358: PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &val));
359: if (val < 0) PetscCall(PetscHMapISet(g2l, garray1[i], tot++));
360: }
362: // Pull out (unique) globals in the hash table and put them in garray2[]
363: n2 = tot;
364: PetscCall(PetscMalloc1(n2, &garray2));
365: tot = 0;
366: PetscHashIterBegin(g2l, iter);
367: while (!PetscHashIterAtEnd(g2l, iter)) {
368: PetscHashIterGetKey(g2l, iter, key);
369: PetscHashIterNext(g2l, iter);
370: garray2[tot++] = key;
371: }
373: // Sort garray2[] and then map them to local indices starting from 0
374: PetscCall(PetscSortInt(n2, garray2));
375: PetscCall(PetscHMapIClear(g2l));
376: for (PetscInt i = 0; i < tot; i++) PetscCall(PetscHMapISet(g2l, garray2[i], i)); // i is the local id
378: // Rewrite indices[] with local indices
379: for (PetscInt i = 0; i < m; i++) {
380: PetscCall(PetscHMapIGetWithDefault(g2l, indices[i], -1, &val));
381: PetscAssert(val >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Met a negative local column index");
382: indices[i] = val;
383: }
384: // Record the map that maps garray1[i] to garray2[map[i]]
385: for (PetscInt i = 0; i < n1; i++) PetscCall(PetscHMapIGetWithDefault(g2l, garray1[i], -1, &map[i]));
386: PetscCall(PetscHMapIDestroy(&g2l));
387: *n2_ = n2;
388: *garray2_ = garray2;
389: PetscFunctionReturn(PETSC_SUCCESS);
390: }
392: /*
393: MatMPIAIJKokkosReduce - Reduce rows of a MPIAIJKOKKOS matrix (E, in split form) to produce another matrix (F, also in split form, stored in mm)
395: It is the reverse of MatMPIAIJKokkosBcast() in some sense, but with a different signature since we do not really need a fully populated MPIAIJKOKKOS E.
397: Think each row of E as a leaf, then the given ownerSF specifies roots for the leaves. Roots may connect to multiple leaves.
398: In this routine, we sparse-merge leaves (rows) at their roots to form potentially longer rows in F. F's number of rows will be nroots of ownerSF.
400: Input Parameters:
401: + comm - MPI communicator of E
402: . A - diag block of E, using local column indices
403: . B - off-diag block of E, using local column indices
404: . cstart - (global) start column of Ed
405: . cend - (global) end column + 1 of Ed. In other words, E's column ownership is in range of [cstart, cend)
406: . garray1[n1] - global column indices of Eo. Here n1 is Eo's column size.
407: . ownerSF - the SF specifies ownership (root) of rows in E
408: . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
409: - mm - to stash intermediate data structures for reuse
411: Output Parameters:
412: + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices().
413: - mm - contains various info, such as garray2[], F (Fd, Fo) etc.
415: Notes:
416: When reuse = MAT_REUSE_MATRIX, cstart, cend, garray1, ownerSF, map are not significant.
418: */
419: static PetscErrorCode MatMPIAIJKokkosReduceBegin(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
420: {
421: PetscFunctionBegin;
422: if (reuse == MAT_INITIAL_MATRIX) {
423: PetscInt Em = A.numRows(), Fm;
424: PetscInt n1 = B.numCols();
426: PetscCall(PetscSFGetGraph(ownerSF, &Fm, NULL, NULL, NULL)); // Fm = #rows of F = nroots of ownerSF
428: // Do the analysis on host
429: auto Ai_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.row_map);
430: auto Aj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A.graph.entries);
431: auto Bi_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.row_map);
432: auto Bj_h = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), B.graph.entries);
433: const MatRowMapType *Ai = Ai_h.data(), *Bi = Bi_h.data();
434: const MatColIdxType *Aj = Aj_h.data(), *Bj = Bj_h.data();
436: // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
437: PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
438: PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
439: for (PetscInt i = 0; i < Em; i++) {
440: const PetscInt *first, *last, *it;
441: PetscInt count, step;
442: // std::lower_bound(first,last,cstart), but need to use global column indices
443: first = Bj + Bi[i];
444: last = Bj + Bi[i + 1];
445: count = last - first;
446: while (count > 0) {
447: it = first;
448: step = count / 2;
449: it += step;
450: if (garray1[*it] < cstart) { // map local to global
451: first = ++it;
452: count -= step + 1;
453: } else count = step;
454: }
455: E_NzLeft[i] = first - (Bj + Bi[i]);
456: E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
457: }
459: // Get length of rows (i.e., sizes of leaves) that contribute to my roots
460: const PetscMPIInt *iranks, *ranks;
461: const PetscInt *ioffset, *irootloc, *roffset, *rmine;
462: PetscInt niranks, nranks;
463: MPI_Request *reqs;
464: PetscMPIInt tag;
465: PetscSF reduceSF;
466: PetscInt *sdisp, *rdisp;
468: PetscCall(PetscCommGetNewTag(comm, &tag));
469: PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks connecting to roots on this process (I'll recv from them)
470: PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, &rmine, NULL)); // get root ranks this process connects (I'll send to them)
472: // Find out length of each row I will receive. Even for the same row index, when they are from
473: // different senders, they might have different lengths (and sparsity patterns)
474: PetscInt sendRowCnt = roffset[nranks], recvRowCnt = ioffset[niranks];
475: PetscInt *sendRowLen, *recvRowLen; // lengths of rows of I need to send/recv per process
477: PetscCall(PetscMalloc5(sendRowCnt, &sendRowLen, recvRowCnt + 1, &recvRowLen, nranks, &sdisp, niranks + 1, &rdisp, nranks + niranks, &reqs));
479: for (PetscInt i = 0; i < sendRowCnt; i++) sendRowLen[i] = E_RowLen[rmine[i]];
480: recvRowLen[0] = 0; // since we will make it in CSR format later
481: recvRowLen++; // advance the pointer now
482: for (PetscInt i = 0; i < niranks; i++) { MPI_Irecv(&recvRowLen[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
483: for (PetscInt i = 0; i < nranks; i++) { MPI_Isend(&sendRowLen[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
484: PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
486: // Build the real PetscSF for reducing E rows (buffer to buffer)
487: rdisp[0] = 0;
488: for (PetscInt i = 0; i < niranks; i++) {
489: rdisp[i + 1] = rdisp[i];
490: for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) { rdisp[i + 1] += recvRowLen[j]; }
491: }
492: recvRowLen--; // put it back into csr format
493: for (PetscInt i = 0; i < recvRowCnt; i++) recvRowLen[i + 1] += recvRowLen[i];
495: for (PetscInt i = 0; i < nranks; i++) { MPI_Irecv(&sdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]); }
496: for (PetscInt i = 0; i < niranks; i++) { MPI_Isend(&rdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]); }
497: PetscCallMPI(MPI_Waitall(nranks + niranks, reqs, MPI_STATUSES_IGNORE));
499: PetscInt nleaves = 0, Enz = 0; // leaves are nonzeros I will send
500: PetscInt nroots = rdisp[niranks]; // roots are nonzeros I will recv
501: PetscSFNode *iremote;
503: for (PetscInt i = 0; i < Em; i++) Enz += E_RowLen[i];
504: PetscAssert(A.nnz() + B.nnz() == Enz, comm, PETSC_ERR_PLIB, "Enz should be equal to sum of nnz of A and B");
505: PetscCallMPI(PetscMalloc1(Enz, &iremote)); // no free, since we give ownership to reduceSF
507: for (PetscInt i = 0; i < nranks; i++) {
508: PetscInt count = 0;
509: for (PetscInt j = roffset[i]; j < roffset[i + 1]; j++) count += E_RowLen[rmine[j]];
510: for (PetscInt j = 0; j < count; j++) {
511: iremote[nleaves + j].rank = ranks[i];
512: iremote[nleaves + j].index = sdisp[i] + j;
513: }
514: nleaves += count;
515: }
516: PetscCheck(nleaves == Enz, comm, PETSC_ERR_PLIB, "nleaves should be equal to Enz");
518: PetscCall(PetscSFCreate(comm, &reduceSF));
519: PetscCall(PetscSFSetGraph(reduceSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
521: // Copy (global) column indices of the needed rows in E to sendCol[], and then PetscSFReduce to recvCol[]
522: PetscInt *sendCol, *recvCol;
523: PetscCall(PetscMalloc2(nleaves, &sendCol, nroots, &recvCol));
524: for (PetscInt k = 0; k < roffset[nranks]; k++) {
525: PetscInt i = rmine[k]; // row to be copied
526: PetscInt *buf = &sendCol[Ai[i] + Bi[i]];
527: PetscInt nzLeft = E_NzLeft[i];
528: PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
529: for (PetscInt j = 0; j < alen + blen; j++) {
530: if (j < nzLeft) {
531: buf[j] = garray1[Bj[Bi[i] + j]]; // left B, in global
532: } else if (j < nzLeft + alen) {
533: buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
534: } else {
535: buf[j] = garray1[Bj[Bi[i] + j - alen]]; // right B, in global
536: }
537: }
538: }
539: PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_INT, PETSC_MEMTYPE_HOST, sendCol, PETSC_MEMTYPE_HOST, recvCol, MPI_REPLACE));
540: PetscCall(PetscSFReduceEnd(reduceSF, MPIU_INT, sendCol, recvCol, MPI_REPLACE));
542: // With recvCol[], we do a series of analysis to get i, j of Fd, Fo, and build plans to reduce nonzeros in recv buffers to Fd and Fo
543: PetscInt *recvRowPerm, *recvColSorted;
544: PetscInt *recvNzPerm, *recvNzPermSorted;
545: PetscCall(PetscMalloc4(recvRowCnt, &recvRowPerm, nroots, &recvColSorted, nroots, &recvNzPerm, nroots, &recvNzPermSorted));
547: for (PetscInt i = 0; i < nroots; i++) recvNzPerm[i] = i; // numbering all received nonzeros
548: for (PetscInt i = 0; i < recvRowCnt; i++) recvRowPerm[i] = i; // put up a permutation array, so that after sorting we know where to get a row in recvCol[]
549: PetscCall(PetscSortIntWithPermutation(recvRowCnt, irootloc, recvRowPerm)); // irootloc[] (owned by ownerSF) won't be changed
551: // i[] array, nz are always easiest to compute
552: MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1);
553: MatRowMapType *Fdi, *Foi;
554: PetscInt FnzDups = 0, Fdnz = 0, FdnzDups = 0, Fonz = 0, FonzDups = 0; // nz (with or without dups) in F, Fd, Fo
555: PetscInt iter;
557: Kokkos::deep_copy(Fdi_h, 0); // zero, as we will do 'val++' on them
558: Kokkos::deep_copy(Foi_h, 0);
559: Fdi = Fdi_h.data() + 1; // +1 for easy indexing in code below
560: Foi = Foi_h.data() + 1;
561: iter = 0;
562: while (iter < recvRowCnt) { // iter over received rows
563: PetscInt curRowIdx = irootloc[recvRowPerm[iter]];
564: PetscInt dupRows = 1; // current row has this many contributing rows (of various sparsity patterns)
566: while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
568: // Copy column indices (and their permutation) of these rows into recvColSorted & recvNzPermSorted
569: PetscInt nz = 0; // nz (with dups) in the current row
570: PetscInt *jbuf = recvColSorted + FnzDups;
571: PetscInt *pbuf = recvNzPermSorted + FnzDups;
572: PetscInt *jbuf2 = jbuf; // temp pointers
573: PetscInt *pbuf2 = pbuf;
574: for (PetscInt d = 0; d < dupRows; d++) {
575: PetscInt i = recvRowPerm[iter + d];
576: PetscInt len = recvRowLen[i + 1] - recvRowLen[i];
577: PetscCall(PetscArraycpy(jbuf2, &recvCol[recvRowLen[i]], len));
578: PetscCall(PetscArraycpy(pbuf2, &recvNzPerm[recvRowLen[i]], len));
579: jbuf2 += len;
580: pbuf2 += len;
581: nz += len;
582: }
583: PetscCall(PetscIntSortSemiOrderedWithArray(nz, jbuf, pbuf)); // It could be improved with k-way merge sort, since the rows are already sorted
585: // Scan column indices (in jbuf[0,nz), might have dups) of this row, and see how many go to Fd and how many go to Fo
586: PetscInt cur = 0;
587: while (cur < nz) {
588: PetscInt curColIdx = jbuf[cur];
589: PetscInt dups = 1;
591: while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
592: if (curColIdx >= cstart && curColIdx < cend) {
593: Fdi[curRowIdx]++;
594: FdnzDups += dups;
595: } else {
596: Foi[curRowIdx]++;
597: FonzDups += dups;
598: }
599: cur += dups;
600: }
602: FnzDups += nz;
603: iter += dupRows; // Move to next unique row
604: }
606: Fdi = Fdi_h.data(); // restore Fdi, Foi and make them CSR
607: Foi = Foi_h.data();
608: for (PetscInt i = 0; i < Fm; i++) {
609: Fdi[i + 1] += Fdi[i];
610: Foi[i + 1] += Foi[i];
611: }
612: Fdnz = Fdi[Fm];
613: Fonz = Foi[Fm];
614: PetscCall(PetscFree2(sendCol, recvCol));
616: // Allocate j, jmap, jperm for Fd and Fo
617: MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
618: MatRowMapKokkosViewHost Fdjmap_h(NoInit("Fdjmap_h"), Fdnz + 1), Fojmap_h(NoInit("Fojmap_h"), Fonz + 1); // +1 to make csr
619: MatRowMapKokkosViewHost Fdjperm_h(NoInit("Fdjperm_h"), FdnzDups), Fojperm_h(NoInit("Fojperm_h"), FonzDups);
620: MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data();
621: MatRowMapType *Fdjmap = Fdjmap_h.data(), *Fojmap = Fojmap_h.data();
622: MatRowMapType *Fdjperm = Fdjperm_h.data(), *Fojperm = Fojperm_h.data();
624: // Scan recvColSorted[] again, and fill j, jmap, jperm for Fd and Fo
625: Fdjmap[0] = 0;
626: Fojmap[0] = 0;
627: FnzDups = 0;
628: Fdnz = 0;
629: Fonz = 0;
630: iter = 0; // iter over received rows
631: while (iter < recvRowCnt) {
632: PetscInt curRowIdx = irootloc[recvRowPerm[iter]]; // current row idx
633: PetscInt dupRows = 1; // It has this many contributing rows (of various lengths)
634: PetscInt nz = 0; // nz (with dups) in the current row
636: while (iter + dupRows < recvRowCnt && irootloc[recvRowPerm[iter + dupRows]] == curRowIdx) dupRows++;
637: for (PetscInt d = 0; d < dupRows; d++) {
638: PetscInt i = recvRowPerm[iter + d];
639: nz += recvRowLen[i + 1] - recvRowLen[i];
640: }
642: PetscInt *jbuf = recvColSorted + FnzDups;
643: // Scan columns (in jbuf[0,nz) of this row, copy them and their permutation to j[] and jperm[] of Fd and Fo
644: PetscInt cur = 0;
645: while (cur < nz) {
646: PetscInt curColIdx = jbuf[cur];
647: PetscInt dups = 1;
649: while (cur + dups < nz && jbuf[cur + dups] == curColIdx) dups++;
650: if (curColIdx >= cstart && curColIdx < cend) {
651: Fdj[Fdnz] = curColIdx - cstart; // easily convert to local
652: Fdjmap[Fdnz + 1] = Fdjmap[Fdnz] + dups;
653: for (PetscInt j = 0; j < dups; j++) Fdjperm[Fdjmap[Fdnz] + j] = recvNzPermSorted[FnzDups + j];
654: FdnzDups += dups;
655: Fdnz++;
656: } else {
657: Foj[Fonz] = curColIdx; // in global
658: Fojmap[Fonz + 1] = Fojmap[Fonz] + dups;
659: for (PetscInt j = 0; j < dups; j++) Fojperm[Fojmap[Fonz] + j] = recvNzPermSorted[FnzDups + j];
660: FonzDups += dups;
661: Fonz++;
662: }
663: cur += dups;
664: FnzDups += dups;
665: }
666: iter += dupRows; // Move to next unique row
667: }
668: PetscCall(PetscFree4(recvRowPerm, recvColSorted, recvNzPerm, recvNzPermSorted));
669: PetscCall(PetscFree5(sendRowLen, recvRowLen, sdisp, rdisp, reqs));
671: // Combine global column indices in garray1[] and Foj[]
672: PetscInt n2, *garray2;
674: PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
675: mm->sf = reduceSF;
676: mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
677: mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots);
678: mm->garray = garray2; // give ownership, so no free
679: mm->n = n2;
680: mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
681: mm->Fdjmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjmap_h);
682: mm->Fdjperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdjperm_h);
683: mm->Fojmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojmap_h);
684: mm->Fojperm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fojperm_h);
686: // Output Fd and Fo in KokkosCsrMatrix format
687: MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz);
688: MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
689: MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
690: MatScalarKokkosView Foa_d(NoInit("Foa_d"), Fonz);
691: MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
692: MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
694: PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
695: PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d)); // Fo's column size is n2, length of garray2[]
697: // Compute kernel launch parameters in merging E
698: PetscInt teamSize, vectorLength, rowsPerTeam;
700: teamSize = vectorLength = rowsPerTeam = -1;
701: PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Em, Enz, -1, teamSize, vectorLength, rowsPerTeam));
702: mm->E_TeamSize = teamSize;
703: mm->E_VectorLength = vectorLength;
704: mm->E_RowsPerTeam = rowsPerTeam;
705: } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
707: // Handy aliases
708: auto &Aa = A.values;
709: auto &Ba = B.values;
710: const auto &Ai = A.graph.row_map;
711: const auto &Bi = B.graph.row_map;
712: const auto &E_NzLeft = mm->E_NzLeft;
713: auto &leafBuf = mm->leafBuf;
714: auto &rootBuf = mm->rootBuf;
715: PetscSF reduceSF = mm->sf;
716: PetscInt Em = A.numRows();
717: PetscInt teamSize = mm->E_TeamSize;
718: PetscInt vectorLength = mm->E_VectorLength;
719: PetscInt rowsPerTeam = mm->E_RowsPerTeam;
720: PetscInt workSets = (Em + rowsPerTeam - 1) / rowsPerTeam;
722: // Copy rows in A/B of E to leafBuf, then pass it to rootBuf
723: PetscCallCXX(Kokkos::parallel_for(
724: Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
725: Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
726: PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
727: if (i < Em) {
728: PetscInt disp = Ai(i) + Bi(i);
729: PetscInt alen = Ai(i + 1) - Ai(i);
730: PetscInt blen = Bi(i + 1) - Bi(i);
731: PetscInt nzleft = E_NzLeft(i);
733: Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
734: MatScalar &val = leafBuf(disp + j);
735: if (j < nzleft) { // B left
736: val = Ba(Bi(i) + j);
737: } else if (j < nzleft + alen) { // diag A
738: val = Aa(Ai(i) + j - nzleft);
739: } else { // B right
740: val = Ba(Bi(i) + j - alen);
741: }
742: });
743: }
744: });
745: }));
746: PetscCall(PetscSFReduceWithMemTypeBegin(reduceSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, leafBuf.data(), PETSC_MEMTYPE_KOKKOS, rootBuf.data(), MPI_REPLACE));
747: PetscFunctionReturn(PETSC_SUCCESS);
748: }
750: // To finish MatMPIAIJKokkosReduce.
751: static PetscErrorCode MatMPIAIJKokkosReduceEnd(MPI_Comm comm, KokkosCsrMatrix A, KokkosCsrMatrix B, PetscInt cstart, PetscInt cend, const PetscInt *garray1, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AtB *mm)
752: {
753: auto &leafBuf = mm->leafBuf;
754: auto &rootBuf = mm->rootBuf;
755: auto &Fda = mm->Fd.values;
756: const auto &Fdjmap = mm->Fdjmap;
757: const auto &Fdjperm = mm->Fdjperm;
758: auto Fdnz = mm->Fd.nnz();
759: auto &Foa = mm->Fo.values;
760: const auto &Fojmap = mm->Fojmap;
761: const auto &Fojperm = mm->Fojperm;
762: auto Fonz = mm->Fo.nnz();
763: PetscSF reduceSF = mm->sf;
765: PetscFunctionBegin;
766: PetscCall(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, leafBuf.data(), rootBuf.data(), MPI_REPLACE));
768: // Reduce data in rootBuf to Fd and Fo
769: PetscCallCXX(Kokkos::parallel_for(
770: Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fdnz), KOKKOS_LAMBDA(const MatRowMapType i) {
771: PetscScalar sum = 0.0;
772: for (MatRowMapType k = Fdjmap(i); k < Fdjmap(i + 1); k++) sum += rootBuf(Fdjperm(k));
773: Fda(i) = sum;
774: }));
776: PetscCallCXX(Kokkos::parallel_for(
777: Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Fonz), KOKKOS_LAMBDA(const MatRowMapType i) {
778: PetscScalar sum = 0.0;
779: for (MatRowMapType k = Fojmap(i); k < Fojmap(i + 1); k++) sum += rootBuf(Fojperm(k));
780: Foa(i) = sum;
781: }));
782: PetscFunctionReturn(PETSC_SUCCESS);
783: }
785: /*
786: MatMPIAIJKokkosBcast - Bcast local rows of a MPIAIJKOKKOS matrix (E) to produce a local matrix (F, stored in mm) in split form
788: This is a complex routine. It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports
789: device and involves various index mapping.
791: In the given ownerSF, leaves correspond to rows in F, and roots correspond to rows in E. Roots may connect to multiple leaves.
792: Suppose F's j-th row is connected to a root identified by PetscSFNode (k,i), it means we need to bcast the i-th row of E on rank k
793: to j-th row of F. ownerSF is not an arbitrary SF, instead it is the Mvctx of another MPIAIJ matrix A that is able to perform A*E.
794: F has the same column layout as E.
796: Conceptually F has global column indices. In this routine, we spit F into diagonal Fd and off-diagonal Fo.
797: Fd uses local column indices, which are easy to compute. We just need to subtract the "local column range start" from the global indices.
798: Fo had global column indices at first. We will reduce them into local ones. In doing that, we also take into account the global
799: column indices that E's off-diag block has. Let's say there are n1 such indices stored in garray1[]. We will reduce them along with
800: column indices in Fo and update Fo with local indices.
802: Input Parameters:
803: + E - the MPIAIJKOKKOS matrix
804: . ownerSF - the ownership SF (insignificant in MAT_REUSE_MATRIX)
805: . reuse - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
806: - mm - to stash matproduct intermediate data structures
808: Output Parameters:
809: + map[n1] - allocated by caller. It maps garray1[] to garray2[]. See more at ReduceTwoSetsOfGlobalIndices.
810: - mm - contains various info, such as garray2[], Fd, Fo, etc.
812: Notes:
813: When reuse = MAT_REUSE_MATRIX, ownerSF, map are not significant.
814: The routine is provide in split-phase form MatMPIAIJKokkosBcastBegin/End() to provide computation/communication opportunities.
815: */
816: static PetscErrorCode MatMPIAIJKokkosBcastBegin(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
817: {
818: Mat_MPIAIJ *empi = static_cast<Mat_MPIAIJ *>(E->data);
819: Mat A = empi->A, B = empi->B; // diag and off-diag
820: Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
821: PetscInt Em = E->rmap->n; // #local rows
822: MPI_Comm comm;
824: PetscFunctionBegin;
825: PetscCallMPI(PetscObjectGetComm((PetscObject)E, &comm));
826: if (reuse == MAT_INITIAL_MATRIX) {
827: Mat_SeqAIJ *aseq = static_cast<Mat_SeqAIJ *>(A->data), *bseq = static_cast<Mat_SeqAIJ *>(B->data);
828: PetscInt n1 = B->cmap->n, *Ai = aseq->i, *Aj = aseq->j, *Bi = bseq->i, *Bj = bseq->j;
829: const PetscInt *garray1 = empi->garray; // its size is n1
830: PetscInt cstart, cend;
831: PetscSF bcastSF;
833: PetscCall(MatGetOwnershipRangeColumn(E, &cstart, &cend));
835: // Count how many nonzeros of each row in E are in the left of the diag block [cstart,cend)
836: PetscIntKokkosViewHost E_NzLeft_h(NoInit("E_NzLeft_h"), Em), E_RowLen_h(NoInit("E_RowLen_h"), Em);
837: PetscInt *E_NzLeft = E_NzLeft_h.data(), *E_RowLen = E_RowLen_h.data();
838: for (PetscInt i = 0; i < Em; i++) {
839: const PetscInt *first, *last, *it;
840: PetscInt count, step;
841: // std::lower_bound(first,last,cstart), but need to use global column indices
842: first = Bj + Bi[i];
843: last = Bj + Bi[i + 1];
844: count = last - first;
845: while (count > 0) {
846: it = first;
847: step = count / 2;
848: it += step;
849: if (empi->garray[*it] < cstart) { // map local to global
850: first = ++it;
851: count -= step + 1;
852: } else count = step;
853: }
854: E_NzLeft[i] = first - (Bj + Bi[i]);
855: E_RowLen[i] = (Ai[i + 1] - Ai[i]) + (Bi[i + 1] - Bi[i]);
856: }
858: // Compute row pointer Fi of F
859: PetscInt *Fi, Fm, Fnz;
860: PetscCall(PetscSFGetGraph(ownerSF, NULL, &Fm, NULL, NULL)); // Fm = #rows of F = nleaves of ownerSF
861: PetscCall(PetscMalloc1(Fm + 1, &Fi));
862: Fi[0] = 0;
863: PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, E_RowLen, PETSC_MEMTYPE_HOST, &Fi[1], MPI_REPLACE));
864: PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, E_RowLen, &Fi[1], MPI_REPLACE));
865: for (PetscInt i = 0; i < Fm; i++) Fi[i + 1] += Fi[i];
866: Fnz = Fi[Fm];
868: // Build the real PetscSF for bcasting E rows (buffer to buffer)
869: const PetscMPIInt *iranks, *ranks;
870: const PetscInt *ioffset, *irootloc, *roffset;
871: PetscInt niranks, nranks, *sdisp, *rdisp;
872: MPI_Request *reqs;
873: PetscMPIInt tag;
875: PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc)); // get leaf ranks referencing roots on this process
876: PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL, NULL)); // recv info
877: PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
879: sdisp[0] = 0; // send displacement
880: for (PetscInt i = 0; i < niranks; i++) {
881: sdisp[i + 1] = sdisp[i];
882: for (PetscInt j = ioffset[i]; j < ioffset[i + 1]; j++) {
883: PetscInt r = irootloc[j]; // row to be sent
884: sdisp[i + 1] += E_RowLen[r];
885: }
886: }
888: PetscCallMPI(PetscCommGetNewTag(comm, &tag));
889: for (PetscInt i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
890: for (PetscInt i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
891: PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));
893: PetscInt nleaves = Fnz; // leaves are nonzeros I will receive
894: PetscInt nroots = sdisp[niranks]; // roots are nonzeros I will send
895: PetscSFNode *iremote; // give ownership to bcastSF
896: PetscCall(PetscMalloc1(nleaves, &iremote));
897: for (PetscInt i = 0; i < nranks; i++) { // for each sender rank
898: PetscInt k = 0;
899: for (PetscInt j = Fi[roffset[i]]; j < Fi[roffset[i + 1]]; j++) { // I will receive rows [roffset[i], roffset[i+1]) of F from ranks[i]
900: iremote[j].rank = ranks[i];
901: iremote[j].index = rdisp[i] + k; // their root location
902: k++;
903: }
904: }
905: PetscCall(PetscSFCreate(comm, &bcastSF));
906: PetscCall(PetscSFSetGraph(bcastSF, nroots, nleaves, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
907: PetscCall(PetscFree3(sdisp, rdisp, reqs));
909: // Build a plan (rowoffset, irootloc, E_NzLeft) to copy rows in E to rootdata of bcastSF in parallel
910: PetscIntKokkosViewHost rowoffset_h(NoInit("rowoffset_h"), ioffset[niranks] + 1);
911: PetscInt *rowoffset = rowoffset_h.data(); // for each entry (row) indicated in irootloc[], we calculate its destinate offset in copying
912: rowoffset[0] = 0;
913: for (PetscInt i = 0; i < ioffset[niranks]; i++) { rowoffset[i + 1] = rowoffset[i] + E_RowLen[irootloc[i]]; }
915: // Copy (global) column indices of the needed rows in E to a buffer, and then bcast to Fj[]
916: PetscInt *jbuf, *Fj;
917: PetscCall(PetscMalloc2(nroots, &jbuf, Fnz, &Fj));
918: for (PetscInt k = 0; k < ioffset[niranks]; k++) {
919: PetscInt i = irootloc[k]; // row to be copied
920: PetscInt *buf = &jbuf[rowoffset[k]];
921: PetscInt nzLeft = E_NzLeft[i];
922: PetscInt alen = Ai[i + 1] - Ai[i], blen = Bi[i + 1] - Bi[i];
923: for (PetscInt j = 0; j < alen + blen; j++) {
924: if (j < nzLeft) {
925: buf[j] = empi->garray[Bj[Bi[i] + j]]; // left B, in global
926: } else if (j < nzLeft + alen) {
927: buf[j] = Aj[Ai[i] + j - nzLeft] + cstart; // diag A, also in global
928: } else {
929: buf[j] = empi->garray[Bj[Bi[i] + j - alen]]; // right B, in global
930: }
931: }
932: }
933: PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_INT, PETSC_MEMTYPE_HOST, jbuf, PETSC_MEMTYPE_HOST, Fj, MPI_REPLACE));
934: PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf, Fj, MPI_REPLACE));
936: // Build a plan (i.e., F_NzLeft) to split F into Fd and Fo
937: MatRowMapKokkosViewHost Fdi_h(NoInit("Fdi_h"), Fm + 1), Foi_h(NoInit("Foi_h"), Fm + 1); // row pointer of Fd, Fo
938: MatColIdxKokkosViewHost F_NzLeft_h(NoInit("F_NzLeft_h"), Fm); // split each row of F into Left, Diag, Right. We only need to record #nz in Left and Diag.
939: MatRowMapType *Fdi = Fdi_h.data(), *Foi = Foi_h.data();
940: MatColIdxType *F_NzLeft = F_NzLeft_h.data();
942: Fdi[0] = Foi[0] = 0;
943: for (PetscInt i = 0; i < Fm; i++) {
944: PetscInt *first, *last, *lb1, *lb2;
945: // cut the row into: Left, [cstart, cend), Right
946: first = Fj + Fi[i];
947: last = Fj + Fi[i + 1];
948: lb1 = std::lower_bound(first, last, cstart);
949: F_NzLeft[i] = lb1 - first;
950: lb2 = std::lower_bound(first, last, cend);
951: Fdi[i + 1] = lb2 - lb1; // row i length in Fdi
952: Foi[i + 1] = (Fi[i + 1] - Fi[i]) - Fdi[i + 1]; // row i length in Foi
953: }
954: for (PetscInt i = 0; i < Fm; i++) {
955: Fdi[i + 1] += Fdi[i];
956: Foi[i + 1] += Foi[i];
957: }
959: // Fill Fdj[] and Foj[], i.e., columns of Fd and Fo. Fdj[] are local, but Foj[] are not yet.
960: PetscInt Fdnz = Fdi[Fm], Fonz = Foi[Fm];
961: MatColIdxKokkosViewHost Fdj_h(NoInit("Fdj_h"), Fdnz), Foj_h(NoInit("Foj_h"), Fonz);
962: MatColIdxType *Fdj = Fdj_h.data(), *Foj = Foj_h.data(), gid;
964: for (PetscInt i = 0; i < Fm; i++) {
965: PetscInt nzLeft = F_NzLeft[i];
966: PetscInt len = Fdi[i + 1] - Fdi[i]; // diag row len
967: for (PetscInt j = 0; j < Fi[i + 1] - Fi[i]; j++) {
968: gid = Fj[Fi[i] + j];
969: if (j < nzLeft) { // left, in global
970: Foj[Foi[i] + j] = gid;
971: } else if (j < nzLeft + len) { // diag, in local
972: Fdj[Fdi[i] + j - nzLeft] = gid - cstart;
973: } else { // right, in global
974: Foj[Foi[i] + j - len] = gid;
975: }
976: }
977: }
978: PetscCall(PetscFree2(jbuf, Fj));
979: PetscCall(PetscFree(Fi));
981: // Reduce global indices in Foj[] and garray1[] into local ones
982: PetscInt n2, *garray2;
983: PetscCall(ReduceTwoSetsOfGlobalIndices(n1, garray1, Fonz, Foj, &n2, &garray2, map));
985: // Record the plans built above, for reuse
986: PetscIntKokkosViewHost tmp(const_cast<PetscInt *>(irootloc), ioffset[niranks]); // irootloc[] is owned by ownerSF. We create a copy for safety
987: PetscIntKokkosViewHost irootloc_h(NoInit("irootloc_h"), ioffset[niranks]);
988: Kokkos::deep_copy(irootloc_h, tmp);
989: mm->sf = bcastSF;
990: mm->E_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), E_NzLeft_h);
991: mm->F_NzLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), F_NzLeft_h);
992: mm->irootloc = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), irootloc_h);
993: mm->rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowoffset_h);
994: mm->rootBuf = MatScalarKokkosView(NoInit("rootBuf"), nroots);
995: mm->leafBuf = MatScalarKokkosView(NoInit("leafBuf"), nleaves);
996: mm->garray = garray2;
997: mm->n = n2;
999: // Output Fd and Fo in KokkosCsrMatrix format
1000: MatScalarKokkosView Fda_d(NoInit("Fda_d"), Fdnz), Foa_d(NoInit("Foa_d"), Fonz);
1001: MatRowMapKokkosView Fdi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdi_h);
1002: MatColIdxKokkosView Fdj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Fdj_h);
1003: MatRowMapKokkosView Foi_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foi_h);
1004: MatColIdxKokkosView Foj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Foj_h);
1006: PetscCallCXX(mm->Fd = KokkosCsrMatrix("Fd", Fm, cend - cstart, Fdnz, Fda_d, Fdi_d, Fdj_d));
1007: PetscCallCXX(mm->Fo = KokkosCsrMatrix("Fo", Fm, n2, Fonz, Foa_d, Foi_d, Foj_d));
1009: // Compute kernel launch parameters in merging E or splitting F
1010: PetscInt teamSize, vectorLength, rowsPerTeam;
1012: teamSize = vectorLength = rowsPerTeam = -1;
1013: PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(mm->irootloc.extent(0), mm->rootBuf.extent(0), -1, teamSize, vectorLength, rowsPerTeam));
1014: mm->E_TeamSize = teamSize;
1015: mm->E_VectorLength = vectorLength;
1016: mm->E_RowsPerTeam = rowsPerTeam;
1018: teamSize = vectorLength = rowsPerTeam = -1;
1019: PetscCall(MatMergeGetLaunchParameters<DefaultExecutionSpace>(Fm, Fnz, -1, teamSize, vectorLength, rowsPerTeam));
1020: mm->F_TeamSize = teamSize;
1021: mm->F_VectorLength = vectorLength;
1022: mm->F_RowsPerTeam = rowsPerTeam;
1023: } else PetscCheck(reuse == MAT_REUSE_MATRIX, comm, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
1025: // Sync E's value to device
1026: akok->a_dual.sync_device();
1027: bkok->a_dual.sync_device();
1029: // Handy aliases
1030: const auto &Aa = akok->a_dual.view_device();
1031: const auto &Ba = bkok->a_dual.view_device();
1032: const auto &Ai = akok->i_dual.view_device();
1033: const auto &Bi = bkok->i_dual.view_device();
1035: // Fetch the plans
1036: PetscIntKokkosView &E_NzLeft = mm->E_NzLeft;
1037: PetscSF &bcastSF = mm->sf;
1038: MatScalarKokkosView &rootBuf = mm->rootBuf;
1039: MatScalarKokkosView &leafBuf = mm->leafBuf;
1040: PetscIntKokkosView &irootloc = mm->irootloc;
1041: PetscIntKokkosView &rowoffset = mm->rowoffset;
1043: PetscInt teamSize = mm->E_TeamSize;
1044: PetscInt vectorLength = mm->E_VectorLength;
1045: PetscInt rowsPerTeam = mm->E_RowsPerTeam;
1046: PetscInt workSets = (irootloc.extent(0) + rowsPerTeam - 1) / rowsPerTeam;
1048: // Copy rows in A/B of E to rootBuf, then bcast it to leafBuf
1049: PetscCallCXX(Kokkos::parallel_for(
1050: Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1051: Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1052: size_t r = t.league_rank() * rowsPerTeam + k; // r-th entry in irootloc[]
1053: if (r < irootloc.extent(0)) {
1054: PetscInt i = irootloc(r); // row i of E
1055: PetscInt disp = rowoffset(r);
1056: PetscInt alen = Ai(i + 1) - Ai(i);
1057: PetscInt blen = Bi(i + 1) - Bi(i);
1058: PetscInt nzleft = E_NzLeft(i);
1060: Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1061: if (j < nzleft) { // B left
1062: rootBuf(disp + j) = Ba(Bi(i) + j);
1063: } else if (j < nzleft + alen) { // diag A
1064: rootBuf(disp + j) = Aa(Ai(i) + j - nzleft);
1065: } else { // B right
1066: rootBuf(disp + j) = Ba(Bi(i) + j - alen);
1067: }
1068: });
1069: }
1070: });
1071: }));
1072: PetscCall(PetscSFBcastWithMemTypeBegin(bcastSF, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, rootBuf.data(), PETSC_MEMTYPE_KOKKOS, leafBuf.data(), MPI_REPLACE));
1073: PetscFunctionReturn(PETSC_SUCCESS);
1074: }
1076: // To finish MatMPIAIJKokkosBcast.
1077: static PetscErrorCode MatMPIAIJKokkosBcastEnd(Mat E, PetscSF ownerSF, MatReuse reuse, PetscInt *map, MatMatStruct_AB *mm)
1078: {
1079: PetscFunctionBegin;
1080: const auto &Fd = mm->Fd;
1081: const auto &Fo = mm->Fo;
1082: const auto &Fdi = Fd.graph.row_map;
1083: const auto &Foi = Fo.graph.row_map;
1084: auto &Fda = Fd.values;
1085: auto &Foa = Fo.values;
1086: auto Fm = Fd.numRows();
1088: PetscIntKokkosView &F_NzLeft = mm->F_NzLeft;
1089: PetscSF &bcastSF = mm->sf;
1090: MatScalarKokkosView &rootBuf = mm->rootBuf;
1091: MatScalarKokkosView &leafBuf = mm->leafBuf;
1092: PetscInt teamSize = mm->F_TeamSize;
1093: PetscInt vectorLength = mm->F_VectorLength;
1094: PetscInt rowsPerTeam = mm->F_RowsPerTeam;
1095: PetscInt workSets = (Fm + rowsPerTeam - 1) / rowsPerTeam;
1097: PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, rootBuf.data(), leafBuf.data(), MPI_REPLACE));
1099: // Update Fda and Foa with new data in leafBuf (as if it is Fa)
1100: PetscCallCXX(Kokkos::parallel_for(
1101: Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), workSets, teamSize, vectorLength), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1102: Kokkos::parallel_for(Kokkos::TeamThreadRange(t, 0, rowsPerTeam), [&](PetscInt k) {
1103: PetscInt i = t.league_rank() * rowsPerTeam + k; // i-th row in F
1104: if (i < Fm) {
1105: PetscInt nzLeft = F_NzLeft(i);
1106: PetscInt alen = Fdi(i + 1) - Fdi(i);
1107: PetscInt blen = Foi(i + 1) - Foi(i);
1108: PetscInt Fii = Fdi(i) + Foi(i);
1110: Kokkos::parallel_for(Kokkos::ThreadVectorRange(t, alen + blen), [&](PetscInt j) {
1111: PetscScalar val = leafBuf(Fii + j);
1112: if (j < nzLeft) { // left
1113: Foa(Foi(i) + j) = val;
1114: } else if (j < nzLeft + alen) { // diag
1115: Fda(Fdi(i) + j - nzLeft) = val;
1116: } else { // right
1117: Foa(Foi(i) + j - alen) = val;
1118: }
1119: });
1120: }
1121: });
1122: }));
1123: PetscFunctionReturn(PETSC_SUCCESS);
1124: }
1126: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1127: {
1128: Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1129: Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1130: KokkosCsrMatrix Adt, Aot, Ad, Ao, Bd, Bo;
1131: PetscInt cstart, cend;
1132: MPI_Comm comm;
1134: PetscFunctionBegin;
1135: PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1136: PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1137: PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1138: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1139: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1140: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1141: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1143: // TODO: add command line options to select spgemm algorithms
1144: auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1146: // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1147: #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1148: #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1149: spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1150: #endif
1151: #endif
1153: PetscCallCXX(mm->kh1.create_spgemm_handle(spgemm_alg));
1154: PetscCallCXX(mm->kh2.create_spgemm_handle(spgemm_alg));
1155: PetscCallCXX(mm->kh3.create_spgemm_handle(spgemm_alg));
1156: PetscCallCXX(mm->kh4.create_spgemm_handle(spgemm_alg));
1158: // Aot * (B's diag + B's off-diag)
1159: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Aot, false, Bd, false, mm->C3));
1160: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Aot, false, Bo, false, mm->C4));
1161: // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1162: // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1163: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1164: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1165: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1167: PetscCallCXX(sort_crs_matrix(mm->C3));
1168: PetscCallCXX(sort_crs_matrix(mm->C4));
1169: #endif
1171: // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1172: PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1173: PetscCall(MatGetOwnershipRangeColumn(B, &cstart, &cend));
1174: PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1176: // Adt * (B's diag + B's off-diag)
1177: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Adt, false, Bd, false, mm->C1));
1178: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1179: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1180: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1181: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1182: PetscCallCXX(sort_crs_matrix(mm->C1));
1183: PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1184: #endif
1186: PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, cstart, cend, bmpi->garray, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1188: // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1189: MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1190: PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1191: PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1192: PetscCallCXX(mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj));
1194: // C = (C1+Fd, C2+Fo)
1195: PetscCallCXX(mm->kh1.create_spadd_handle(true)); // C1, Fd are sorted
1196: PetscCallCXX(mm->kh2.create_spadd_handle(true)); // C2, Fo are sorted
1197: PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->Fd, mm->Cd));
1198: PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->Fo, mm->Co));
1199: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1200: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1201: PetscFunctionReturn(PETSC_SUCCESS);
1202: }
1204: static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AtB *mm)
1205: {
1206: Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1207: Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1208: KokkosCsrMatrix Adt, Aot, Bd, Bo;
1209: MPI_Comm comm;
1211: PetscFunctionBegin;
1212: PetscCall(PetscObjectGetComm((PetscObject)B, &comm));
1213: PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->A, &Adt));
1214: PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(ampi->B, &Aot));
1215: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1216: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1218: // Aot * (B's diag + B's off-diag)
1219: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Aot, false, Bd, false, mm->C3));
1220: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Aot, false, Bo, false, mm->C4));
1222: // Reduce E (i.e., C3 and C4)'s rows to form F, and overlap the communication
1223: PetscCall(MatMPIAIJKokkosReduceBegin(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1225: // Adt * (B's diag + B's off-diag)
1226: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Adt, false, Bd, false, mm->C1));
1227: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Adt, false, Bo, false, mm->C2_mid));
1229: PetscCall(MatMPIAIJKokkosReduceEnd(comm, mm->C3, mm->C4, 0, 0, NULL, NULL, MAT_REUSE_MATRIX, NULL, mm));
1231: // C = (C1+Fd, C2+Fo)
1232: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->Fd, mm->Cd));
1233: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->Fo, mm->Co));
1234: PetscFunctionReturn(PETSC_SUCCESS);
1235: }
1237: /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos
1239: Input Parameters:
1240: + product - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1241: . A - an MPIAIJKOKKOS matrix
1242: . B - an MPIAIJKOKKOS matrix
1243: - mm - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.
1244: */
1245: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1246: {
1247: Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1248: Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1249: KokkosCsrMatrix Ad, Ao, Bd, Bo;
1251: PetscFunctionBegin;
1252: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1253: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1254: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1255: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1257: // TODO: add command line options to select spgemm algorithms
1258: auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; // default is TPL if enabled, otherwise KK
1260: // CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4
1261: #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1262: #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
1263: spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
1264: #endif
1265: #endif
1267: mm->kh1.create_spgemm_handle(spgemm_alg);
1268: mm->kh2.create_spgemm_handle(spgemm_alg);
1269: mm->kh3.create_spgemm_handle(spgemm_alg);
1270: mm->kh4.create_spgemm_handle(spgemm_alg);
1272: // Bcast B's rows to form F, and overlap the communication
1273: PetscIntKokkosViewHost map_h(NoInit("map_h"), bmpi->B->cmap->n);
1274: PetscCall(MatMPIAIJKokkosBcastBegin(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1276: // A's diag * (B's diag + B's off-diag)
1277: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh1, Ad, false, Bd, false, mm->C1));
1278: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh2, Ad, false, Bo, false, mm->C2_mid)); // C2 aliases with C2_mid, except with new column indices
1279: // KK spgemm_symbolic() only populates the result's row map, but not its columns.
1280: // TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
1281: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1282: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1283: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1284: PetscCallCXX(sort_crs_matrix(mm->C1));
1285: PetscCallCXX(sort_crs_matrix(mm->C2_mid));
1286: #endif
1288: PetscCall(MatMPIAIJKokkosBcastEnd(B, ampi->Mvctx, MAT_INITIAL_MATRIX, map_h.data(), mm));
1290: // A's off-diag * (F's diag + F's off-diag)
1291: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1292: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1293: PetscCallCXX(KokkosSparse::spgemm_symbolic(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1294: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1295: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
1296: PetscCallCXX(sort_crs_matrix(mm->C3));
1297: PetscCallCXX(sort_crs_matrix(mm->C4));
1298: #endif
1300: // Create C2, which shares a, i arrays with C2_mid, but with new column indices and potentially larger column size
1301: MatColIdxKokkosView oldj = mm->C2_mid.graph.entries, newj(NoInit("j"), oldj.extent(0));
1302: PetscIntKokkosView map = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), map_h);
1303: PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, oldj.extent(0)), KOKKOS_LAMBDA(const PetscInt i) { newj(i) = map(oldj(i)); }));
1304: mm->C2 = KokkosCsrMatrix("C2", mm->C2_mid.numRows(), mm->n /*new column size*/, mm->C2_mid.nnz(), mm->C2_mid.values, mm->C2_mid.graph.row_map, newj);
1306: // C = (Cd, Co) = (C1+C3, C2+C4)
1307: mm->kh1.create_spadd_handle(true); // C1, C3 are sorted
1308: mm->kh2.create_spadd_handle(true); // C2, C4 are sorted
1309: PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh1, mm->C1, mm->C3, mm->Cd));
1310: PetscCallCXX(KokkosSparse::spadd_symbolic(&mm->kh2, mm->C2, mm->C4, mm->Co));
1311: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1312: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1313: PetscFunctionReturn(PETSC_SUCCESS);
1314: }
1316: static PetscErrorCode MatProductNumeric_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1317: {
1318: Mat_MPIAIJ *ampi = static_cast<Mat_MPIAIJ *>(A->data);
1319: Mat_MPIAIJ *bmpi = static_cast<Mat_MPIAIJ *>(B->data);
1320: KokkosCsrMatrix Ad, Ao, Bd, Bo;
1322: PetscFunctionBegin;
1323: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->A, &Ad));
1324: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(ampi->B, &Ao));
1325: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->A, &Bd));
1326: PetscCall(MatSeqAIJKokkosGetKokkosCsrMatrix(bmpi->B, &Bo));
1328: // Bcast B's rows to form F, and overlap the communication
1329: PetscCall(MatMPIAIJKokkosBcastBegin(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1331: // A's diag * (B's diag + B's off-diag)
1332: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh1, Ad, false, Bd, false, mm->C1));
1333: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh2, Ad, false, Bo, false, mm->C2_mid));
1335: PetscCall(MatMPIAIJKokkosBcastEnd(B, NULL, MAT_REUSE_MATRIX, NULL, mm));
1337: // A's off-diag * (F's diag + F's off-diag)
1338: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh3, Ao, false, mm->Fd, false, mm->C3));
1339: PetscCallCXX(KokkosSparse::spgemm_numeric(mm->kh4, Ao, false, mm->Fo, false, mm->C4));
1341: // C = (Cd, Co) = (C1+C3, C2+C4)
1342: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh1, 1.0, mm->C1, 1.0, mm->C3, mm->Cd));
1343: PetscCallCXX(KokkosSparse::spadd_numeric(&mm->kh2, 1.0, mm->C2, 1.0, mm->C4, mm->Co));
1344: PetscFunctionReturn(PETSC_SUCCESS);
1345: }
1347: static PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1348: {
1349: Mat_MPIAIJ *cmpi = static_cast<Mat_MPIAIJ *>(C->data);
1350: Mat_Product *product;
1351: MatProductData_MPIAIJKokkos *pdata;
1352: MatProductType ptype;
1353: Mat A, B;
1355: PetscFunctionBegin;
1356: MatCheckProduct(C, 1); // make sure C is a product
1357: product = C->product;
1358: pdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1359: ptype = product->type;
1360: A = product->A;
1361: B = product->B;
1363: // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
1364: // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
1365: // we still do numeric.
1366: if (pdata->reusesym) { // numeric reuses results from symbolic
1367: pdata->reusesym = PETSC_FALSE;
1368: PetscFunctionReturn(PETSC_SUCCESS);
1369: }
1371: if (ptype == MATPRODUCT_AB) {
1372: PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1373: } else if (ptype == MATPRODUCT_AtB) {
1374: PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, A, B, pdata->mmAtB));
1375: } else if (ptype == MATPRODUCT_PtAP) { // BtAB, computed by Z = AB; C= BtZ
1376: PetscCall(MatProductNumeric_MPIAIJKokkos_AB(product, A, B, pdata->mmAB));
1377: PetscCall(MatProductNumeric_MPIAIJKokkos_AtB(product, B, pdata->Z, pdata->mmAtB));
1378: }
1380: PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->A)); // mark that A, B on device are modified
1381: PetscCall(MatSeqAIJKokkosModifyDevice(cmpi->B));
1382: PetscFunctionReturn(PETSC_SUCCESS);
1383: }
1385: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1386: {
1387: Mat A, B;
1388: Mat_Product *product;
1389: MatProductType ptype;
1390: MatProductData_MPIAIJKokkos *pdata;
1391: MatMatStruct *mm = NULL;
1392: PetscInt m, n, M, N;
1393: Mat Cd, Co;
1394: MPI_Comm comm;
1396: PetscFunctionBegin;
1397: PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
1398: MatCheckProduct(C, 1);
1399: product = C->product;
1400: PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
1401: ptype = product->type;
1402: A = product->A;
1403: B = product->B;
1405: switch (ptype) {
1406: case MATPRODUCT_AB:
1407: m = A->rmap->n;
1408: n = B->cmap->n;
1409: M = A->rmap->N;
1410: N = B->cmap->N;
1411: break;
1412: case MATPRODUCT_AtB:
1413: m = A->cmap->n;
1414: n = B->cmap->n;
1415: M = A->cmap->N;
1416: N = B->cmap->N;
1417: break;
1418: case MATPRODUCT_PtAP:
1419: m = B->cmap->n;
1420: n = B->cmap->n;
1421: M = B->cmap->N;
1422: N = B->cmap->N;
1423: break; /* BtAB */
1424: default:
1425: SETERRQ(comm, PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1426: }
1428: PetscCall(MatSetSizes(C, m, n, M, N));
1429: PetscCall(PetscLayoutSetUp(C->rmap));
1430: PetscCall(PetscLayoutSetUp(C->cmap));
1431: PetscCall(MatSetType(C, ((PetscObject)A)->type_name));
1433: pdata = new MatProductData_MPIAIJKokkos();
1434: pdata->reusesym = product->api_user;
1436: if (ptype == MATPRODUCT_AB) {
1437: auto mmAB = new MatMatStruct_AB();
1438: PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB));
1439: mm = pdata->mmAB = mmAB;
1440: } else if (ptype == MATPRODUCT_AtB) {
1441: auto mmAtB = new MatMatStruct_AtB();
1442: PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, B, mmAtB));
1443: mm = pdata->mmAtB = mmAtB;
1444: } else if (ptype == MATPRODUCT_PtAP) { // C = BtAB, computed as Z = AB; C= BtZ
1445: Mat Zd, Zo, Z; // Zd, Zo are owned by pdata->Z
1447: auto mmAB = new MatMatStruct_AB();
1448: PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmAB)); // Z stored as mmAB->{Cd, Co}
1449: PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Cd, &Zd));
1450: PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mmAB->Co, &Zo));
1451: pdata->mmAB = mmAB;
1453: m = A->rmap->n; // Z's layout
1454: n = B->cmap->n;
1455: M = A->rmap->N;
1456: N = B->cmap->N;
1457: PetscCall(MatCreate(comm, &Z));
1458: PetscCall(MatSetSizes(Z, m, n, M, N));
1459: PetscCall(PetscLayoutSetUp(Z->rmap));
1460: PetscCall(PetscLayoutSetUp(Z->cmap));
1461: PetscCall(MatSetType(Z, MATMPIAIJKOKKOS));
1462: PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Z, Zd, Zo, mmAB->garray));
1464: auto mmAtB = new MatMatStruct_AtB();
1465: PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, Z, mmAtB)); // final result C stored as mmAtB->{Cd, Co}
1467: pdata->Z = Z; // give ownership to pdata
1468: mm = pdata->mmAtB = mmAtB;
1469: }
1471: PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Cd, &Cd));
1472: PetscCall(MatCreateSeqAIJKokkosWithKokkosCsrMatrix(PETSC_COMM_SELF, mm->Co, &Co));
1473: PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co, mm->garray));
1475: C->product->data = pdata;
1476: C->product->destroy = MatProductDataDestroy_MPIAIJKokkos;
1477: C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1478: PetscFunctionReturn(PETSC_SUCCESS);
1479: }
1481: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1482: {
1483: Mat_Product *product = mat->product;
1484: PetscBool match = PETSC_FALSE;
1485: PetscBool usecpu = PETSC_FALSE;
1487: PetscFunctionBegin;
1488: MatCheckProduct(mat, 1);
1489: if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1490: if (match) { /* we can always fallback to the CPU if requested */
1491: switch (product->type) {
1492: case MATPRODUCT_AB:
1493: if (product->api_user) {
1494: PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1495: PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1496: PetscOptionsEnd();
1497: } else {
1498: PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1499: PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1500: PetscOptionsEnd();
1501: }
1502: break;
1503: case MATPRODUCT_AtB:
1504: if (product->api_user) {
1505: PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1506: PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1507: PetscOptionsEnd();
1508: } else {
1509: PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1510: PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1511: PetscOptionsEnd();
1512: }
1513: break;
1514: case MATPRODUCT_PtAP:
1515: if (product->api_user) {
1516: PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1517: PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1518: PetscOptionsEnd();
1519: } else {
1520: PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1521: PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1522: PetscOptionsEnd();
1523: }
1524: break;
1525: default:
1526: break;
1527: }
1528: match = (PetscBool)!usecpu;
1529: }
1530: if (match) {
1531: switch (product->type) {
1532: case MATPRODUCT_AB:
1533: case MATPRODUCT_AtB:
1534: case MATPRODUCT_PtAP:
1535: mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1536: break;
1537: default:
1538: break;
1539: }
1540: }
1541: /* fallback to MPIAIJ ops */
1542: if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1543: PetscFunctionReturn(PETSC_SUCCESS);
1544: }
1546: // Mirror of MatCOOStruct_MPIAIJ on device
1547: struct MatCOOStruct_MPIAIJKokkos {
1548: PetscCount n;
1549: PetscSF sf;
1550: PetscCount Annz, Bnnz;
1551: PetscCount Annz2, Bnnz2;
1552: PetscCountKokkosView Ajmap1, Aperm1;
1553: PetscCountKokkosView Bjmap1, Bperm1;
1554: PetscCountKokkosView Aimap2, Ajmap2, Aperm2;
1555: PetscCountKokkosView Bimap2, Bjmap2, Bperm2;
1556: PetscCountKokkosView Cperm1;
1557: MatScalarKokkosView sendbuf, recvbuf;
1559: MatCOOStruct_MPIAIJKokkos(const MatCOOStruct_MPIAIJ *coo_h)
1560: {
1561: auto &exec = PetscGetKokkosExecutionSpace();
1563: n = coo_h->n;
1564: sf = coo_h->sf;
1565: Annz = coo_h->Annz;
1566: Bnnz = coo_h->Bnnz;
1567: Annz2 = coo_h->Annz2;
1568: Bnnz2 = coo_h->Bnnz2;
1569: Ajmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap1, coo_h->Annz + 1));
1570: Aperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm1, coo_h->Atot1));
1571: Bjmap1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap1, coo_h->Bnnz + 1));
1572: Bperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm1, coo_h->Btot1));
1573: Aimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aimap2, coo_h->Annz2));
1574: Ajmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Ajmap2, coo_h->Annz2 + 1));
1575: Aperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Aperm2, coo_h->Atot2));
1576: Bimap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bimap2, coo_h->Bnnz2));
1577: Bjmap2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bjmap2, coo_h->Bnnz2 + 1));
1578: Bperm2 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Bperm2, coo_h->Btot2));
1579: Cperm1 = Kokkos::create_mirror_view_and_copy(exec, PetscCountKokkosViewHost(coo_h->Cperm1, coo_h->sendlen));
1580: sendbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->sendbuf, coo_h->sendlen));
1581: recvbuf = Kokkos::create_mirror_view(Kokkos::WithoutInitializing, exec, MatScalarKokkosViewHost(coo_h->recvbuf, coo_h->recvlen));
1582: PetscCallVoid(PetscObjectReference((PetscObject)sf));
1583: }
1585: ~MatCOOStruct_MPIAIJKokkos() { PetscCallVoid(PetscSFDestroy(&sf)); }
1586: };
1588: static PetscErrorCode MatCOOStructDestroy_MPIAIJKokkos(void *data)
1589: {
1590: PetscFunctionBegin;
1591: PetscCallCXX(delete static_cast<MatCOOStruct_MPIAIJKokkos *>(data));
1592: PetscFunctionReturn(PETSC_SUCCESS);
1593: }
1595: static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1596: {
1597: PetscContainer container_h, container_d;
1598: MatCOOStruct_MPIAIJ *coo_h;
1599: MatCOOStruct_MPIAIJKokkos *coo_d;
1601: PetscFunctionBegin;
1602: PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1603: mat->preallocated = PETSC_TRUE;
1604: PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1605: PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1606: PetscCall(MatZeroEntries(mat));
1608: // Copy the COO struct to device
1609: PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1610: PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1611: PetscCallCXX(coo_d = new MatCOOStruct_MPIAIJKokkos(coo_h));
1613: // Put the COO struct in a container and then attach that to the matrix
1614: PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
1615: PetscCall(PetscContainerSetPointer(container_d, coo_d));
1616: PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_MPIAIJKokkos));
1617: PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
1618: PetscCall(PetscContainerDestroy(&container_d));
1619: PetscFunctionReturn(PETSC_SUCCESS);
1620: }
1622: static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1623: {
1624: Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1625: Mat A = mpiaij->A, B = mpiaij->B;
1626: MatScalarKokkosView Aa, Ba;
1627: MatScalarKokkosView v1;
1628: PetscMemType memtype;
1629: PetscContainer container;
1630: MatCOOStruct_MPIAIJKokkos *coo;
1631: Kokkos::DefaultExecutionSpace &exec = PetscGetKokkosExecutionSpace();
1633: PetscFunctionBegin;
1634: PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1635: PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1637: const auto &n = coo->n;
1638: const auto &Annz = coo->Annz;
1639: const auto &Annz2 = coo->Annz2;
1640: const auto &Bnnz = coo->Bnnz;
1641: const auto &Bnnz2 = coo->Bnnz2;
1642: const auto &vsend = coo->sendbuf;
1643: const auto &v2 = coo->recvbuf;
1644: const auto &Ajmap1 = coo->Ajmap1;
1645: const auto &Ajmap2 = coo->Ajmap2;
1646: const auto &Aimap2 = coo->Aimap2;
1647: const auto &Bjmap1 = coo->Bjmap1;
1648: const auto &Bjmap2 = coo->Bjmap2;
1649: const auto &Bimap2 = coo->Bimap2;
1650: const auto &Aperm1 = coo->Aperm1;
1651: const auto &Aperm2 = coo->Aperm2;
1652: const auto &Bperm1 = coo->Bperm1;
1653: const auto &Bperm2 = coo->Bperm2;
1654: const auto &Cperm1 = coo->Cperm1;
1656: PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1657: if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we need to copy it to device if any */
1658: v1 = Kokkos::create_mirror_view_and_copy(exec, MatScalarKokkosViewHost((PetscScalar *)v, n));
1659: } else {
1660: v1 = MatScalarKokkosView((PetscScalar *)v, n); /* Directly use v[]'s memory */
1661: }
1663: if (imode == INSERT_VALUES) {
1664: PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1665: PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1666: } else {
1667: PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1668: PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1669: }
1671: PetscCall(PetscLogGpuTimeBegin());
1672: /* Pack entries to be sent to remote */
1673: Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, vsend.extent(0)), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });
1675: /* Send remote entries to their owner and overlap the communication with local computation */
1676: PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1677: /* Add local entries to A and B in one kernel */
1678: Kokkos::parallel_for(
1679: Kokkos::RangePolicy<>(exec, 0, Annz + Bnnz), KOKKOS_LAMBDA(PetscCount i) {
1680: PetscScalar sum = 0.0;
1681: if (i < Annz) {
1682: for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1683: Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1684: } else {
1685: i -= Annz;
1686: for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1687: Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1688: }
1689: });
1690: PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));
1692: /* Add received remote entries to A and B in one kernel */
1693: Kokkos::parallel_for(
1694: Kokkos::RangePolicy<>(exec, 0, Annz2 + Bnnz2), KOKKOS_LAMBDA(PetscCount i) {
1695: if (i < Annz2) {
1696: for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1697: } else {
1698: i -= Annz2;
1699: for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1700: }
1701: });
1702: PetscCall(PetscLogGpuTimeEnd());
1704: if (imode == INSERT_VALUES) {
1705: PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1706: PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1707: } else {
1708: PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1709: PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1710: }
1711: PetscFunctionReturn(PETSC_SUCCESS);
1712: }
1714: static PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1715: {
1716: PetscFunctionBegin;
1717: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1718: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1719: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1720: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1721: PetscCall(MatDestroy_MPIAIJ(A));
1722: PetscFunctionReturn(PETSC_SUCCESS);
1723: }
1725: static PetscErrorCode MatShift_MPIAIJKokkos(Mat A, PetscScalar a)
1726: {
1727: Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(A->data);
1728: PetscBool congruent;
1730: PetscFunctionBegin;
1731: PetscCall(MatHasCongruentLayouts(A, &congruent));
1732: if (congruent) { // square matrix and the diagonals are solely in the diag block
1733: PetscCall(MatShift(mpiaij->A, a));
1734: } else { // too hard, use the general version
1735: PetscCall(MatShift_Basic(A, a));
1736: }
1737: PetscFunctionReturn(PETSC_SUCCESS);
1738: }
1740: static PetscErrorCode MatSetOps_MPIAIJKokkos(Mat B)
1741: {
1742: PetscFunctionBegin;
1743: B->ops->assemblyend = MatAssemblyEnd_MPIAIJKokkos;
1744: B->ops->mult = MatMult_MPIAIJKokkos;
1745: B->ops->multadd = MatMultAdd_MPIAIJKokkos;
1746: B->ops->multtranspose = MatMultTranspose_MPIAIJKokkos;
1747: B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1748: B->ops->destroy = MatDestroy_MPIAIJKokkos;
1749: B->ops->shift = MatShift_MPIAIJKokkos;
1751: PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1752: PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1753: PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1754: PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1755: PetscFunctionReturn(PETSC_SUCCESS);
1756: }
1758: PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1759: {
1760: Mat B;
1761: Mat_MPIAIJ *a;
1763: PetscFunctionBegin;
1764: if (reuse == MAT_INITIAL_MATRIX) {
1765: PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1766: } else if (reuse == MAT_REUSE_MATRIX) {
1767: PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1768: }
1769: B = *newmat;
1771: B->boundtocpu = PETSC_FALSE;
1772: PetscCall(PetscFree(B->defaultvectype));
1773: PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1774: PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));
1776: a = static_cast<Mat_MPIAIJ *>(A->data);
1777: if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1778: if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1779: if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));
1780: PetscCall(MatSetOps_MPIAIJKokkos(B));
1781: PetscFunctionReturn(PETSC_SUCCESS);
1782: }
1784: /*MC
1785: MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos
1787: A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
1789: Options Database Key:
1790: . -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`
1792: Level: beginner
1794: .seealso: [](ch_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1795: M*/
1796: PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1797: {
1798: PetscFunctionBegin;
1799: PetscCall(PetscKokkosInitializeCheck());
1800: PetscCall(MatCreate_MPIAIJ(A));
1801: PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1802: PetscFunctionReturn(PETSC_SUCCESS);
1803: }
1805: /*@C
1806: MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1807: (the default parallel PETSc format). This matrix will ultimately pushed down
1808: to Kokkos for calculations.
1810: Collective
1812: Input Parameters:
1813: + comm - MPI communicator, set to `PETSC_COMM_SELF`
1814: . m - number of local rows (or `PETSC_DECIDE` to have calculated if `M` is given)
1815: This value should be the same as the local size used in creating the
1816: y vector for the matrix-vector product y = Ax.
1817: . n - This value should be the same as the local size used in creating the
1818: x vector for the matrix-vector product y = Ax. (or `PETSC_DECIDE` to have
1819: calculated if N is given) For square matrices n is almost always `m`.
1820: . M - number of global rows (or `PETSC_DETERMINE` to have calculated if `m` is given)
1821: . N - number of global columns (or `PETSC_DETERMINE` to have calculated if `n` is given)
1822: . d_nz - number of nonzeros per row in DIAGONAL portion of local submatrix
1823: (same value is used for all local rows)
1824: . d_nnz - array containing the number of nonzeros in the various rows of the
1825: DIAGONAL portion of the local submatrix (possibly different for each row)
1826: or `NULL`, if `d_nz` is used to specify the nonzero structure.
1827: The size of this array is equal to the number of local rows, i.e `m`.
1828: For matrices you plan to factor you must leave room for the diagonal entry and
1829: put in the entry even if it is zero.
1830: . o_nz - number of nonzeros per row in the OFF-DIAGONAL portion of local
1831: submatrix (same value is used for all local rows).
1832: - o_nnz - array containing the number of nonzeros in the various rows of the
1833: OFF-DIAGONAL portion of the local submatrix (possibly different for
1834: each row) or `NULL`, if `o_nz` is used to specify the nonzero
1835: structure. The size of this array is equal to the number
1836: of local rows, i.e `m`.
1838: Output Parameter:
1839: . A - the matrix
1841: Level: intermediate
1843: Notes:
1844: It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1845: MatXXXXSetPreallocation() paradigm instead of this routine directly.
1846: [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1848: The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1849: storage. That is, the stored row and column indices can begin at
1850: either one (as in Fortran) or zero.
1852: .seealso: [](ch_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1853: `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1854: @*/
1855: PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1856: {
1857: PetscMPIInt size;
1859: PetscFunctionBegin;
1860: PetscCall(MatCreate(comm, A));
1861: PetscCall(MatSetSizes(*A, m, n, M, N));
1862: PetscCallMPI(MPI_Comm_size(comm, &size));
1863: if (size > 1) {
1864: PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1865: PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1866: } else {
1867: PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1868: PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1869: }
1870: PetscFunctionReturn(PETSC_SUCCESS);
1871: }