Actual source code: bjkokkos.kokkos.cxx
1: #define PETSC_SKIP_CXX_COMPLEX_FIX // Kokkos::complex does not need the petsc complex fix
3: #include <petsc/private/pcbjkokkosimpl.h>
5: #include <petsc/private/kspimpl.h>
6: #include <petscksp.h>
7: #include <../src/mat/impls/aij/mpi/mpiaij.h>
8: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
9: #include <petscsection.h>
10: #include <petscdmcomposite.h>
12: #include <../src/mat/impls/aij/seq/aij.h>
13: #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
15: #include <petscdevice_cupm.h>
17: static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
18: {
19: const char *prefix;
20: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
21: DM dm;
23: PetscFunctionBegin;
24: PetscCall(KSPCreate(PetscObjectComm((PetscObject)pc), &jac->ksp));
25: PetscCall(KSPSetNestLevel(jac->ksp, pc->kspnestlevel));
26: PetscCall(KSPSetErrorIfNotConverged(jac->ksp, pc->erroriffailure));
27: PetscCall(PetscObjectIncrementTabLevel((PetscObject)jac->ksp, (PetscObject)pc, 1));
28: PetscCall(PCGetOptionsPrefix(pc, &prefix));
29: PetscCall(KSPSetOptionsPrefix(jac->ksp, prefix));
30: PetscCall(KSPAppendOptionsPrefix(jac->ksp, "pc_bjkokkos_"));
31: PetscCall(PCGetDM(pc, &dm));
32: if (dm) {
33: PetscCall(KSPSetDM(jac->ksp, dm));
34: PetscCall(KSPSetDMActive(jac->ksp, PETSC_FALSE));
35: }
36: jac->reason = PETSC_FALSE;
37: jac->monitor = PETSC_FALSE;
38: jac->batch_target = 0;
39: jac->rank_target = 0;
40: jac->nsolves_team = 1;
41: jac->ksp->max_it = 50; // this is really for GMRES w/o restarts
42: PetscFunctionReturn(PETSC_SUCCESS);
43: }
45: // y <-- Ax
46: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
47: {
48: Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
49: int rowa = ic[rowb];
50: int n = glb_Aai[rowa + 1] - glb_Aai[rowa];
51: const PetscInt *aj = glb_Aaj + glb_Aai[rowa]; // global
52: const PetscScalar *aa = glb_Aaa + glb_Aai[rowa];
53: PetscScalar sum;
54: Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, n), [=](const int i, PetscScalar &lsum) { lsum += aa[i] * x_loc[r[aj[i]] - start]; }, sum);
55: Kokkos::single(Kokkos::PerThread(team), [=]() { y_loc[rowb - start] = sum; });
56: });
57: team.team_barrier();
58: return PETSC_SUCCESS;
59: }
61: // temp buffer per thread with reduction at end?
62: KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
63: {
64: Kokkos::parallel_for(Kokkos::TeamVectorRange(team, end - start), [=](int i) { y_loc[i] = 0; });
65: team.team_barrier();
66: Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
67: int rowa = ic[rowb];
68: int n = glb_Aai[rowa + 1] - glb_Aai[rowa];
69: const PetscInt *aj = glb_Aaj + glb_Aai[rowa]; // global
70: const PetscScalar *aa = glb_Aaa + glb_Aai[rowa];
71: const PetscScalar xx = x_loc[rowb - start]; // rowb = ic[rowa] = ic[r[rowb]]
72: Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
73: PetscScalar val = aa[i] * xx;
74: Kokkos::atomic_fetch_add(&y_loc[r[aj[i]] - start], val);
75: });
76: });
77: team.team_barrier();
78: return PETSC_SUCCESS;
79: }
81: typedef struct Batch_MetaData_TAG {
82: PetscInt flops;
83: PetscInt its;
84: KSPConvergedReason reason;
85: } Batch_MetaData;
87: // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
88: static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
89: {
90: using Kokkos::parallel_for;
91: using Kokkos::parallel_reduce;
92: int Nblk = end - start, it, m, stride = stride_shared, idx = 0;
93: PetscReal dp, dpold, w, dpest, tau, psi, cm, r0;
94: const PetscScalar *Diag = &glb_idiag[start];
95: PetscScalar *ptr = work_space_shared, rho, rhoold, a, s, b, eta, etaold, psiold, cf, dpi;
97: if (idx++ == nShareVec) {
98: ptr = work_space_global;
99: stride = stride_global;
100: }
101: PetscScalar *XX = ptr;
102: ptr += stride;
103: if (idx++ == nShareVec) {
104: ptr = work_space_global;
105: stride = stride_global;
106: }
107: PetscScalar *R = ptr;
108: ptr += stride;
109: if (idx++ == nShareVec) {
110: ptr = work_space_global;
111: stride = stride_global;
112: }
113: PetscScalar *RP = ptr;
114: ptr += stride;
115: if (idx++ == nShareVec) {
116: ptr = work_space_global;
117: stride = stride_global;
118: }
119: PetscScalar *V = ptr;
120: ptr += stride;
121: if (idx++ == nShareVec) {
122: ptr = work_space_global;
123: stride = stride_global;
124: }
125: PetscScalar *T = ptr;
126: ptr += stride;
127: if (idx++ == nShareVec) {
128: ptr = work_space_global;
129: stride = stride_global;
130: }
131: PetscScalar *Q = ptr;
132: ptr += stride;
133: if (idx++ == nShareVec) {
134: ptr = work_space_global;
135: stride = stride_global;
136: }
137: PetscScalar *P = ptr;
138: ptr += stride;
139: if (idx++ == nShareVec) {
140: ptr = work_space_global;
141: stride = stride_global;
142: }
143: PetscScalar *U = ptr;
144: ptr += stride;
145: if (idx++ == nShareVec) {
146: ptr = work_space_global;
147: stride = stride_global;
148: }
149: PetscScalar *D = ptr;
150: ptr += stride;
151: if (idx++ == nShareVec) {
152: ptr = work_space_global;
153: stride = stride_global;
154: }
155: PetscScalar *AUQ = V;
157: // init: get b, zero x
158: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
159: int rowa = ic[rowb];
160: R[rowb - start] = glb_b[rowa];
161: XX[rowb - start] = 0;
162: });
163: team.team_barrier();
164: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
165: team.team_barrier();
166: r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
167: // diagnostics
168: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
169: if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
170: #endif
171: if (dp < atol) {
172: metad->reason = KSP_CONVERGED_ATOL_NORMAL;
173: it = 0;
174: goto done;
175: }
176: if (0 == maxit) {
177: metad->reason = KSP_CONVERGED_ITS;
178: it = 0;
179: goto done;
180: }
182: /* Make the initial Rp = R */
183: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { RP[idx] = R[idx]; });
184: team.team_barrier();
185: /* Set the initial conditions */
186: etaold = 0.0;
187: psiold = 0.0;
188: tau = dp;
189: dpold = dp;
191: /* rhoold = (r,rp) */
192: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rhoold);
193: team.team_barrier();
194: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
195: U[idx] = R[idx];
196: P[idx] = R[idx];
197: T[idx] = Diag[idx] * P[idx];
198: D[idx] = 0;
199: });
200: team.team_barrier();
201: static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
203: it = 0;
204: do {
205: /* s <- (v,rp) */
206: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += V[idx] * PetscConj(RP[idx]); }, s);
207: team.team_barrier();
208: if (s == 0) {
209: metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
210: goto done;
211: }
212: a = rhoold / s; /* a <- rho / s */
213: /* q <- u - a v VecWAXPY(w,alpha,x,y): w = alpha x + y. */
214: /* t <- u + q */
215: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
216: Q[idx] = U[idx] - a * V[idx];
217: T[idx] = U[idx] + Q[idx];
218: });
219: team.team_barrier();
220: // KSP_PCApplyBAorAB
221: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * T[idx]; });
222: team.team_barrier();
223: static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, AUQ));
224: /* r <- r - a K (u + q) */
225: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { R[idx] = R[idx] - a * AUQ[idx]; });
226: team.team_barrier();
227: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
228: team.team_barrier();
229: dp = PetscSqrtReal(PetscRealPart(dpi));
230: for (m = 0; m < 2; m++) {
231: if (!m) w = PetscSqrtReal(dp * dpold);
232: else w = dp;
233: psi = w / tau;
234: cm = 1.0 / PetscSqrtReal(1.0 + psi * psi);
235: tau = tau * psi * cm;
236: eta = cm * cm * a;
237: cf = psiold * psiold * etaold / a;
238: if (!m) {
239: /* D = U + cf D */
240: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = U[idx] + cf * D[idx]; });
241: } else {
242: /* D = Q + cf D */
243: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = Q[idx] + cf * D[idx]; });
244: }
245: team.team_barrier();
246: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = XX[idx] + eta * D[idx]; });
247: team.team_barrier();
248: dpest = PetscSqrtReal(2 * it + m + 2.0) * tau;
249: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
250: if (monitor && m == 1) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dpest); });
251: #endif
252: if (dpest < atol) {
253: metad->reason = KSP_CONVERGED_ATOL_NORMAL;
254: goto done;
255: }
256: if (dpest / r0 < rtol) {
257: metad->reason = KSP_CONVERGED_RTOL_NORMAL;
258: goto done;
259: }
260: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
261: if (dpest / r0 > dtol) {
262: metad->reason = KSP_DIVERGED_DTOL;
263: Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), it, dpest, r0); });
264: goto done;
265: }
266: #else
267: if (dpest / r0 > dtol) {
268: metad->reason = KSP_DIVERGED_DTOL;
269: goto done;
270: }
271: #endif
272: if (it + 1 == maxit) {
273: metad->reason = KSP_CONVERGED_ITS;
274: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
275: Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: TFQMR %d:%d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, m, dpest, r0, dpest / r0); });
276: #endif
277: goto done;
278: }
279: etaold = eta;
280: psiold = psi;
281: }
283: /* rho <- (r,rp) */
284: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rho);
285: team.team_barrier();
286: if (rho == 0) {
287: metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
288: goto done;
289: }
290: b = rho / rhoold; /* b <- rho / rhoold */
291: /* u <- r + b q */
292: /* p <- u + b(q + b p) */
293: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
294: U[idx] = R[idx] + b * Q[idx];
295: Q[idx] = Q[idx] + b * P[idx];
296: P[idx] = U[idx] + b * Q[idx];
297: });
298: /* v <- K p */
299: team.team_barrier();
300: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * P[idx]; });
301: team.team_barrier();
302: static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
304: rhoold = rho;
305: dpold = dp;
307: it++;
308: } while (it < maxit);
309: done:
310: // KSPUnwindPreconditioner
311: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = Diag[idx] * XX[idx]; });
312: team.team_barrier();
313: // put x into Plex order
314: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
315: int rowa = ic[rowb];
316: glb_x[rowa] = XX[rowb - start];
317: });
318: metad->its = it;
319: if (1) {
320: int nnz;
321: parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
322: metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
323: } else {
324: metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
325: }
326: return PETSC_SUCCESS;
327: }
329: // Solve Ax = y with biCG
330: static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
331: {
332: using Kokkos::parallel_for;
333: using Kokkos::parallel_reduce;
334: int Nblk = end - start, it, stride = stride_shared, idx = 0; // start in shared mem
335: PetscReal dp, r0;
336: const PetscScalar *Di = &glb_idiag[start];
337: PetscScalar *ptr = work_space_shared, dpi, a = 1.0, beta, betaold = 1.0, t1, t2;
339: if (idx++ == nShareVec) {
340: ptr = work_space_global;
341: stride = stride_global;
342: }
343: PetscScalar *XX = ptr;
344: ptr += stride;
345: if (idx++ == nShareVec) {
346: ptr = work_space_global;
347: stride = stride_global;
348: }
349: PetscScalar *Rl = ptr;
350: ptr += stride;
351: if (idx++ == nShareVec) {
352: ptr = work_space_global;
353: stride = stride_global;
354: }
355: PetscScalar *Zl = ptr;
356: ptr += stride;
357: if (idx++ == nShareVec) {
358: ptr = work_space_global;
359: stride = stride_global;
360: }
361: PetscScalar *Pl = ptr;
362: ptr += stride;
363: if (idx++ == nShareVec) {
364: ptr = work_space_global;
365: stride = stride_global;
366: }
367: PetscScalar *Rr = ptr;
368: ptr += stride;
369: if (idx++ == nShareVec) {
370: ptr = work_space_global;
371: stride = stride_global;
372: }
373: PetscScalar *Zr = ptr;
374: ptr += stride;
375: if (idx++ == nShareVec) {
376: ptr = work_space_global;
377: stride = stride_global;
378: }
379: PetscScalar *Pr = ptr;
380: ptr += stride;
382: /* r <- b (x is 0) */
383: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
384: int rowa = ic[rowb];
385: Rl[rowb - start] = Rr[rowb - start] = glb_b[rowa];
386: XX[rowb - start] = 0;
387: });
388: team.team_barrier();
389: /* z <- Br */
390: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
391: Zr[idx] = Di[idx] * Rr[idx];
392: Zl[idx] = Di[idx] * Rl[idx];
393: });
394: team.team_barrier();
395: /* dp <- r'*r */
396: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
397: team.team_barrier();
398: r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
399: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
400: if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
401: #endif
402: if (dp < atol) {
403: metad->reason = KSP_CONVERGED_ATOL_NORMAL;
404: it = 0;
405: goto done;
406: }
407: if (0 == maxit) {
408: metad->reason = KSP_CONVERGED_ITS;
409: it = 0;
410: goto done;
411: }
413: it = 0;
414: do {
415: /* beta <- r'z */
416: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += Zr[idx] * PetscConj(Rl[idx]); }, beta);
417: team.team_barrier();
418: #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
419: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
420: Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%7d beta = Z.R = %22.14e \n", i, (double)beta); });
421: #endif
422: #endif
423: if (beta == 0.0) {
424: metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
425: goto done;
426: }
427: if (it == 0) {
428: /* p <- z */
429: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
430: Pr[idx] = Zr[idx];
431: Pl[idx] = Zl[idx];
432: });
433: } else {
434: t1 = beta / betaold;
435: /* p <- z + b* p */
436: t2 = PetscConj(t1);
437: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
438: Pr[idx] = t1 * Pr[idx] + Zr[idx];
439: Pl[idx] = t2 * Pl[idx] + Zl[idx];
440: });
441: }
442: team.team_barrier();
443: betaold = beta;
444: /* z <- Kp */
445: static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pr, Zr));
446: static_cast<void>(MatMultTranspose(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pl, Zl));
447: /* dpi <- z'p */
448: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Zr[idx] * PetscConj(Pl[idx]); }, dpi);
449: team.team_barrier();
450: if (dpi == 0) {
451: metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
452: goto done;
453: }
454: //
455: a = beta / dpi; /* a = beta/p'z */
456: t1 = -a;
457: t2 = PetscConj(t1);
458: /* x <- x + ap */
459: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
460: XX[idx] = XX[idx] + a * Pr[idx];
461: Rr[idx] = Rr[idx] + t1 * Zr[idx];
462: Rl[idx] = Rl[idx] + t2 * Zl[idx];
463: });
464: team.team_barrier();
465: team.team_barrier();
466: /* dp <- r'*r */
467: parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
468: team.team_barrier();
469: dp = PetscSqrtReal(PetscRealPart(dpi));
470: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
471: if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dp); });
472: #endif
473: if (dp < atol) {
474: metad->reason = KSP_CONVERGED_ATOL_NORMAL;
475: goto done;
476: }
477: if (dp / r0 < rtol) {
478: metad->reason = KSP_CONVERGED_RTOL_NORMAL;
479: goto done;
480: }
481: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
482: if (dp / r0 > dtol) {
483: metad->reason = KSP_DIVERGED_DTOL;
484: Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e (BICG does this)\n", team.league_rank(), it, dp, r0); });
485: goto done;
486: }
487: #else
488: if (dp / r0 > dtol) {
489: metad->reason = KSP_DIVERGED_DTOL;
490: goto done;
491: }
492: #endif
493: if (it + 1 == maxit) {
494: metad->reason = KSP_CONVERGED_ITS; // don't worry about hitting max iterations
495: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
496: Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: BICG %d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, dp, r0, dp / r0); });
497: #endif
498: goto done;
499: }
500: /* z <- Br */
501: parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
502: Zr[idx] = Di[idx] * Rr[idx];
503: Zl[idx] = Di[idx] * Rl[idx];
504: });
506: it++;
507: } while (it < maxit);
508: done:
509: // put x back into Plex order
510: parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
511: int rowa = ic[rowb];
512: glb_x[rowa] = XX[rowb - start];
513: });
514: metad->its = it;
515: if (1) {
516: int nnz;
517: parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
518: metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
519: } else {
520: metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
521: }
522: return PETSC_SUCCESS;
523: }
525: // KSP solver solve Ax = b; xout is output, bin is input
526: static PetscErrorCode PCApply_BJKOKKOS(PC pc, Vec bin, Vec xout)
527: {
528: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
529: Mat A = pc->pmat, Aseq = A;
530: PetscMPIInt rank;
532: PetscFunctionBegin;
533: PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
534: if (!A->spptr) {
535: Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
536: }
537: PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
538: {
539: PetscInt maxit = jac->ksp->max_it;
540: const PetscInt conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
541: const PetscInt nwork = jac->nwork, nBlk = jac->nBlocks;
542: PetscScalar *glb_xdata = NULL, *dummy;
543: PetscReal rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
544: const PetscScalar *glb_idiag = jac->d_idiag_k->data(), *glb_bdata = NULL;
545: const PetscInt *glb_Aai, *glb_Aaj, *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
546: const PetscScalar *glb_Aaa;
547: const PetscInt *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
548: PCFailedReason pcreason;
549: KSPIndex ksp_type_idx = jac->ksp_type_idx;
550: PetscMemType mtype;
551: PetscContainer container;
552: PetscInt batch_sz; // the number of repeated DMs, [DM_e_1, DM_e_2, DM_e_batch_sz, DM_i_1, ...]
553: VecScatter plex_batch = NULL; // not used
554: Vec bvec; // a copy of b for scatter (just alias to bin now)
555: PetscBool monitor = jac->monitor; // captured
556: PetscInt view_bid = jac->batch_target;
557: MatInfo info;
559: PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &glb_Aai, &glb_Aaj, &dummy, &mtype));
560: jac->max_nits = 0;
561: glb_Aaa = dummy;
562: if (jac->rank_target != rank) view_bid = -1; // turn off all but one process
563: PetscCall(MatGetInfo(A, MAT_LOCAL, &info));
564: // get field major is to map plex IO to/from block/field major
565: PetscCall(PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container));
566: if (container) {
567: PetscCall(VecDuplicate(bin, &bvec));
568: PetscCall(PetscContainerGetPointer(container, (void **)&plex_batch));
569: PetscCall(VecScatterBegin(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
570: PetscCall(VecScatterEnd(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
571: SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No plex_batch_is -- require NO field major ordering for now");
572: } else {
573: bvec = bin;
574: }
575: // get x
576: PetscCall(VecGetArrayAndMemType(xout, &glb_xdata, &mtype));
577: #if defined(PETSC_HAVE_CUDA)
578: PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for x %d != %d", static_cast<int>(mtype), static_cast<int>(PETSC_MEMTYPE_DEVICE));
579: #endif
580: PetscCall(VecGetArrayReadAndMemType(bvec, &glb_bdata, &mtype));
581: #if defined(PETSC_HAVE_CUDA)
582: PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for b");
583: #endif
584: // get batch size
585: PetscCall(PetscObjectQuery((PetscObject)A, "batch size", (PetscObject *)&container));
586: if (container) {
587: PetscInt *pNf = NULL;
588: PetscCall(PetscContainerGetPointer(container, (void **)&pNf));
589: batch_sz = *pNf; // number of times to repeat the DMs
590: } else batch_sz = 1;
591: PetscCheck(nBlk % batch_sz == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT, batch_sz, nBlk);
592: if (ksp_type_idx == BATCH_KSP_GMRESKK_IDX) {
593: // KK solver - move PETSc data into Kokkos Views, setup solver, solve, move data out of Kokkos, process metadata (convergence tests, etc.)
594: #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
595: PetscCall(PCApply_BJKOKKOSKERNELS(pc, glb_bdata, glb_xdata, glb_Aai, glb_Aaj, glb_Aaa, team_size, info, batch_sz, &pcreason));
596: #else
597: PetscCheck(ksp_type_idx != BATCH_KSP_GMRESKK_IDX, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: BATCH_KSP_GMRES not supported for complex");
598: #endif
599: } else { // Kokkos Krylov
600: using scr_mem_t = Kokkos::DefaultExecutionSpace::scratch_memory_space;
601: using vect2D_scr_t = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, scr_mem_t>;
602: Kokkos::View<Batch_MetaData *, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
603: int stride_shared, stride_global, global_buff_words;
604: d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
605: // solve each block independently
606: int scr_bytes_team_shared = 0, nShareVec = 0, nGlobBVec = 0;
607: if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - TODO: test efficiency loss
608: size_t maximum_shared_mem_size = 64000;
609: PetscDevice device;
610: PetscCall(PetscDeviceGetDefault_Internal(&device));
611: PetscCall(PetscDeviceGetAttribute(device, PETSC_DEVICE_ATTR_SIZE_T_SHARED_MEM_PER_BLOCK, &maximum_shared_mem_size));
612: stride_shared = jac->const_block_size; // captured
613: nShareVec = maximum_shared_mem_size / (jac->const_block_size * sizeof(PetscScalar)); // integer floor, number of vectors that fit in shared
614: if (nShareVec > nwork) nShareVec = nwork;
615: else nGlobBVec = nwork - nShareVec;
616: global_buff_words = jac->n * nGlobBVec;
617: scr_bytes_team_shared = jac->const_block_size * nShareVec * sizeof(PetscScalar);
618: } else {
619: scr_bytes_team_shared = 0;
620: stride_shared = 0;
621: global_buff_words = jac->n * nwork;
622: nGlobBVec = nwork; // not needed == fix
623: }
624: stride_global = jac->n; // captured
625: #if defined(PETSC_HAVE_CUDA)
626: nvtxRangePushA("batch-kokkos-solve");
627: #endif
628: Kokkos::View<PetscScalar *, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_words); // global work vectors
629: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
630: PetscCall(PetscInfo(pc, "\tn = %d. %d shared bytes/team, %d global mem bytes, rtol=%e, num blocks %d, team_size=%d, %d vector threads, %d shared vectors, %d global vectors\n", (int)jac->n, scr_bytes_team_shared, global_buff_words, rtol, (int)nBlk, (int)team_size, PCBJKOKKOS_VEC_SIZE, nShareVec, nGlobBVec));
631: #endif
632: PetscScalar *d_work_vecs = d_work_vecs_k.data();
633: Kokkos::parallel_for(
634: "Solve", Kokkos::TeamPolicy<Kokkos::LaunchBounds<256, 4>>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team_shared)), KOKKOS_LAMBDA(const team_member team) {
635: const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
636: vect2D_scr_t work_vecs_shared(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), end - start, nShareVec);
637: PetscScalar *work_buff_shared = work_vecs_shared.data();
638: PetscScalar *work_buff_global = &d_work_vecs[start]; // start inc'ed in
639: bool print = monitor && (blkID == view_bid);
640: switch (ksp_type_idx) {
641: case BATCH_KSP_BICG_IDX:
642: static_cast<void>(BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
643: break;
644: case BATCH_KSP_TFQMR_IDX:
645: static_cast<void>(BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
646: break;
647: default:
648: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
649: printf("Unknown KSP type %d\n", ksp_type_idx);
650: #else
651: /* void */;
652: #endif
653: }
654: });
655: Kokkos::fence();
656: #if defined(PETSC_HAVE_CUDA)
657: nvtxRangePop();
658: nvtxRangePushA("Post-solve-metadata");
659: #endif
660: auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
661: Kokkos::deep_copy(h_metadata, d_metadata);
662: PetscInt max_nnit = -1;
663: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
664: PetscInt mbid = 0;
665: #endif
666: int in[2], out[2];
667: if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
668: #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
669: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
670: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Iterations\n"));
671: #endif
672: // assume species major
673: #if PCBJKOKKOS_VERBOSE_LEVEL == 3
674: if (batch_sz != 1) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%s: max iterations per species:", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
675: else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), " Linear solve converged due to %s iterations ", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
676: #endif
677: for (PetscInt dmIdx = 0, head = 0, s = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
678: for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, idx++, s++) {
679: for (int bid = 0; bid < batch_sz; bid++) {
680: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
681: jac->max_nits += h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its; // report total number of iterations with high verbose
682: if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
683: max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
684: mbid = bid;
685: }
686: #else
687: if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
688: jac->max_nits = max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
689: mbid = bid;
690: }
691: #endif
692: }
693: #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
694: PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%2" PetscInt_FMT ":", s));
695: for (int bid = 0; bid < batch_sz; bid++) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its));
696: PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
697: #else // == 3
698: PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", max_nnit));
699: #endif
700: }
701: head += batch_sz * jac->dm_Nf[dmIdx];
702: }
703: #if PCBJKOKKOS_VERBOSE_LEVEL == 3
704: PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
705: #endif
706: #endif
707: if (max_nnit == -1) { // < 3
708: for (int blkID = 0; blkID < nBlk; blkID++) {
709: if (h_metadata[blkID].its > max_nnit) {
710: jac->max_nits = max_nnit = h_metadata[blkID].its;
711: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
712: mbid = blkID;
713: #endif
714: }
715: }
716: }
717: in[0] = max_nnit;
718: in[1] = rank;
719: PetscCallMPI(MPIU_Allreduce(in, out, 1, MPI_2INT, MPI_MAXLOC, PetscObjectComm((PetscObject)A)));
720: #if PCBJKOKKOS_VERBOSE_LEVEL > 1
721: if (0 == rank) {
722: if (batch_sz != 1)
723: PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT ", species %" PetscInt_FMT " (max)\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid % batch_sz, mbid / batch_sz));
724: else PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT "\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid));
725: }
726: #endif
727: }
728: for (int blkID = 0; blkID < nBlk; blkID++) {
729: PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
730: PetscCheck(h_metadata[blkID].reason >= 0 || !jac->ksp->errorifnotconverged, PetscObjectComm((PetscObject)pc), PETSC_ERR_CONV_FAILED, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT,
731: KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz);
732: }
733: {
734: int errsum;
735: Kokkos::parallel_reduce(
736: nBlk,
737: KOKKOS_LAMBDA(const int idx, int &lsum) {
738: if (d_metadata[idx].reason < 0) ++lsum;
739: },
740: errsum);
741: pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
742: if (!errsum && !jac->max_nits) { // set max its to give back to top KSP
743: for (int blkID = 0; blkID < nBlk; blkID++) {
744: if (h_metadata[blkID].its > jac->max_nits) jac->max_nits = h_metadata[blkID].its;
745: }
746: } else if (errsum) {
747: PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] ERROR Kokkos batch solver did not converge in all solves\n", (int)rank));
748: }
749: }
750: #if defined(PETSC_HAVE_CUDA)
751: nvtxRangePop();
752: #endif
753: } // end of Kokkos (not Kernels) solvers block
754: PetscCall(VecRestoreArrayAndMemType(xout, &glb_xdata));
755: PetscCall(VecRestoreArrayReadAndMemType(bvec, &glb_bdata));
756: PetscCall(PCSetFailedReason(pc, pcreason));
757: // map back to Plex space - not used
758: if (plex_batch) {
759: PetscCall(VecCopy(xout, bvec));
760: PetscCall(VecScatterBegin(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
761: PetscCall(VecScatterEnd(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
762: PetscCall(VecDestroy(&bvec));
763: }
764: }
765: PetscFunctionReturn(PETSC_SUCCESS);
766: }
768: static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
769: {
770: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
771: Mat A = pc->pmat, Aseq = A; // use filtered block matrix, really "P"
772: PetscBool flg;
774: PetscFunctionBegin;
775: PetscCheck(A, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "No matrix - A is used above");
776: PetscCall(PetscObjectTypeCompareAny((PetscObject)A, &flg, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
777: PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "must use '-[dm_]mat_type aijkokkos -[dm_]vec_type kokkos' for -pc_type bjkokkos");
778: if (!A->spptr) {
779: Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
780: }
781: PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
782: {
783: PetscInt Istart, Iend;
784: PetscMPIInt rank;
785: PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
786: PetscCall(MatGetOwnershipRange(A, &Istart, &Iend));
787: if (!jac->vec_diag) {
788: Vec *subX = NULL;
789: DM pack, *subDM = NULL;
790: PetscInt nDMs, n, *block_sizes = NULL;
791: IS isrow, isicol;
792: { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
793: MatOrderingType rtype;
794: const PetscInt *rowindices, *icolindices;
795: rtype = MATORDERINGRCM;
796: // get permutation. And invert. should we convert to local indices?
797: PetscCall(MatGetOrdering(Aseq, rtype, &isrow, &isicol)); // only seems to work for seq matrix
798: PetscCall(ISDestroy(&isrow));
799: PetscCall(ISInvertPermutation(isicol, PETSC_DECIDE, &isrow)); // THIS IS BACKWARD -- isrow is inverse
800: // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
801: if (0) {
802: Mat mat_block_order; // debug
803: PetscCall(ISShift(isicol, Istart, isicol));
804: PetscCall(MatCreateSubMatrix(A, isicol, isicol, MAT_INITIAL_MATRIX, &mat_block_order));
805: PetscCall(ISShift(isicol, -Istart, isicol));
806: PetscCall(MatViewFromOptions(mat_block_order, NULL, "-ksp_batch_reorder_view"));
807: PetscCall(MatDestroy(&mat_block_order));
808: }
809: PetscCall(ISGetIndices(isrow, &rowindices)); // local idx
810: PetscCall(ISGetIndices(isicol, &icolindices));
811: const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isrow_k((PetscInt *)rowindices, A->rmap->n);
812: const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isicol_k((PetscInt *)icolindices, A->rmap->n);
813: jac->d_isrow_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isrow_k));
814: jac->d_isicol_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isicol_k));
815: Kokkos::deep_copy(*jac->d_isrow_k, h_isrow_k);
816: Kokkos::deep_copy(*jac->d_isicol_k, h_isicol_k);
817: PetscCall(ISRestoreIndices(isrow, &rowindices));
818: PetscCall(ISRestoreIndices(isicol, &icolindices));
819: // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
820: }
821: // get block sizes & allocate vec_diag
822: PetscCall(PCGetDM(pc, &pack));
823: if (pack) {
824: PetscCall(PetscObjectTypeCompare((PetscObject)pack, DMCOMPOSITE, &flg));
825: if (flg) {
826: PetscCall(DMCompositeGetNumberDM(pack, &nDMs));
827: PetscCall(DMCreateGlobalVector(pack, &jac->vec_diag));
828: } else pack = NULL; // flag for no DM
829: }
830: if (!jac->vec_diag) { // get 'nDMs' and sizes 'block_sizes' w/o DMComposite. TODO: User could provide ISs
831: PetscInt bsrt, bend, ncols, ntot = 0;
832: const PetscInt *colsA, nloc = Iend - Istart;
833: const PetscInt *rowindices, *icolindices;
834: PetscCall(PetscMalloc1(nloc, &block_sizes)); // very inefficient, to big
835: PetscCall(ISGetIndices(isrow, &rowindices));
836: PetscCall(ISGetIndices(isicol, &icolindices));
837: nDMs = 0;
838: bsrt = 0;
839: bend = 1;
840: for (PetscInt row_B = 0; row_B < nloc; row_B++) { // for all rows in block diagonal space
841: PetscInt rowA = icolindices[row_B], minj = PETSC_INT_MAX, maxj = 0;
842: //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t[%d] rowA = %d\n",rank,rowA));
843: PetscCall(MatGetRow(Aseq, rowA, &ncols, &colsA, NULL)); // not sorted in permutation
844: PetscCheck(ncols, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Empty row not supported: %" PetscInt_FMT, row_B);
845: for (PetscInt colj = 0; colj < ncols; colj++) {
846: PetscInt colB = rowindices[colsA[colj]]; // use local idx
847: //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t[%d] colB = %d\n",rank,colB));
848: PetscCheck(colB >= 0 && colB < nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "colB < 0: %" PetscInt_FMT, colB);
849: if (colB > maxj) maxj = colB;
850: if (colB < minj) minj = colB;
851: }
852: PetscCall(MatRestoreRow(Aseq, rowA, &ncols, &colsA, NULL));
853: if (minj >= bend) { // first column is > max of last block -- new block or last block
854: //PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\t\t finish block %d, N loc = %d (%d,%d)\n", nDMs+1, bend - bsrt,bsrt,bend));
855: block_sizes[nDMs] = bend - bsrt;
856: ntot += block_sizes[nDMs];
857: PetscCheck(minj == bend, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "minj != bend: %" PetscInt_FMT " != %" PetscInt_FMT, minj, bend);
858: bsrt = bend;
859: bend++; // start with size 1 in new block
860: nDMs++;
861: }
862: if (maxj + 1 > bend) bend = maxj + 1;
863: PetscCheck(minj >= bsrt || row_B == Iend - 1, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "%" PetscInt_FMT ") minj < bsrt: %" PetscInt_FMT " != %" PetscInt_FMT, rowA, minj, bsrt);
864: //PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] %d) row %d.%d) cols %d : %d ; bsrt = %d, bend = %d\n",rank,row_B,nDMs,rowA,minj,maxj,bsrt,bend));
865: }
866: // do last block
867: //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t\t [%d] finish block %d, N loc = %d (%d,%d)\n", rank, nDMs+1, bend - bsrt,bsrt,bend));
868: block_sizes[nDMs] = bend - bsrt;
869: ntot += block_sizes[nDMs];
870: nDMs++;
871: // cleanup
872: PetscCheck(ntot == nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "n total != n local: %" PetscInt_FMT " != %" PetscInt_FMT, ntot, nloc);
873: PetscCall(ISRestoreIndices(isrow, &rowindices));
874: PetscCall(ISRestoreIndices(isicol, &icolindices));
875: PetscCall(PetscRealloc(sizeof(PetscInt) * nDMs, &block_sizes));
876: PetscCall(MatCreateVecs(A, &jac->vec_diag, NULL));
877: PetscCall(PetscInfo(pc, "Setup Matrix based meta data (not DMComposite not attached to PC) %" PetscInt_FMT " sub domains\n", nDMs));
878: }
879: PetscCall(ISDestroy(&isrow));
880: PetscCall(ISDestroy(&isicol));
881: jac->num_dms = nDMs;
882: PetscCall(VecGetLocalSize(jac->vec_diag, &n));
883: jac->n = n;
884: jac->d_idiag_k = new Kokkos::View<PetscScalar *, Kokkos::LayoutRight>("idiag", n);
885: // options
886: PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
887: PetscCall(KSPSetFromOptions(jac->ksp));
888: PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPBICG, ""));
889: if (flg) {
890: jac->ksp_type_idx = BATCH_KSP_BICG_IDX;
891: jac->nwork = 7;
892: } else {
893: PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPTFQMR, ""));
894: if (flg) {
895: jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX;
896: jac->nwork = 10;
897: } else {
898: #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
899: PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPGMRES, ""));
900: PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Unsupported batch ksp type");
901: jac->ksp_type_idx = BATCH_KSP_GMRESKK_IDX;
902: jac->nwork = 0;
903: #else
904: KSPType ksptype;
905: PetscCall(KSPGetType(jac->ksp, &ksptype));
906: PetscCheck(flg, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: %s not supported in complex", ksptype);
907: #endif
908: }
909: }
910: PetscOptionsBegin(PetscObjectComm((PetscObject)jac->ksp), ((PetscObject)jac->ksp)->prefix, "Options for Kokkos batch solver", "none");
911: PetscCall(PetscOptionsBool("-ksp_converged_reason", "", "bjkokkos.kokkos.cxx.c", jac->reason, &jac->reason, NULL));
912: PetscCall(PetscOptionsBool("-ksp_monitor", "", "bjkokkos.kokkos.cxx.c", jac->monitor, &jac->monitor, NULL));
913: PetscCall(PetscOptionsInt("-ksp_batch_target", "", "bjkokkos.kokkos.cxx.c", jac->batch_target, &jac->batch_target, NULL));
914: PetscCall(PetscOptionsInt("-ksp_rank_target", "", "bjkokkos.kokkos.cxx.c", jac->rank_target, &jac->rank_target, NULL));
915: PetscCall(PetscOptionsInt("-ksp_batch_nsolves_team", "", "bjkokkos.kokkos.cxx.c", jac->nsolves_team, &jac->nsolves_team, NULL));
916: PetscCheck(jac->batch_target < jac->num_dms, PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG, "-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")", jac->batch_target, jac->num_dms);
917: PetscOptionsEnd();
918: // get blocks - jac->d_bid_eqOffset_k
919: if (pack) {
920: PetscCall(PetscMalloc(sizeof(*subX) * nDMs, &subX));
921: PetscCall(PetscMalloc(sizeof(*subDM) * nDMs, &subDM));
922: }
923: PetscCall(PetscMalloc(sizeof(*jac->dm_Nf) * nDMs, &jac->dm_Nf));
924: PetscCall(PetscInfo(pc, "Have %" PetscInt_FMT " blocks, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name));
925: if (pack) PetscCall(DMCompositeGetEntriesArray(pack, subDM));
926: jac->nBlocks = 0;
927: for (PetscInt ii = 0; ii < nDMs; ii++) {
928: PetscInt Nf;
929: if (subDM) {
930: DM dm = subDM[ii];
931: PetscSection section;
932: PetscCall(DMGetLocalSection(dm, §ion));
933: PetscCall(PetscSectionGetNumFields(section, &Nf));
934: } else Nf = 1;
935: jac->nBlocks += Nf;
936: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
937: if (ii == 0) PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
938: #else
939: PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
940: #endif
941: jac->dm_Nf[ii] = Nf;
942: }
943: { // d_bid_eqOffset_k
944: Kokkos::View<PetscInt *, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks + 1);
945: if (pack) PetscCall(DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX));
946: h_block_offsets[0] = 0;
947: jac->const_block_size = -1;
948: for (PetscInt ii = 0, idx = 0; ii < nDMs; ii++) {
949: PetscInt nloc, nblk;
950: if (pack) PetscCall(VecGetSize(subX[ii], &nloc));
951: else nloc = block_sizes[ii];
952: nblk = nloc / jac->dm_Nf[ii];
953: PetscCheck(nloc % jac->dm_Nf[ii] == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_USER, "nloc%%jac->dm_Nf[ii] (%" PetscInt_FMT ") != 0 DMs", nloc % jac->dm_Nf[ii]);
954: for (PetscInt jj = 0; jj < jac->dm_Nf[ii]; jj++, idx++) {
955: h_block_offsets[idx + 1] = h_block_offsets[idx] + nblk;
956: #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
957: if (idx == 0) PetscCall(PetscInfo(pc, "Add first of %" PetscInt_FMT " blocks with %" PetscInt_FMT " equations\n", jac->nBlocks, nblk));
958: #else
959: PetscCall(PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks));
960: #endif
961: if (jac->const_block_size == -1) jac->const_block_size = nblk;
962: else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
963: }
964: }
965: if (pack) {
966: PetscCall(DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX));
967: PetscCall(PetscFree(subX));
968: PetscCall(PetscFree(subDM));
969: }
970: jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt *, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(), h_block_offsets));
971: Kokkos::deep_copy(*jac->d_bid_eqOffset_k, h_block_offsets);
972: }
973: if (!pack) PetscCall(PetscFree(block_sizes));
974: }
975: { // get jac->d_idiag_k (PC setup),
976: const PetscInt *d_ai, *d_aj;
977: const PetscScalar *d_aa;
978: const PetscInt conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
979: const PetscInt *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
980: PetscScalar *d_idiag = jac->d_idiag_k->data(), *dummy;
981: PetscMemType mtype;
982: PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &d_ai, &d_aj, &dummy, &mtype));
983: d_aa = dummy;
984: Kokkos::parallel_for(
985: "Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
986: const PetscInt blkID = team.league_rank();
987: Kokkos::parallel_for(Kokkos::TeamThreadRange(team, d_bid_eqOffset[blkID], d_bid_eqOffset[blkID + 1]), [=](const int rowb) {
988: const PetscInt rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
989: const PetscScalar *aa = d_aa + ai;
990: const PetscInt nrow = d_ai[rowa + 1] - ai;
991: int found;
992: Kokkos::parallel_reduce(
993: Kokkos::ThreadVectorRange(team, nrow),
994: [=](const int &j, int &count) {
995: const PetscInt colb = r[aj[j]];
996: if (colb == rowb) {
997: d_idiag[rowb] = 1. / aa[j];
998: count++;
999: }
1000: },
1001: found);
1002: #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1003: if (found != 1) Kokkos::single(Kokkos::PerThread(team), [=]() { printf("ERRORrow %d) found = %d\n", rowb, found); });
1004: #endif
1005: });
1006: });
1007: }
1008: }
1009: PetscFunctionReturn(PETSC_SUCCESS);
1010: }
1012: /* Default destroy, if it has never been setup */
1013: static PetscErrorCode PCReset_BJKOKKOS(PC pc)
1014: {
1015: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1017: PetscFunctionBegin;
1018: PetscCall(KSPDestroy(&jac->ksp));
1019: PetscCall(VecDestroy(&jac->vec_diag));
1020: if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
1021: if (jac->d_idiag_k) delete jac->d_idiag_k;
1022: if (jac->d_isrow_k) delete jac->d_isrow_k;
1023: if (jac->d_isicol_k) delete jac->d_isicol_k;
1024: jac->d_bid_eqOffset_k = NULL;
1025: jac->d_idiag_k = NULL;
1026: jac->d_isrow_k = NULL;
1027: jac->d_isicol_k = NULL;
1028: PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", NULL)); // not published now (causes configure errors)
1029: PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", NULL));
1030: PetscCall(PetscFree(jac->dm_Nf));
1031: jac->dm_Nf = NULL;
1032: if (jac->rowOffsets) delete jac->rowOffsets;
1033: if (jac->colIndices) delete jac->colIndices;
1034: if (jac->batch_b) delete jac->batch_b;
1035: if (jac->batch_x) delete jac->batch_x;
1036: if (jac->batch_values) delete jac->batch_values;
1037: jac->rowOffsets = NULL;
1038: jac->colIndices = NULL;
1039: jac->batch_b = NULL;
1040: jac->batch_x = NULL;
1041: jac->batch_values = NULL;
1042: PetscFunctionReturn(PETSC_SUCCESS);
1043: }
1045: static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
1046: {
1047: PetscFunctionBegin;
1048: PetscCall(PCReset_BJKOKKOS(pc));
1049: PetscCall(PetscFree(pc->data));
1050: PetscFunctionReturn(PETSC_SUCCESS);
1051: }
1053: static PetscErrorCode PCView_BJKOKKOS(PC pc, PetscViewer viewer)
1054: {
1055: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1056: PetscBool iascii;
1058: PetscFunctionBegin;
1059: if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1060: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
1061: if (iascii) {
1062: PetscCall(PetscViewerASCIIPrintf(viewer, " Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n"));
1063: PetscCall(PetscViewerASCIIPrintf(viewer, "\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n", jac->nwork, jac->ksp->rtol, jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
1064: ((PetscObject)jac->ksp)->type_name));
1065: }
1066: PetscFunctionReturn(PETSC_SUCCESS);
1067: }
1069: static PetscErrorCode PCSetFromOptions_BJKOKKOS(PC pc, PetscOptionItems *PetscOptionsObject)
1070: {
1071: PetscFunctionBegin;
1072: PetscOptionsHeadBegin(PetscOptionsObject, "PC BJKOKKOS options");
1073: PetscOptionsHeadEnd();
1074: PetscFunctionReturn(PETSC_SUCCESS);
1075: }
1077: static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc, KSP ksp)
1078: {
1079: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1081: PetscFunctionBegin;
1082: PetscCall(PetscObjectReference((PetscObject)ksp));
1083: PetscCall(KSPDestroy(&jac->ksp));
1084: jac->ksp = ksp;
1085: PetscFunctionReturn(PETSC_SUCCESS);
1086: }
1088: /*@
1089: PCBJKOKKOSSetKSP - Sets the `KSP` context for `PCBJKOKKOS`
1091: Collective
1093: Input Parameters:
1094: + pc - the `PCBJKOKKOS` preconditioner context
1095: - ksp - the `KSP` solver
1097: Level: advanced
1099: Notes:
1100: The `PC` and the `KSP` must have the same communicator
1102: If the `PC` is not `PCBJKOKKOS` this function returns without doing anything
1104: .seealso: [](ch_ksp), `PCBJKOKKOSGetKSP()`, `PCBJKOKKOS`
1105: @*/
1106: PetscErrorCode PCBJKOKKOSSetKSP(PC pc, KSP ksp)
1107: {
1108: PetscFunctionBegin;
1111: PetscCheckSameComm(pc, 1, ksp, 2);
1112: PetscTryMethod(pc, "PCBJKOKKOSSetKSP_C", (PC, KSP), (pc, ksp));
1113: PetscFunctionReturn(PETSC_SUCCESS);
1114: }
1116: static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc, KSP *ksp)
1117: {
1118: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1120: PetscFunctionBegin;
1121: if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1122: *ksp = jac->ksp;
1123: PetscFunctionReturn(PETSC_SUCCESS);
1124: }
1126: /*@
1127: PCBJKOKKOSGetKSP - Gets the `KSP` context for the `PCBJKOKKOS` preconditioner
1129: Not Collective but `KSP` returned is parallel if `PC` was parallel
1131: Input Parameter:
1132: . pc - the preconditioner context
1134: Output Parameter:
1135: . ksp - the `KSP` solver
1137: Level: advanced
1139: Notes:
1140: You must call `KSPSetUp()` before calling `PCBJKOKKOSGetKSP()`.
1142: If the `PC` is not a `PCBJKOKKOS` object it raises an error
1144: .seealso: [](ch_ksp), `PCBJKOKKOS`, `PCBJKOKKOSSetKSP()`
1145: @*/
1146: PetscErrorCode PCBJKOKKOSGetKSP(PC pc, KSP *ksp)
1147: {
1148: PetscFunctionBegin;
1150: PetscAssertPointer(ksp, 2);
1151: PetscUseMethod(pc, "PCBJKOKKOSGetKSP_C", (PC, KSP *), (pc, ksp));
1152: PetscFunctionReturn(PETSC_SUCCESS);
1153: }
1155: static PetscErrorCode PCPostSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1156: {
1157: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1159: PetscFunctionBegin;
1161: ksp->its = jac->max_nits;
1162: PetscFunctionReturn(PETSC_SUCCESS);
1163: }
1165: static PetscErrorCode PCPreSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1166: {
1167: PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1169: PetscFunctionBegin;
1171: jac->ksp->errorifnotconverged = ksp->errorifnotconverged;
1172: PetscFunctionReturn(PETSC_SUCCESS);
1173: }
1175: /*MC
1176: PCBJKOKKOS - Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a `MATSEQAIJ` matrix on the GPU using Kokkos
1178: Options Database Key:
1179: . -pc_bjkokkos_ - options prefix for its `KSP` options
1181: Level: intermediate
1183: Note:
1184: For use with -ksp_type preonly to bypass any computation on the CPU
1186: Developer Notes:
1187: The documentation is incomplete. Is this a block Jacobi preconditioner?
1189: Why does it have its own `KSP`? Where is the `KSP` run if used with -ksp_type preonly?
1191: .seealso: [](ch_ksp), `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCBJACOBI`,
1192: `PCSHELL`, `PCCOMPOSITE`, `PCSetUseAmat()`, `PCBJKOKKOSGetKSP()`
1193: M*/
1195: PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
1196: {
1197: PC_PCBJKOKKOS *jac;
1199: PetscFunctionBegin;
1200: PetscCall(PetscNew(&jac));
1201: pc->data = (void *)jac;
1203: jac->ksp = NULL;
1204: jac->vec_diag = NULL;
1205: jac->d_bid_eqOffset_k = NULL;
1206: jac->d_idiag_k = NULL;
1207: jac->d_isrow_k = NULL;
1208: jac->d_isicol_k = NULL;
1209: jac->nBlocks = 1;
1210: jac->max_nits = 0;
1212: PetscCall(PetscMemzero(pc->ops, sizeof(struct _PCOps)));
1213: pc->ops->apply = PCApply_BJKOKKOS;
1214: pc->ops->applytranspose = NULL;
1215: pc->ops->setup = PCSetUp_BJKOKKOS;
1216: pc->ops->reset = PCReset_BJKOKKOS;
1217: pc->ops->destroy = PCDestroy_BJKOKKOS;
1218: pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
1219: pc->ops->view = PCView_BJKOKKOS;
1220: pc->ops->postsolve = PCPostSolve_BJKOKKOS;
1221: pc->ops->presolve = PCPreSolve_BJKOKKOS;
1223: jac->rowOffsets = NULL;
1224: jac->colIndices = NULL;
1225: jac->batch_b = NULL;
1226: jac->batch_x = NULL;
1227: jac->batch_values = NULL;
1229: PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", PCBJKOKKOSGetKSP_BJKOKKOS));
1230: PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", PCBJKOKKOSSetKSP_BJKOKKOS));
1231: PetscFunctionReturn(PETSC_SUCCESS);
1232: }