Actual source code: bjkokkos.kokkos.cxx

  1: #define PETSC_SKIP_CXX_COMPLEX_FIX // Kokkos::complex does not need the petsc complex fix

  3: #include <petsc/private/pcbjkokkosimpl.h>

  5: #include <petsc/private/kspimpl.h>
  6: #include <petscksp.h>
  7: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  8: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
  9: #include <petscsection.h>
 10: #include <petscdmcomposite.h>

 12: #include <../src/mat/impls/aij/seq/aij.h>
 13: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>

 15: #include <petscdevice_cupm.h>

 17: static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
 18: {
 19:   const char    *prefix;
 20:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
 21:   DM             dm;

 23:   PetscFunctionBegin;
 24:   PetscCall(KSPCreate(PetscObjectComm((PetscObject)pc), &jac->ksp));
 25:   PetscCall(KSPSetNestLevel(jac->ksp, pc->kspnestlevel));
 26:   PetscCall(KSPSetErrorIfNotConverged(jac->ksp, pc->erroriffailure));
 27:   PetscCall(PetscObjectIncrementTabLevel((PetscObject)jac->ksp, (PetscObject)pc, 1));
 28:   PetscCall(PCGetOptionsPrefix(pc, &prefix));
 29:   PetscCall(KSPSetOptionsPrefix(jac->ksp, prefix));
 30:   PetscCall(KSPAppendOptionsPrefix(jac->ksp, "pc_bjkokkos_"));
 31:   PetscCall(PCGetDM(pc, &dm));
 32:   if (dm) {
 33:     PetscCall(KSPSetDM(jac->ksp, dm));
 34:     PetscCall(KSPSetDMActive(jac->ksp, PETSC_FALSE));
 35:   }
 36:   jac->reason       = PETSC_FALSE;
 37:   jac->monitor      = PETSC_FALSE;
 38:   jac->batch_target = 0;
 39:   jac->rank_target  = 0;
 40:   jac->nsolves_team = 1;
 41:   jac->ksp->max_it  = 50; // this is really for GMRES w/o restarts
 42:   PetscFunctionReturn(PETSC_SUCCESS);
 43: }

 45: // y <-- Ax
 46: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
 47: {
 48:   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
 49:     int                rowa = ic[rowb];
 50:     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
 51:     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
 52:     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
 53:     PetscScalar        sum;
 54:     Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, n), [=](const int i, PetscScalar &lsum) { lsum += aa[i] * x_loc[r[aj[i]] - start]; }, sum);
 55:     Kokkos::single(Kokkos::PerThread(team), [=]() { y_loc[rowb - start] = sum; });
 56:   });
 57:   team.team_barrier();
 58:   return PETSC_SUCCESS;
 59: }

 61: // temp buffer per thread with reduction at end?
 62: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
 63: {
 64:   Kokkos::parallel_for(Kokkos::TeamVectorRange(team, end - start), [=](int i) { y_loc[i] = 0; });
 65:   team.team_barrier();
 66:   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
 67:     int                rowa = ic[rowb];
 68:     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
 69:     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
 70:     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
 71:     const PetscScalar  xx   = x_loc[rowb - start]; // rowb = ic[rowa] = ic[r[rowb]]
 72:     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
 73:       PetscScalar val = aa[i] * xx;
 74:       Kokkos::atomic_fetch_add(&y_loc[r[aj[i]] - start], val);
 75:     });
 76:   });
 77:   team.team_barrier();
 78:   return PETSC_SUCCESS;
 79: }

 81: typedef struct Batch_MetaData_TAG {
 82:   PetscInt           flops;
 83:   PetscInt           its;
 84:   KSPConvergedReason reason;
 85: } Batch_MetaData;

 87: // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
 88: static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
 89: {
 90:   using Kokkos::parallel_for;
 91:   using Kokkos::parallel_reduce;
 92:   int                Nblk = end - start, it, m, stride = stride_shared, idx = 0;
 93:   PetscReal          dp, dpold, w, dpest, tau, psi, cm, r0;
 94:   const PetscScalar *Diag = &glb_idiag[start];
 95:   PetscScalar       *ptr  = work_space_shared, rho, rhoold, a, s, b, eta, etaold, psiold, cf, dpi;

 97:   if (idx++ == nShareVec) {
 98:     ptr    = work_space_global;
 99:     stride = stride_global;
100:   }
101:   PetscScalar *XX = ptr;
102:   ptr += stride;
103:   if (idx++ == nShareVec) {
104:     ptr    = work_space_global;
105:     stride = stride_global;
106:   }
107:   PetscScalar *R = ptr;
108:   ptr += stride;
109:   if (idx++ == nShareVec) {
110:     ptr    = work_space_global;
111:     stride = stride_global;
112:   }
113:   PetscScalar *RP = ptr;
114:   ptr += stride;
115:   if (idx++ == nShareVec) {
116:     ptr    = work_space_global;
117:     stride = stride_global;
118:   }
119:   PetscScalar *V = ptr;
120:   ptr += stride;
121:   if (idx++ == nShareVec) {
122:     ptr    = work_space_global;
123:     stride = stride_global;
124:   }
125:   PetscScalar *T = ptr;
126:   ptr += stride;
127:   if (idx++ == nShareVec) {
128:     ptr    = work_space_global;
129:     stride = stride_global;
130:   }
131:   PetscScalar *Q = ptr;
132:   ptr += stride;
133:   if (idx++ == nShareVec) {
134:     ptr    = work_space_global;
135:     stride = stride_global;
136:   }
137:   PetscScalar *P = ptr;
138:   ptr += stride;
139:   if (idx++ == nShareVec) {
140:     ptr    = work_space_global;
141:     stride = stride_global;
142:   }
143:   PetscScalar *U = ptr;
144:   ptr += stride;
145:   if (idx++ == nShareVec) {
146:     ptr    = work_space_global;
147:     stride = stride_global;
148:   }
149:   PetscScalar *D = ptr;
150:   ptr += stride;
151:   if (idx++ == nShareVec) {
152:     ptr    = work_space_global;
153:     stride = stride_global;
154:   }
155:   PetscScalar *AUQ = V;

157:   // init: get b, zero x
158:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
159:     int rowa         = ic[rowb];
160:     R[rowb - start]  = glb_b[rowa];
161:     XX[rowb - start] = 0;
162:   });
163:   team.team_barrier();
164:   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
165:   team.team_barrier();
166:   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
167:   // diagnostics
168: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
169:   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
170: #endif
171:   if (dp < atol) {
172:     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
173:     it            = 0;
174:     goto done;
175:   }
176:   if (0 == maxit) {
177:     metad->reason = KSP_CONVERGED_ITS;
178:     it            = 0;
179:     goto done;
180:   }

182:   /* Make the initial Rp = R */
183:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { RP[idx] = R[idx]; });
184:   team.team_barrier();
185:   /* Set the initial conditions */
186:   etaold = 0.0;
187:   psiold = 0.0;
188:   tau    = dp;
189:   dpold  = dp;

191:   /* rhoold = (r,rp)     */
192:   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rhoold);
193:   team.team_barrier();
194:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
195:     U[idx] = R[idx];
196:     P[idx] = R[idx];
197:     T[idx] = Diag[idx] * P[idx];
198:     D[idx] = 0;
199:   });
200:   team.team_barrier();
201:   static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));

203:   it = 0;
204:   do {
205:     /* s <- (v,rp)          */
206:     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += V[idx] * PetscConj(RP[idx]); }, s);
207:     team.team_barrier();
208:     if (s == 0) {
209:       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
210:       goto done;
211:     }
212:     a = rhoold / s; /* a <- rho / s         */
213:     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
214:     /* t <- u + q           */
215:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
216:       Q[idx] = U[idx] - a * V[idx];
217:       T[idx] = U[idx] + Q[idx];
218:     });
219:     team.team_barrier();
220:     // KSP_PCApplyBAorAB
221:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * T[idx]; });
222:     team.team_barrier();
223:     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, AUQ));
224:     /* r <- r - a K (u + q) */
225:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { R[idx] = R[idx] - a * AUQ[idx]; });
226:     team.team_barrier();
227:     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
228:     team.team_barrier();
229:     dp = PetscSqrtReal(PetscRealPart(dpi));
230:     for (m = 0; m < 2; m++) {
231:       if (!m) w = PetscSqrtReal(dp * dpold);
232:       else w = dp;
233:       psi = w / tau;
234:       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
235:       tau = tau * psi * cm;
236:       eta = cm * cm * a;
237:       cf  = psiold * psiold * etaold / a;
238:       if (!m) {
239:         /* D = U + cf D */
240:         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = U[idx] + cf * D[idx]; });
241:       } else {
242:         /* D = Q + cf D */
243:         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = Q[idx] + cf * D[idx]; });
244:       }
245:       team.team_barrier();
246:       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = XX[idx] + eta * D[idx]; });
247:       team.team_barrier();
248:       dpest = PetscSqrtReal(2 * it + m + 2.0) * tau;
249: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
250:       if (monitor && m == 1) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dpest); });
251: #endif
252:       if (dpest < atol) {
253:         metad->reason = KSP_CONVERGED_ATOL_NORMAL;
254:         goto done;
255:       }
256:       if (dpest / r0 < rtol) {
257:         metad->reason = KSP_CONVERGED_RTOL_NORMAL;
258:         goto done;
259:       }
260: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
261:       if (dpest / r0 > dtol) {
262:         metad->reason = KSP_DIVERGED_DTOL;
263:         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), it, dpest, r0); });
264:         goto done;
265:       }
266: #else
267:       if (dpest / r0 > dtol) {
268:         metad->reason = KSP_DIVERGED_DTOL;
269:         goto done;
270:       }
271: #endif
272:       if (it + 1 == maxit) {
273:         metad->reason = KSP_CONVERGED_ITS;
274: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
275:         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: TFQMR %d:%d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, m, dpest, r0, dpest / r0); });
276: #endif
277:         goto done;
278:       }
279:       etaold = eta;
280:       psiold = psi;
281:     }

283:     /* rho <- (r,rp)       */
284:     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rho);
285:     team.team_barrier();
286:     if (rho == 0) {
287:       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
288:       goto done;
289:     }
290:     b = rho / rhoold; /* b <- rho / rhoold   */
291:     /* u <- r + b q        */
292:     /* p <- u + b(q + b p) */
293:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
294:       U[idx] = R[idx] + b * Q[idx];
295:       Q[idx] = Q[idx] + b * P[idx];
296:       P[idx] = U[idx] + b * Q[idx];
297:     });
298:     /* v <- K p  */
299:     team.team_barrier();
300:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * P[idx]; });
301:     team.team_barrier();
302:     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));

304:     rhoold = rho;
305:     dpold  = dp;

307:     it++;
308:   } while (it < maxit);
309: done:
310:   // KSPUnwindPreconditioner
311:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = Diag[idx] * XX[idx]; });
312:   team.team_barrier();
313:   // put x into Plex order
314:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
315:     int rowa    = ic[rowb];
316:     glb_x[rowa] = XX[rowb - start];
317:   });
318:   metad->its = it;
319:   if (1) {
320:     int nnz;
321:     parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
322:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
323:   } else {
324:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
325:   }
326:   return PETSC_SUCCESS;
327: }

329: // Solve Ax = y with biCG
330: static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
331: {
332:   using Kokkos::parallel_for;
333:   using Kokkos::parallel_reduce;
334:   int                Nblk = end - start, it, stride = stride_shared, idx = 0; // start in shared mem
335:   PetscReal          dp, r0;
336:   const PetscScalar *Di  = &glb_idiag[start];
337:   PetscScalar       *ptr = work_space_shared, dpi, a = 1.0, beta, betaold = 1.0, t1, t2;

339:   if (idx++ == nShareVec) {
340:     ptr    = work_space_global;
341:     stride = stride_global;
342:   }
343:   PetscScalar *XX = ptr;
344:   ptr += stride;
345:   if (idx++ == nShareVec) {
346:     ptr    = work_space_global;
347:     stride = stride_global;
348:   }
349:   PetscScalar *Rl = ptr;
350:   ptr += stride;
351:   if (idx++ == nShareVec) {
352:     ptr    = work_space_global;
353:     stride = stride_global;
354:   }
355:   PetscScalar *Zl = ptr;
356:   ptr += stride;
357:   if (idx++ == nShareVec) {
358:     ptr    = work_space_global;
359:     stride = stride_global;
360:   }
361:   PetscScalar *Pl = ptr;
362:   ptr += stride;
363:   if (idx++ == nShareVec) {
364:     ptr    = work_space_global;
365:     stride = stride_global;
366:   }
367:   PetscScalar *Rr = ptr;
368:   ptr += stride;
369:   if (idx++ == nShareVec) {
370:     ptr    = work_space_global;
371:     stride = stride_global;
372:   }
373:   PetscScalar *Zr = ptr;
374:   ptr += stride;
375:   if (idx++ == nShareVec) {
376:     ptr    = work_space_global;
377:     stride = stride_global;
378:   }
379:   PetscScalar *Pr = ptr;
380:   ptr += stride;

382:   /*     r <- b (x is 0) */
383:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
384:     int rowa         = ic[rowb];
385:     Rl[rowb - start] = Rr[rowb - start] = glb_b[rowa];
386:     XX[rowb - start]                    = 0;
387:   });
388:   team.team_barrier();
389:   /*     z <- Br         */
390:   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
391:     Zr[idx] = Di[idx] * Rr[idx];
392:     Zl[idx] = Di[idx] * Rl[idx];
393:   });
394:   team.team_barrier();
395:   /*    dp <- r'*r       */
396:   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
397:   team.team_barrier();
398:   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
399: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
400:   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
401: #endif
402:   if (dp < atol) {
403:     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
404:     it            = 0;
405:     goto done;
406:   }
407:   if (0 == maxit) {
408:     metad->reason = KSP_CONVERGED_ITS;
409:     it            = 0;
410:     goto done;
411:   }

413:   it = 0;
414:   do {
415:     /*     beta <- r'z     */
416:     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += Zr[idx] * PetscConj(Rl[idx]); }, beta);
417:     team.team_barrier();
418: #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
419:   #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
420:     Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%7d beta = Z.R = %22.14e \n", i, (double)beta); });
421:   #endif
422: #endif
423:     if (beta == 0.0) {
424:       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
425:       goto done;
426:     }
427:     if (it == 0) {
428:       /*     p <- z          */
429:       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
430:         Pr[idx] = Zr[idx];
431:         Pl[idx] = Zl[idx];
432:       });
433:     } else {
434:       t1 = beta / betaold;
435:       /*     p <- z + b* p   */
436:       t2 = PetscConj(t1);
437:       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
438:         Pr[idx] = t1 * Pr[idx] + Zr[idx];
439:         Pl[idx] = t2 * Pl[idx] + Zl[idx];
440:       });
441:     }
442:     team.team_barrier();
443:     betaold = beta;
444:     /*     z <- Kp         */
445:     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pr, Zr));
446:     static_cast<void>(MatMultTranspose(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pl, Zl));
447:     /*     dpi <- z'p      */
448:     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Zr[idx] * PetscConj(Pl[idx]); }, dpi);
449:     team.team_barrier();
450:     if (dpi == 0) {
451:       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
452:       goto done;
453:     }
454:     //
455:     a  = beta / dpi; /*     a = beta/p'z    */
456:     t1 = -a;
457:     t2 = PetscConj(t1);
458:     /*     x <- x + ap     */
459:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
460:       XX[idx] = XX[idx] + a * Pr[idx];
461:       Rr[idx] = Rr[idx] + t1 * Zr[idx];
462:       Rl[idx] = Rl[idx] + t2 * Zl[idx];
463:     });
464:     team.team_barrier();
465:     team.team_barrier();
466:     /*    dp <- r'*r       */
467:     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
468:     team.team_barrier();
469:     dp = PetscSqrtReal(PetscRealPart(dpi));
470: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
471:     if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dp); });
472: #endif
473:     if (dp < atol) {
474:       metad->reason = KSP_CONVERGED_ATOL_NORMAL;
475:       goto done;
476:     }
477:     if (dp / r0 < rtol) {
478:       metad->reason = KSP_CONVERGED_RTOL_NORMAL;
479:       goto done;
480:     }
481: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
482:     if (dp / r0 > dtol) {
483:       metad->reason = KSP_DIVERGED_DTOL;
484:       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e (BICG does this)\n", team.league_rank(), it, dp, r0); });
485:       goto done;
486:     }
487: #else
488:     if (dp / r0 > dtol) {
489:       metad->reason = KSP_DIVERGED_DTOL;
490:       goto done;
491:     }
492: #endif
493:     if (it + 1 == maxit) {
494:       metad->reason = KSP_CONVERGED_ITS; // don't worry about hitting max iterations
495: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
496:       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: BICG %d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, dp, r0, dp / r0); });
497: #endif
498:       goto done;
499:     }
500:     /* z <- Br  */
501:     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
502:       Zr[idx] = Di[idx] * Rr[idx];
503:       Zl[idx] = Di[idx] * Rl[idx];
504:     });

506:     it++;
507:   } while (it < maxit);
508: done:
509:   // put x back into Plex order
510:   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
511:     int rowa    = ic[rowb];
512:     glb_x[rowa] = XX[rowb - start];
513:   });
514:   metad->its = it;
515:   if (1) {
516:     int nnz;
517:     parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
518:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
519:   } else {
520:     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
521:   }
522:   return PETSC_SUCCESS;
523: }

525: // KSP solver solve Ax = b; xout is output, bin is input
526: static PetscErrorCode PCApply_BJKOKKOS(PC pc, Vec bin, Vec xout)
527: {
528:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
529:   Mat            A = pc->pmat, Aseq = A;
530:   PetscMPIInt    rank;

532:   PetscFunctionBegin;
533:   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
534:   if (!A->spptr) {
535:     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
536:   }
537:   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
538:   {
539:     PetscInt           maxit = jac->ksp->max_it;
540:     const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
541:     const PetscInt     nwork = jac->nwork, nBlk = jac->nBlocks;
542:     PetscScalar       *glb_xdata = NULL, *dummy;
543:     PetscReal          rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
544:     const PetscScalar *glb_idiag = jac->d_idiag_k->data(), *glb_bdata = NULL;
545:     const PetscInt    *glb_Aai, *glb_Aaj, *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
546:     const PetscScalar *glb_Aaa;
547:     const PetscInt    *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
548:     PCFailedReason     pcreason;
549:     KSPIndex           ksp_type_idx = jac->ksp_type_idx;
550:     PetscMemType       mtype;
551:     PetscContainer     container;
552:     PetscInt           batch_sz;                // the number of repeated DMs, [DM_e_1, DM_e_2, DM_e_batch_sz, DM_i_1, ...]
553:     VecScatter         plex_batch = NULL;       // not used
554:     Vec                bvec;                    // a copy of b for scatter (just alias to bin now)
555:     PetscBool          monitor  = jac->monitor; // captured
556:     PetscInt           view_bid = jac->batch_target;
557:     MatInfo            info;

559:     PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &glb_Aai, &glb_Aaj, &dummy, &mtype));
560:     jac->max_nits = 0;
561:     glb_Aaa       = dummy;
562:     if (jac->rank_target != rank) view_bid = -1; // turn off all but one process
563:     PetscCall(MatGetInfo(A, MAT_LOCAL, &info));
564:     // get field major is to map plex IO to/from block/field major
565:     PetscCall(PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container));
566:     if (container) {
567:       PetscCall(VecDuplicate(bin, &bvec));
568:       PetscCall(PetscContainerGetPointer(container, (void **)&plex_batch));
569:       PetscCall(VecScatterBegin(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
570:       PetscCall(VecScatterEnd(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
571:       SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No plex_batch_is -- require NO field major ordering for now");
572:     } else {
573:       bvec = bin;
574:     }
575:     // get x
576:     PetscCall(VecGetArrayAndMemType(xout, &glb_xdata, &mtype));
577: #if defined(PETSC_HAVE_CUDA)
578:     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for x %d != %d", static_cast<int>(mtype), static_cast<int>(PETSC_MEMTYPE_DEVICE));
579: #endif
580:     PetscCall(VecGetArrayReadAndMemType(bvec, &glb_bdata, &mtype));
581: #if defined(PETSC_HAVE_CUDA)
582:     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for b");
583: #endif
584:     // get batch size
585:     PetscCall(PetscObjectQuery((PetscObject)A, "batch size", (PetscObject *)&container));
586:     if (container) {
587:       PetscInt *pNf = NULL;
588:       PetscCall(PetscContainerGetPointer(container, (void **)&pNf));
589:       batch_sz = *pNf; // number of times to repeat the DMs
590:     } else batch_sz = 1;
591:     PetscCheck(nBlk % batch_sz == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT, batch_sz, nBlk);
592:     if (ksp_type_idx == BATCH_KSP_GMRESKK_IDX) {
593:       // KK solver - move PETSc data into Kokkos Views, setup solver, solve, move data out of Kokkos, process metadata (convergence tests, etc.)
594: #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
595:       PetscCall(PCApply_BJKOKKOSKERNELS(pc, glb_bdata, glb_xdata, glb_Aai, glb_Aaj, glb_Aaa, team_size, info, batch_sz, &pcreason));
596: #else
597:       PetscCheck(ksp_type_idx != BATCH_KSP_GMRESKK_IDX, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: BATCH_KSP_GMRES not supported for complex");
598: #endif
599:     } else { // Kokkos Krylov
600:       using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
601:       using vect2D_scr_t = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, scr_mem_t>;
602:       Kokkos::View<Batch_MetaData *, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
603:       int                                                           stride_shared, stride_global, global_buff_words;
604:       d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
605:       // solve each block independently
606:       int scr_bytes_team_shared = 0, nShareVec = 0, nGlobBVec = 0;
607:       if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - TODO: test efficiency loss
608:         size_t      maximum_shared_mem_size = 64000;
609:         PetscDevice device;
610:         PetscCall(PetscDeviceGetDefault_Internal(&device));
611:         PetscCall(PetscDeviceGetAttribute(device, PETSC_DEVICE_ATTR_SIZE_T_SHARED_MEM_PER_BLOCK, &maximum_shared_mem_size));
612:         stride_shared = jac->const_block_size;                                                   // captured
613:         nShareVec     = maximum_shared_mem_size / (jac->const_block_size * sizeof(PetscScalar)); // integer floor, number of vectors that fit in shared
614:         if (nShareVec > nwork) nShareVec = nwork;
615:         else nGlobBVec = nwork - nShareVec;
616:         global_buff_words     = jac->n * nGlobBVec;
617:         scr_bytes_team_shared = jac->const_block_size * nShareVec * sizeof(PetscScalar);
618:       } else {
619:         scr_bytes_team_shared = 0;
620:         stride_shared         = 0;
621:         global_buff_words     = jac->n * nwork;
622:         nGlobBVec             = nwork; // not needed == fix
623:       }
624:       stride_global = jac->n; // captured
625: #if defined(PETSC_HAVE_CUDA)
626:       nvtxRangePushA("batch-kokkos-solve");
627: #endif
628:       Kokkos::View<PetscScalar *, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_words); // global work vectors
629: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
630:       PetscCall(PetscInfo(pc, "\tn = %d. %d shared bytes/team, %d global mem bytes, rtol=%e, num blocks %d, team_size=%d, %d vector threads, %d shared vectors, %d global vectors\n", (int)jac->n, scr_bytes_team_shared, global_buff_words, rtol, (int)nBlk, (int)team_size, PCBJKOKKOS_VEC_SIZE, nShareVec, nGlobBVec));
631: #endif
632:       PetscScalar *d_work_vecs = d_work_vecs_k.data();
633:       Kokkos::parallel_for(
634:         "Solve", Kokkos::TeamPolicy<Kokkos::LaunchBounds<256, 4>>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team_shared)), KOKKOS_LAMBDA(const team_member team) {
635:           const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
636:           vect2D_scr_t work_vecs_shared(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), end - start, nShareVec);
637:           PetscScalar *work_buff_shared = work_vecs_shared.data();
638:           PetscScalar *work_buff_global = &d_work_vecs[start]; // start inc'ed in
639:           bool         print            = monitor && (blkID == view_bid);
640:           switch (ksp_type_idx) {
641:           case BATCH_KSP_BICG_IDX:
642:             static_cast<void>(BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
643:             break;
644:           case BATCH_KSP_TFQMR_IDX:
645:             static_cast<void>(BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
646:             break;
647:           default:
648: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
649:             printf("Unknown KSP type %d\n", ksp_type_idx);
650: #else
651:             /* void */;
652: #endif
653:           }
654:         });
655:       Kokkos::fence();
656: #if defined(PETSC_HAVE_CUDA)
657:       nvtxRangePop();
658:       nvtxRangePushA("Post-solve-metadata");
659: #endif
660:       auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
661:       Kokkos::deep_copy(h_metadata, d_metadata);
662:       PetscInt max_nnit = -1;
663: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
664:       PetscInt mbid = 0;
665: #endif
666:       int in[2], out[2];
667:       if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
668: #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
669:   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
670:         PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Iterations\n"));
671:   #endif
672:         // assume species major
673:   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
674:         if (batch_sz != 1) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%s: max iterations per species:", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
675:         else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve converged due to %s iterations ", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
676:   #endif
677:         for (PetscInt dmIdx = 0, head = 0, s = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
678:           for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, idx++, s++) {
679:             for (int bid = 0; bid < batch_sz; bid++) {
680:   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
681:               jac->max_nits += h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its; // report total number of iterations with high verbose
682:               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
683:                 max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
684:                 mbid     = bid;
685:               }
686:   #else
687:               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
688:                 jac->max_nits = max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
689:                 mbid                     = bid;
690:               }
691:   #endif
692:             }
693:   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
694:             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%2" PetscInt_FMT ":", s));
695:             for (int bid = 0; bid < batch_sz; bid++) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its));
696:             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
697:   #else // == 3
698:             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", max_nnit));
699:   #endif
700:           }
701:           head += batch_sz * jac->dm_Nf[dmIdx];
702:         }
703:   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
704:         PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
705:   #endif
706: #endif
707:         if (max_nnit == -1) { // < 3
708:           for (int blkID = 0; blkID < nBlk; blkID++) {
709:             if (h_metadata[blkID].its > max_nnit) {
710:               jac->max_nits = max_nnit = h_metadata[blkID].its;
711: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
712:               mbid = blkID;
713: #endif
714:             }
715:           }
716:         }
717:         in[0] = max_nnit;
718:         in[1] = rank;
719:         PetscCallMPI(MPIU_Allreduce(in, out, 1, MPI_2INT, MPI_MAXLOC, PetscObjectComm((PetscObject)A)));
720: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
721:         if (0 == rank) {
722:           if (batch_sz != 1)
723:             PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT ", species %" PetscInt_FMT " (max)\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid % batch_sz, mbid / batch_sz));
724:           else PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT "\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid));
725:         }
726: #endif
727:       }
728:       for (int blkID = 0; blkID < nBlk; blkID++) {
729:         PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
730:         PetscCheck(h_metadata[blkID].reason >= 0 || !jac->ksp->errorifnotconverged, PetscObjectComm((PetscObject)pc), PETSC_ERR_CONV_FAILED, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT,
731:                    KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz);
732:       }
733:       {
734:         int errsum;
735:         Kokkos::parallel_reduce(
736:           nBlk,
737:           KOKKOS_LAMBDA(const int idx, int &lsum) {
738:             if (d_metadata[idx].reason < 0) ++lsum;
739:           },
740:           errsum);
741:         pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
742:         if (!errsum && !jac->max_nits) { // set max its to give back to top KSP
743:           for (int blkID = 0; blkID < nBlk; blkID++) {
744:             if (h_metadata[blkID].its > jac->max_nits) jac->max_nits = h_metadata[blkID].its;
745:           }
746:         } else if (errsum) {
747:           PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] ERROR Kokkos batch solver did not converge in all solves\n", (int)rank));
748:         }
749:       }
750: #if defined(PETSC_HAVE_CUDA)
751:       nvtxRangePop();
752: #endif
753:     } // end of Kokkos (not Kernels) solvers block
754:     PetscCall(VecRestoreArrayAndMemType(xout, &glb_xdata));
755:     PetscCall(VecRestoreArrayReadAndMemType(bvec, &glb_bdata));
756:     PetscCall(PCSetFailedReason(pc, pcreason));
757:     // map back to Plex space - not used
758:     if (plex_batch) {
759:       PetscCall(VecCopy(xout, bvec));
760:       PetscCall(VecScatterBegin(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
761:       PetscCall(VecScatterEnd(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
762:       PetscCall(VecDestroy(&bvec));
763:     }
764:   }
765:   PetscFunctionReturn(PETSC_SUCCESS);
766: }

768: static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
769: {
770:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
771:   Mat            A = pc->pmat, Aseq = A; // use filtered block matrix, really "P"
772:   PetscBool      flg;

774:   PetscFunctionBegin;
775:   PetscCheck(A, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "No matrix - A is used above");
776:   PetscCall(PetscObjectTypeCompareAny((PetscObject)A, &flg, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
777:   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "must use '-[dm_]mat_type aijkokkos -[dm_]vec_type kokkos' for -pc_type bjkokkos");
778:   if (!A->spptr) {
779:     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
780:   }
781:   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
782:   {
783:     PetscInt    Istart, Iend;
784:     PetscMPIInt rank;
785:     PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
786:     PetscCall(MatGetOwnershipRange(A, &Istart, &Iend));
787:     if (!jac->vec_diag) {
788:       Vec     *subX = NULL;
789:       DM       pack, *subDM = NULL;
790:       PetscInt nDMs, n, *block_sizes = NULL;
791:       IS       isrow, isicol;
792:       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
793:         MatOrderingType rtype;
794:         const PetscInt *rowindices, *icolindices;
795:         rtype = MATORDERINGRCM;
796:         // get permutation. And invert. should we convert to local indices?
797:         PetscCall(MatGetOrdering(Aseq, rtype, &isrow, &isicol)); // only seems to work for seq matrix
798:         PetscCall(ISDestroy(&isrow));
799:         PetscCall(ISInvertPermutation(isicol, PETSC_DECIDE, &isrow)); // THIS IS BACKWARD -- isrow is inverse
800:         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
801:         if (0) {
802:           Mat mat_block_order; // debug
803:           PetscCall(ISShift(isicol, Istart, isicol));
804:           PetscCall(MatCreateSubMatrix(A, isicol, isicol, MAT_INITIAL_MATRIX, &mat_block_order));
805:           PetscCall(ISShift(isicol, -Istart, isicol));
806:           PetscCall(MatViewFromOptions(mat_block_order, NULL, "-ksp_batch_reorder_view"));
807:           PetscCall(MatDestroy(&mat_block_order));
808:         }
809:         PetscCall(ISGetIndices(isrow, &rowindices)); // local idx
810:         PetscCall(ISGetIndices(isicol, &icolindices));
811:         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isrow_k((PetscInt *)rowindices, A->rmap->n);
812:         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isicol_k((PetscInt *)icolindices, A->rmap->n);
813:         jac->d_isrow_k  = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isrow_k));
814:         jac->d_isicol_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isicol_k));
815:         Kokkos::deep_copy(*jac->d_isrow_k, h_isrow_k);
816:         Kokkos::deep_copy(*jac->d_isicol_k, h_isicol_k);
817:         PetscCall(ISRestoreIndices(isrow, &rowindices));
818:         PetscCall(ISRestoreIndices(isicol, &icolindices));
819:         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
820:       }
821:       // get block sizes & allocate vec_diag
822:       PetscCall(PCGetDM(pc, &pack));
823:       if (pack) {
824:         PetscCall(PetscObjectTypeCompare((PetscObject)pack, DMCOMPOSITE, &flg));
825:         if (flg) {
826:           PetscCall(DMCompositeGetNumberDM(pack, &nDMs));
827:           PetscCall(DMCreateGlobalVector(pack, &jac->vec_diag));
828:         } else pack = NULL; // flag for no DM
829:       }
830:       if (!jac->vec_diag) { // get 'nDMs' and sizes 'block_sizes' w/o DMComposite. TODO: User could provide ISs
831:         PetscInt        bsrt, bend, ncols, ntot = 0;
832:         const PetscInt *colsA, nloc = Iend - Istart;
833:         const PetscInt *rowindices, *icolindices;
834:         PetscCall(PetscMalloc1(nloc, &block_sizes)); // very inefficient, to big
835:         PetscCall(ISGetIndices(isrow, &rowindices));
836:         PetscCall(ISGetIndices(isicol, &icolindices));
837:         nDMs = 0;
838:         bsrt = 0;
839:         bend = 1;
840:         for (PetscInt row_B = 0; row_B < nloc; row_B++) { // for all rows in block diagonal space
841:           PetscInt rowA = icolindices[row_B], minj = PETSC_INT_MAX, maxj = 0;
842:           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t[%d] rowA = %d\n",rank,rowA));
843:           PetscCall(MatGetRow(Aseq, rowA, &ncols, &colsA, NULL)); // not sorted in permutation
844:           PetscCheck(ncols, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Empty row not supported: %" PetscInt_FMT, row_B);
845:           for (PetscInt colj = 0; colj < ncols; colj++) {
846:             PetscInt colB = rowindices[colsA[colj]]; // use local idx
847:             //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t[%d] colB = %d\n",rank,colB));
848:             PetscCheck(colB >= 0 && colB < nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "colB < 0: %" PetscInt_FMT, colB);
849:             if (colB > maxj) maxj = colB;
850:             if (colB < minj) minj = colB;
851:           }
852:           PetscCall(MatRestoreRow(Aseq, rowA, &ncols, &colsA, NULL));
853:           if (minj >= bend) { // first column is > max of last block -- new block or last block
854:             //PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\t\t finish block %d, N loc = %d (%d,%d)\n", nDMs+1, bend - bsrt,bsrt,bend));
855:             block_sizes[nDMs] = bend - bsrt;
856:             ntot += block_sizes[nDMs];
857:             PetscCheck(minj == bend, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "minj != bend: %" PetscInt_FMT " != %" PetscInt_FMT, minj, bend);
858:             bsrt = bend;
859:             bend++; // start with size 1 in new block
860:             nDMs++;
861:           }
862:           if (maxj + 1 > bend) bend = maxj + 1;
863:           PetscCheck(minj >= bsrt || row_B == Iend - 1, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "%" PetscInt_FMT ") minj < bsrt: %" PetscInt_FMT " != %" PetscInt_FMT, rowA, minj, bsrt);
864:           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] %d) row %d.%d) cols %d : %d ; bsrt = %d, bend = %d\n",rank,row_B,nDMs,rowA,minj,maxj,bsrt,bend));
865:         }
866:         // do last block
867:         //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t\t [%d] finish block %d, N loc = %d (%d,%d)\n", rank, nDMs+1, bend - bsrt,bsrt,bend));
868:         block_sizes[nDMs] = bend - bsrt;
869:         ntot += block_sizes[nDMs];
870:         nDMs++;
871:         // cleanup
872:         PetscCheck(ntot == nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "n total != n local: %" PetscInt_FMT " != %" PetscInt_FMT, ntot, nloc);
873:         PetscCall(ISRestoreIndices(isrow, &rowindices));
874:         PetscCall(ISRestoreIndices(isicol, &icolindices));
875:         PetscCall(PetscRealloc(sizeof(PetscInt) * nDMs, &block_sizes));
876:         PetscCall(MatCreateVecs(A, &jac->vec_diag, NULL));
877:         PetscCall(PetscInfo(pc, "Setup Matrix based meta data (not DMComposite not attached to PC) %" PetscInt_FMT " sub domains\n", nDMs));
878:       }
879:       PetscCall(ISDestroy(&isrow));
880:       PetscCall(ISDestroy(&isicol));
881:       jac->num_dms = nDMs;
882:       PetscCall(VecGetLocalSize(jac->vec_diag, &n));
883:       jac->n         = n;
884:       jac->d_idiag_k = new Kokkos::View<PetscScalar *, Kokkos::LayoutRight>("idiag", n);
885:       // options
886:       PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
887:       PetscCall(KSPSetFromOptions(jac->ksp));
888:       PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPBICG, ""));
889:       if (flg) {
890:         jac->ksp_type_idx = BATCH_KSP_BICG_IDX;
891:         jac->nwork        = 7;
892:       } else {
893:         PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPTFQMR, ""));
894:         if (flg) {
895:           jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX;
896:           jac->nwork        = 10;
897:         } else {
898: #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
899:           PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPGMRES, ""));
900:           PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Unsupported batch ksp type");
901:           jac->ksp_type_idx = BATCH_KSP_GMRESKK_IDX;
902:           jac->nwork        = 0;
903: #else
904:           KSPType ksptype;
905:           PetscCall(KSPGetType(jac->ksp, &ksptype));
906:           PetscCheck(flg, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: %s not supported in complex", ksptype);
907: #endif
908:         }
909:       }
910:       PetscOptionsBegin(PetscObjectComm((PetscObject)jac->ksp), ((PetscObject)jac->ksp)->prefix, "Options for Kokkos batch solver", "none");
911:       PetscCall(PetscOptionsBool("-ksp_converged_reason", "", "bjkokkos.kokkos.cxx.c", jac->reason, &jac->reason, NULL));
912:       PetscCall(PetscOptionsBool("-ksp_monitor", "", "bjkokkos.kokkos.cxx.c", jac->monitor, &jac->monitor, NULL));
913:       PetscCall(PetscOptionsInt("-ksp_batch_target", "", "bjkokkos.kokkos.cxx.c", jac->batch_target, &jac->batch_target, NULL));
914:       PetscCall(PetscOptionsInt("-ksp_rank_target", "", "bjkokkos.kokkos.cxx.c", jac->rank_target, &jac->rank_target, NULL));
915:       PetscCall(PetscOptionsInt("-ksp_batch_nsolves_team", "", "bjkokkos.kokkos.cxx.c", jac->nsolves_team, &jac->nsolves_team, NULL));
916:       PetscCheck(jac->batch_target < jac->num_dms, PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG, "-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")", jac->batch_target, jac->num_dms);
917:       PetscOptionsEnd();
918:       // get blocks - jac->d_bid_eqOffset_k
919:       if (pack) {
920:         PetscCall(PetscMalloc(sizeof(*subX) * nDMs, &subX));
921:         PetscCall(PetscMalloc(sizeof(*subDM) * nDMs, &subDM));
922:       }
923:       PetscCall(PetscMalloc(sizeof(*jac->dm_Nf) * nDMs, &jac->dm_Nf));
924:       PetscCall(PetscInfo(pc, "Have %" PetscInt_FMT " blocks, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name));
925:       if (pack) PetscCall(DMCompositeGetEntriesArray(pack, subDM));
926:       jac->nBlocks = 0;
927:       for (PetscInt ii = 0; ii < nDMs; ii++) {
928:         PetscInt Nf;
929:         if (subDM) {
930:           DM           dm = subDM[ii];
931:           PetscSection section;
932:           PetscCall(DMGetLocalSection(dm, &section));
933:           PetscCall(PetscSectionGetNumFields(section, &Nf));
934:         } else Nf = 1;
935:         jac->nBlocks += Nf;
936: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
937:         if (ii == 0) PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
938: #else
939:         PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
940: #endif
941:         jac->dm_Nf[ii] = Nf;
942:       }
943:       { // d_bid_eqOffset_k
944:         Kokkos::View<PetscInt *, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks + 1);
945:         if (pack) PetscCall(DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX));
946:         h_block_offsets[0]    = 0;
947:         jac->const_block_size = -1;
948:         for (PetscInt ii = 0, idx = 0; ii < nDMs; ii++) {
949:           PetscInt nloc, nblk;
950:           if (pack) PetscCall(VecGetSize(subX[ii], &nloc));
951:           else nloc = block_sizes[ii];
952:           nblk = nloc / jac->dm_Nf[ii];
953:           PetscCheck(nloc % jac->dm_Nf[ii] == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_USER, "nloc%%jac->dm_Nf[ii] (%" PetscInt_FMT ") != 0 DMs", nloc % jac->dm_Nf[ii]);
954:           for (PetscInt jj = 0; jj < jac->dm_Nf[ii]; jj++, idx++) {
955:             h_block_offsets[idx + 1] = h_block_offsets[idx] + nblk;
956: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
957:             if (idx == 0) PetscCall(PetscInfo(pc, "Add first of %" PetscInt_FMT " blocks with %" PetscInt_FMT " equations\n", jac->nBlocks, nblk));
958: #else
959:             PetscCall(PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks));
960: #endif
961:             if (jac->const_block_size == -1) jac->const_block_size = nblk;
962:             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
963:           }
964:         }
965:         if (pack) {
966:           PetscCall(DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX));
967:           PetscCall(PetscFree(subX));
968:           PetscCall(PetscFree(subDM));
969:         }
970:         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt *, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(), h_block_offsets));
971:         Kokkos::deep_copy(*jac->d_bid_eqOffset_k, h_block_offsets);
972:       }
973:       if (!pack) PetscCall(PetscFree(block_sizes));
974:     }
975:     { // get jac->d_idiag_k (PC setup),
976:       const PetscInt    *d_ai, *d_aj;
977:       const PetscScalar *d_aa;
978:       const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
979:       const PetscInt    *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
980:       PetscScalar       *d_idiag = jac->d_idiag_k->data(), *dummy;
981:       PetscMemType       mtype;
982:       PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &d_ai, &d_aj, &dummy, &mtype));
983:       d_aa = dummy;
984:       Kokkos::parallel_for(
985:         "Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
986:           const PetscInt blkID = team.league_rank();
987:           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, d_bid_eqOffset[blkID], d_bid_eqOffset[blkID + 1]), [=](const int rowb) {
988:             const PetscInt     rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
989:             const PetscScalar *aa   = d_aa + ai;
990:             const PetscInt     nrow = d_ai[rowa + 1] - ai;
991:             int                found;
992:             Kokkos::parallel_reduce(
993:               Kokkos::ThreadVectorRange(team, nrow),
994:               [=](const int &j, int &count) {
995:                 const PetscInt colb = r[aj[j]];
996:                 if (colb == rowb) {
997:                   d_idiag[rowb] = 1. / aa[j];
998:                   count++;
999:                 }
1000:               },
1001:               found);
1002: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1003:             if (found != 1) Kokkos::single(Kokkos::PerThread(team), [=]() { printf("ERRORrow %d) found = %d\n", rowb, found); });
1004: #endif
1005:           });
1006:         });
1007:     }
1008:   }
1009:   PetscFunctionReturn(PETSC_SUCCESS);
1010: }

1012: /* Default destroy, if it has never been setup */
1013: static PetscErrorCode PCReset_BJKOKKOS(PC pc)
1014: {
1015:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1017:   PetscFunctionBegin;
1018:   PetscCall(KSPDestroy(&jac->ksp));
1019:   PetscCall(VecDestroy(&jac->vec_diag));
1020:   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
1021:   if (jac->d_idiag_k) delete jac->d_idiag_k;
1022:   if (jac->d_isrow_k) delete jac->d_isrow_k;
1023:   if (jac->d_isicol_k) delete jac->d_isicol_k;
1024:   jac->d_bid_eqOffset_k = NULL;
1025:   jac->d_idiag_k        = NULL;
1026:   jac->d_isrow_k        = NULL;
1027:   jac->d_isicol_k       = NULL;
1028:   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", NULL)); // not published now (causes configure errors)
1029:   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", NULL));
1030:   PetscCall(PetscFree(jac->dm_Nf));
1031:   jac->dm_Nf = NULL;
1032:   if (jac->rowOffsets) delete jac->rowOffsets;
1033:   if (jac->colIndices) delete jac->colIndices;
1034:   if (jac->batch_b) delete jac->batch_b;
1035:   if (jac->batch_x) delete jac->batch_x;
1036:   if (jac->batch_values) delete jac->batch_values;
1037:   jac->rowOffsets   = NULL;
1038:   jac->colIndices   = NULL;
1039:   jac->batch_b      = NULL;
1040:   jac->batch_x      = NULL;
1041:   jac->batch_values = NULL;
1042:   PetscFunctionReturn(PETSC_SUCCESS);
1043: }

1045: static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
1046: {
1047:   PetscFunctionBegin;
1048:   PetscCall(PCReset_BJKOKKOS(pc));
1049:   PetscCall(PetscFree(pc->data));
1050:   PetscFunctionReturn(PETSC_SUCCESS);
1051: }

1053: static PetscErrorCode PCView_BJKOKKOS(PC pc, PetscViewer viewer)
1054: {
1055:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1056:   PetscBool      iascii;

1058:   PetscFunctionBegin;
1059:   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1060:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
1061:   if (iascii) {
1062:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n"));
1063:     PetscCall(PetscViewerASCIIPrintf(viewer, "\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n", jac->nwork, jac->ksp->rtol, jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
1064:                                      ((PetscObject)jac->ksp)->type_name));
1065:   }
1066:   PetscFunctionReturn(PETSC_SUCCESS);
1067: }

1069: static PetscErrorCode PCSetFromOptions_BJKOKKOS(PC pc, PetscOptionItems *PetscOptionsObject)
1070: {
1071:   PetscFunctionBegin;
1072:   PetscOptionsHeadBegin(PetscOptionsObject, "PC BJKOKKOS options");
1073:   PetscOptionsHeadEnd();
1074:   PetscFunctionReturn(PETSC_SUCCESS);
1075: }

1077: static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc, KSP ksp)
1078: {
1079:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1081:   PetscFunctionBegin;
1082:   PetscCall(PetscObjectReference((PetscObject)ksp));
1083:   PetscCall(KSPDestroy(&jac->ksp));
1084:   jac->ksp = ksp;
1085:   PetscFunctionReturn(PETSC_SUCCESS);
1086: }

1088: /*@
1089:   PCBJKOKKOSSetKSP - Sets the `KSP` context for `PCBJKOKKOS`

1091:   Collective

1093:   Input Parameters:
1094: + pc  - the `PCBJKOKKOS` preconditioner context
1095: - ksp - the `KSP` solver

1097:   Level: advanced

1099:   Notes:
1100:   The `PC` and the `KSP` must have the same communicator

1102:   If the `PC` is not `PCBJKOKKOS` this function returns without doing anything

1104: .seealso: [](ch_ksp), `PCBJKOKKOSGetKSP()`, `PCBJKOKKOS`
1105: @*/
1106: PetscErrorCode PCBJKOKKOSSetKSP(PC pc, KSP ksp)
1107: {
1108:   PetscFunctionBegin;
1111:   PetscCheckSameComm(pc, 1, ksp, 2);
1112:   PetscTryMethod(pc, "PCBJKOKKOSSetKSP_C", (PC, KSP), (pc, ksp));
1113:   PetscFunctionReturn(PETSC_SUCCESS);
1114: }

1116: static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc, KSP *ksp)
1117: {
1118:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1120:   PetscFunctionBegin;
1121:   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1122:   *ksp = jac->ksp;
1123:   PetscFunctionReturn(PETSC_SUCCESS);
1124: }

1126: /*@
1127:   PCBJKOKKOSGetKSP - Gets the `KSP` context for the `PCBJKOKKOS` preconditioner

1129:   Not Collective but `KSP` returned is parallel if `PC` was parallel

1131:   Input Parameter:
1132: . pc - the preconditioner context

1134:   Output Parameter:
1135: . ksp - the `KSP` solver

1137:   Level: advanced

1139:   Notes:
1140:   You must call `KSPSetUp()` before calling `PCBJKOKKOSGetKSP()`.

1142:   If the `PC` is not a `PCBJKOKKOS` object it raises an error

1144: .seealso: [](ch_ksp), `PCBJKOKKOS`, `PCBJKOKKOSSetKSP()`
1145: @*/
1146: PetscErrorCode PCBJKOKKOSGetKSP(PC pc, KSP *ksp)
1147: {
1148:   PetscFunctionBegin;
1150:   PetscAssertPointer(ksp, 2);
1151:   PetscUseMethod(pc, "PCBJKOKKOSGetKSP_C", (PC, KSP *), (pc, ksp));
1152:   PetscFunctionReturn(PETSC_SUCCESS);
1153: }

1155: static PetscErrorCode PCPostSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1156: {
1157:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1159:   PetscFunctionBegin;
1161:   ksp->its = jac->max_nits;
1162:   PetscFunctionReturn(PETSC_SUCCESS);
1163: }

1165: static PetscErrorCode PCPreSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1166: {
1167:   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;

1169:   PetscFunctionBegin;
1171:   jac->ksp->errorifnotconverged = ksp->errorifnotconverged;
1172:   PetscFunctionReturn(PETSC_SUCCESS);
1173: }

1175: /*MC
1176:      PCBJKOKKOS - A batched Krylov/block Jacobi solver that runs a solve of each diagaonl block of a block diagonal `MATSEQAIJ` in a Kokkos thread group

1178:    Options Database Key:
1179: .     -pc_bjkokkos_ - options prefix for its `KSP` options

1181:    Level: intermediate

1183:    Note:
1184:     For use with -ksp_type preonly to bypass any computation on the CPU

1186:    Developer Notes:
1187:    The entire Krylov (TFQMR or BICG) with diagonal preconditioning for each block of a block diagnaol matrix runs in a Kokkos thread group (eg, one block per SM on NVIDIA). It supports taking a non-block diagonal matrix but this is not tested. One should create an explicit block diagonal matrix and use that as the preconditioning matrix in the outer KSP solver. Varaible block size are supported and tested in src/ts/utils/dmplexlandau/tutorials/ex[1|2].c

1189: .seealso: [](ch_ksp), `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCBJACOBI`,
1190:           `PCSHELL`, `PCCOMPOSITE`, `PCSetUseAmat()`, `PCBJKOKKOSGetKSP()`
1191: M*/

1193: PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
1194: {
1195:   PC_PCBJKOKKOS *jac;

1197:   PetscFunctionBegin;
1198:   PetscCall(PetscNew(&jac));
1199:   pc->data = (void *)jac;

1201:   jac->ksp              = NULL;
1202:   jac->vec_diag         = NULL;
1203:   jac->d_bid_eqOffset_k = NULL;
1204:   jac->d_idiag_k        = NULL;
1205:   jac->d_isrow_k        = NULL;
1206:   jac->d_isicol_k       = NULL;
1207:   jac->nBlocks          = 1;
1208:   jac->max_nits         = 0;

1210:   PetscCall(PetscMemzero(pc->ops, sizeof(struct _PCOps)));
1211:   pc->ops->apply          = PCApply_BJKOKKOS;
1212:   pc->ops->applytranspose = NULL;
1213:   pc->ops->setup          = PCSetUp_BJKOKKOS;
1214:   pc->ops->reset          = PCReset_BJKOKKOS;
1215:   pc->ops->destroy        = PCDestroy_BJKOKKOS;
1216:   pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
1217:   pc->ops->view           = PCView_BJKOKKOS;
1218:   pc->ops->postsolve      = PCPostSolve_BJKOKKOS;
1219:   pc->ops->presolve       = PCPreSolve_BJKOKKOS;

1221:   jac->rowOffsets   = NULL;
1222:   jac->colIndices   = NULL;
1223:   jac->batch_b      = NULL;
1224:   jac->batch_x      = NULL;
1225:   jac->batch_values = NULL;

1227:   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", PCBJKOKKOSGetKSP_BJKOKKOS));
1228:   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", PCBJKOKKOSSetKSP_BJKOKKOS));
1229:   PetscFunctionReturn(PETSC_SUCCESS);
1230: }