Actual source code: irk.c

  1: /*
  2:   Code for timestepping with implicit Runge-Kutta method

  4:   Notes:
  5:   The general system is written as

  7:   F(t,U,Udot) = 0

  9: */
 10: #include <petsc/private/tsimpl.h>
 11: #include <petscdm.h>
 12: #include <petscdt.h>

 14: static TSIRKType         TSIRKDefault = TSIRKGAUSS;
 15: static PetscBool         TSIRKRegisterAllCalled;
 16: static PetscBool         TSIRKPackageInitialized;
 17: static PetscFunctionList TSIRKList;

 19: struct _IRKTableau {
 20:   PetscReal   *A, *b, *c;
 21:   PetscScalar *A_inv, *A_inv_rowsum, *I_s;
 22:   PetscReal   *binterp; /* Dense output formula */
 23: };

 25: typedef struct _IRKTableau *IRKTableau;

 27: typedef struct {
 28:   char        *method_name;
 29:   PetscInt     order;   /* Classical approximation order of the method */
 30:   PetscInt     nstages; /* Number of stages */
 31:   PetscBool    stiffly_accurate;
 32:   PetscInt     pinterp; /* Interpolation order */
 33:   IRKTableau   tableau;
 34:   Vec          U0;    /* Backup vector */
 35:   Vec          Z;     /* Combined stage vector */
 36:   Vec         *Y;     /* States computed during the step */
 37:   Vec          Ydot;  /* Work vector holding time derivatives during residual evaluation */
 38:   Vec          U;     /* U is used to compute Ydot = shift(Y-U) */
 39:   Vec         *YdotI; /* Work vectors to hold the residual evaluation */
 40:   Mat          TJ;    /* KAIJ matrix for the Jacobian of the combined system */
 41:   PetscScalar *work;  /* Scalar work */
 42:   TSStepStatus status;
 43:   PetscBool    rebuild_completion;
 44:   PetscReal    ccfl;
 45: } TS_IRK;

 47: /*@C
 48:   TSIRKTableauCreate - create the tableau for `TSIRK` and provide the entries

 50:   Not Collective

 52:   Input Parameters:
 53: + ts           - timestepping context
 54: . nstages      - number of stages, this is the dimension of the matrices below
 55: . A            - stage coefficients (dimension nstages*nstages, row-major)
 56: . b            - step completion table (dimension nstages)
 57: . c            - abscissa (dimension nstages)
 58: . binterp      - coefficients of the interpolation formula (dimension nstages)
 59: . A_inv        - inverse of A (dimension nstages*nstages, row-major)
 60: . A_inv_rowsum - row sum of the inverse of A (dimension nstages)
 61: - I_s          - identity matrix (dimension nstages*nstages)

 63:   Level: advanced

 65: .seealso: [](ch_ts), `TSIRK`, `TSIRKRegister()`
 66: @*/
 67: PetscErrorCode TSIRKTableauCreate(TS ts, PetscInt nstages, const PetscReal *A, const PetscReal *b, const PetscReal *c, const PetscReal *binterp, const PetscScalar *A_inv, const PetscScalar *A_inv_rowsum, const PetscScalar *I_s)
 68: {
 69:   TS_IRK    *irk = (TS_IRK *)ts->data;
 70:   IRKTableau tab = irk->tableau;

 72:   PetscFunctionBegin;
 73:   irk->order = nstages;
 74:   PetscCall(PetscMalloc3(PetscSqr(nstages), &tab->A, PetscSqr(nstages), &tab->A_inv, PetscSqr(nstages), &tab->I_s));
 75:   PetscCall(PetscMalloc4(nstages, &tab->b, nstages, &tab->c, nstages, &tab->binterp, nstages, &tab->A_inv_rowsum));
 76:   PetscCall(PetscArraycpy(tab->A, A, PetscSqr(nstages)));
 77:   PetscCall(PetscArraycpy(tab->b, b, nstages));
 78:   PetscCall(PetscArraycpy(tab->c, c, nstages));
 79:   /* optional coefficient arrays */
 80:   if (binterp) PetscCall(PetscArraycpy(tab->binterp, binterp, nstages));
 81:   if (A_inv) PetscCall(PetscArraycpy(tab->A_inv, A_inv, PetscSqr(nstages)));
 82:   if (A_inv_rowsum) PetscCall(PetscArraycpy(tab->A_inv_rowsum, A_inv_rowsum, nstages));
 83:   if (I_s) PetscCall(PetscArraycpy(tab->I_s, I_s, PetscSqr(nstages)));
 84:   PetscFunctionReturn(PETSC_SUCCESS);
 85: }

 87: /* Arrays should be freed with PetscFree3(A,b,c) */
 88: static PetscErrorCode TSIRKCreate_Gauss(TS ts)
 89: {
 90:   PetscInt     nstages;
 91:   PetscReal   *gauss_A_real, *gauss_b, *b, *gauss_c;
 92:   PetscScalar *gauss_A, *gauss_A_inv, *gauss_A_inv_rowsum, *I_s;
 93:   PetscScalar *G0, *G1;
 94:   PetscInt     i, j;
 95:   Mat          G0mat, G1mat, Amat;

 97:   PetscFunctionBegin;
 98:   PetscCall(TSIRKGetNumStages(ts, &nstages));
 99:   PetscCall(PetscMalloc3(PetscSqr(nstages), &gauss_A_real, nstages, &gauss_b, nstages, &gauss_c));
100:   PetscCall(PetscMalloc4(PetscSqr(nstages), &gauss_A, PetscSqr(nstages), &gauss_A_inv, nstages, &gauss_A_inv_rowsum, PetscSqr(nstages), &I_s));
101:   PetscCall(PetscMalloc3(nstages, &b, PetscSqr(nstages), &G0, PetscSqr(nstages), &G1));
102:   PetscCall(PetscDTGaussQuadrature(nstages, 0., 1., gauss_c, b));
103:   for (i = 0; i < nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */

105:   /* A^T = G0^{-1} G1 */
106:   for (i = 0; i < nstages; i++) {
107:     for (j = 0; j < nstages; j++) {
108:       G0[i * nstages + j] = PetscPowRealInt(gauss_c[i], j);
109:       G1[i * nstages + j] = PetscPowRealInt(gauss_c[i], j + 1) / (j + 1);
110:     }
111:   }
112:   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
113:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G0, &G0mat));
114:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G1, &G1mat));
115:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, gauss_A, &Amat));
116:   PetscCall(MatLUFactor(G0mat, NULL, NULL, NULL));
117:   PetscCall(MatMatSolve(G0mat, G1mat, Amat));
118:   PetscCall(MatTranspose(Amat, MAT_INPLACE_MATRIX, &Amat));
119:   for (i = 0; i < nstages; i++)
120:     for (j = 0; j < nstages; j++) gauss_A_real[i * nstages + j] = PetscRealPart(gauss_A[i * nstages + j]);

122:   PetscCall(MatDestroy(&G0mat));
123:   PetscCall(MatDestroy(&G1mat));
124:   PetscCall(MatDestroy(&Amat));
125:   PetscCall(PetscFree3(b, G0, G1));

127:   { /* Invert A */
128:     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
129:      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
130:      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
131:     Mat                A_baij;
132:     PetscInt           idxm[1] = {0}, idxn[1] = {0};
133:     const PetscScalar *A_inv;

135:     PetscCall(MatCreateSeqBAIJ(PETSC_COMM_SELF, nstages, nstages, nstages, 1, NULL, &A_baij));
136:     PetscCall(MatSetOption(A_baij, MAT_ROW_ORIENTED, PETSC_FALSE));
137:     PetscCall(MatSetValuesBlocked(A_baij, 1, idxm, 1, idxn, gauss_A, INSERT_VALUES));
138:     PetscCall(MatAssemblyBegin(A_baij, MAT_FINAL_ASSEMBLY));
139:     PetscCall(MatAssemblyEnd(A_baij, MAT_FINAL_ASSEMBLY));
140:     PetscCall(MatInvertBlockDiagonal(A_baij, &A_inv));
141:     PetscCall(PetscMemcpy(gauss_A_inv, A_inv, nstages * nstages * sizeof(PetscScalar)));
142:     PetscCall(MatDestroy(&A_baij));
143:   }

145:   /* Compute row sums A_inv_rowsum and identity I_s */
146:   for (i = 0; i < nstages; i++) {
147:     gauss_A_inv_rowsum[i] = 0;
148:     for (j = 0; j < nstages; j++) {
149:       gauss_A_inv_rowsum[i] += gauss_A_inv[i + nstages * j];
150:       I_s[i + nstages * j] = 1. * (i == j);
151:     }
152:   }
153:   PetscCall(TSIRKTableauCreate(ts, nstages, gauss_A_real, gauss_b, gauss_c, NULL, gauss_A_inv, gauss_A_inv_rowsum, I_s));
154:   PetscCall(PetscFree3(gauss_A_real, gauss_b, gauss_c));
155:   PetscCall(PetscFree4(gauss_A, gauss_A_inv, gauss_A_inv_rowsum, I_s));
156:   PetscFunctionReturn(PETSC_SUCCESS);
157: }

159: /*@C
160:   TSIRKRegister -  adds a `TSIRK` implementation

162:   Not Collective, No Fortran Support

164:   Input Parameters:
165: + sname    - name of user-defined IRK scheme
166: - function - function to create method context

168:   Level: advanced

170:   Note:
171:   `TSIRKRegister()` may be called multiple times to add several user-defined families.

173:   Example Usage:
174: .vb
175:    TSIRKRegister("my_scheme", MySchemeCreate);
176: .ve

178:   Then, your scheme can be chosen with the procedural interface via
179: .vb
180:   TSIRKSetType(ts, "my_scheme")
181: .ve
182:   or at runtime via the option
183: .vb
184:   -ts_irk_type my_scheme
185: .ve

187: .seealso: [](ch_ts), `TSIRK`, `TSIRKRegisterAll()`
188: @*/
189: PetscErrorCode TSIRKRegister(const char sname[], PetscErrorCode (*function)(TS))
190: {
191:   PetscFunctionBegin;
192:   PetscCall(TSIRKInitializePackage());
193:   PetscCall(PetscFunctionListAdd(&TSIRKList, sname, function));
194:   PetscFunctionReturn(PETSC_SUCCESS);
195: }

197: /*@C
198:   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in `TSIRK`

200:   Not Collective, but should be called by all processes which will need the schemes to be registered

202:   Level: advanced

204: .seealso: [](ch_ts), `TSIRK`, `TSIRKRegisterDestroy()`
205: @*/
206: PetscErrorCode TSIRKRegisterAll(void)
207: {
208:   PetscFunctionBegin;
209:   if (TSIRKRegisterAllCalled) PetscFunctionReturn(PETSC_SUCCESS);
210:   TSIRKRegisterAllCalled = PETSC_TRUE;

212:   PetscCall(TSIRKRegister(TSIRKGAUSS, TSIRKCreate_Gauss));
213:   PetscFunctionReturn(PETSC_SUCCESS);
214: }

216: /*@C
217:   TSIRKRegisterDestroy - Frees the list of schemes that were registered by `TSIRKRegister()`.

219:   Not Collective

221:   Level: advanced

223: .seealso: [](ch_ts), `TSIRK`, `TSIRKRegister()`, `TSIRKRegisterAll()`
224: @*/
225: PetscErrorCode TSIRKRegisterDestroy(void)
226: {
227:   PetscFunctionBegin;
228:   TSIRKRegisterAllCalled = PETSC_FALSE;
229:   PetscFunctionReturn(PETSC_SUCCESS);
230: }

232: /*@C
233:   TSIRKInitializePackage - This function initializes everything in the `TSIRK` package. It is called
234:   from `TSInitializePackage()`.

236:   Level: developer

238: .seealso: [](ch_ts), `TSIRK`, `PetscInitialize()`, `TSIRKFinalizePackage()`, `TSInitializePackage()`
239: @*/
240: PetscErrorCode TSIRKInitializePackage(void)
241: {
242:   PetscFunctionBegin;
243:   if (TSIRKPackageInitialized) PetscFunctionReturn(PETSC_SUCCESS);
244:   TSIRKPackageInitialized = PETSC_TRUE;
245:   PetscCall(TSIRKRegisterAll());
246:   PetscCall(PetscRegisterFinalize(TSIRKFinalizePackage));
247:   PetscFunctionReturn(PETSC_SUCCESS);
248: }

250: /*@C
251:   TSIRKFinalizePackage - This function destroys everything in the `TSIRK` package. It is
252:   called from `PetscFinalize()`.

254:   Level: developer

256: .seealso: [](ch_ts), `TSIRK`, `PetscFinalize()`, `TSInitializePackage()`
257: @*/
258: PetscErrorCode TSIRKFinalizePackage(void)
259: {
260:   PetscFunctionBegin;
261:   PetscCall(PetscFunctionListDestroy(&TSIRKList));
262:   TSIRKPackageInitialized = PETSC_FALSE;
263:   PetscFunctionReturn(PETSC_SUCCESS);
264: }

266: /*
267:  This function can be called before or after ts->vec_sol has been updated.
268: */
269: static PetscErrorCode TSEvaluateStep_IRK(TS ts, PetscInt order, Vec U, PetscBool *done)
270: {
271:   TS_IRK      *irk   = (TS_IRK *)ts->data;
272:   IRKTableau   tab   = irk->tableau;
273:   Vec         *YdotI = irk->YdotI;
274:   PetscScalar *w     = irk->work;
275:   PetscReal    h;
276:   PetscInt     j;

278:   PetscFunctionBegin;
279:   switch (irk->status) {
280:   case TS_STEP_INCOMPLETE:
281:   case TS_STEP_PENDING:
282:     h = ts->time_step;
283:     break;
284:   case TS_STEP_COMPLETE:
285:     h = ts->ptime - ts->ptime_prev;
286:     break;
287:   default:
288:     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
289:   }

291:   PetscCall(VecCopy(ts->vec_sol, U));
292:   for (j = 0; j < irk->nstages; j++) w[j] = h * tab->b[j];
293:   PetscCall(VecMAXPY(U, irk->nstages, w, YdotI));
294:   PetscFunctionReturn(PETSC_SUCCESS);
295: }

297: static PetscErrorCode TSRollBack_IRK(TS ts)
298: {
299:   TS_IRK *irk = (TS_IRK *)ts->data;

301:   PetscFunctionBegin;
302:   PetscCall(VecCopy(irk->U0, ts->vec_sol));
303:   PetscFunctionReturn(PETSC_SUCCESS);
304: }

306: static PetscErrorCode TSStep_IRK(TS ts)
307: {
308:   TS_IRK        *irk   = (TS_IRK *)ts->data;
309:   IRKTableau     tab   = irk->tableau;
310:   PetscScalar   *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
311:   const PetscInt nstages = irk->nstages;
312:   SNES           snes;
313:   PetscInt       i, j, its, lits, bs;
314:   TSAdapt        adapt;
315:   PetscInt       rejections     = 0;
316:   PetscBool      accept         = PETSC_TRUE;
317:   PetscReal      next_time_step = ts->time_step;

319:   PetscFunctionBegin;
320:   if (!ts->steprollback) PetscCall(VecCopy(ts->vec_sol, irk->U0));
321:   PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
322:   for (i = 0; i < nstages; i++) PetscCall(VecStrideScatter(ts->vec_sol, i * bs, irk->Z, INSERT_VALUES));

324:   irk->status = TS_STEP_INCOMPLETE;
325:   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
326:     PetscCall(VecCopy(ts->vec_sol, irk->U));
327:     PetscCall(TSGetSNES(ts, &snes));
328:     PetscCall(SNESSolve(snes, NULL, irk->Z));
329:     PetscCall(SNESGetIterationNumber(snes, &its));
330:     PetscCall(SNESGetLinearSolveIterations(snes, &lits));
331:     ts->snes_its += its;
332:     ts->ksp_its += lits;
333:     PetscCall(VecStrideGatherAll(irk->Z, irk->Y, INSERT_VALUES));
334:     for (i = 0; i < nstages; i++) {
335:       PetscCall(VecZeroEntries(irk->YdotI[i]));
336:       for (j = 0; j < nstages; j++) PetscCall(VecAXPY(irk->YdotI[i], A_inv[i + j * nstages] / ts->time_step, irk->Y[j]));
337:       PetscCall(VecAXPY(irk->YdotI[i], -A_inv_rowsum[i] / ts->time_step, irk->U));
338:     }
339:     irk->status = TS_STEP_INCOMPLETE;
340:     PetscCall(TSEvaluateStep_IRK(ts, irk->order, ts->vec_sol, NULL));
341:     irk->status = TS_STEP_PENDING;
342:     PetscCall(TSGetAdapt(ts, &adapt));
343:     PetscCall(TSAdaptChoose(adapt, ts, ts->time_step, NULL, &next_time_step, &accept));
344:     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
345:     if (!accept) {
346:       PetscCall(TSRollBack_IRK(ts));
347:       ts->time_step = next_time_step;
348:       goto reject_step;
349:     }

351:     ts->ptime += ts->time_step;
352:     ts->time_step = next_time_step;
353:     break;
354:   reject_step:
355:     ts->reject++;
356:     accept = PETSC_FALSE;
357:     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
358:       ts->reason = TS_DIVERGED_STEP_REJECTED;
359:       PetscCall(PetscInfo(ts, "Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n", ts->steps, rejections));
360:     }
361:   }
362:   PetscFunctionReturn(PETSC_SUCCESS);
363: }

365: static PetscErrorCode TSInterpolate_IRK(TS ts, PetscReal itime, Vec U)
366: {
367:   TS_IRK          *irk     = (TS_IRK *)ts->data;
368:   PetscInt         nstages = irk->nstages, pinterp = irk->pinterp, i, j;
369:   PetscReal        h;
370:   PetscReal        tt, t;
371:   PetscScalar     *bt;
372:   const PetscReal *B = irk->tableau->binterp;

374:   PetscFunctionBegin;
375:   PetscCheck(B, PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not have an interpolation formula", irk->method_name);
376:   switch (irk->status) {
377:   case TS_STEP_INCOMPLETE:
378:   case TS_STEP_PENDING:
379:     h = ts->time_step;
380:     t = (itime - ts->ptime) / h;
381:     break;
382:   case TS_STEP_COMPLETE:
383:     h = ts->ptime - ts->ptime_prev;
384:     t = (itime - ts->ptime) / h + 1; /* In the interval [0,1] */
385:     break;
386:   default:
387:     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
388:   }
389:   PetscCall(PetscMalloc1(nstages, &bt));
390:   for (i = 0; i < nstages; i++) bt[i] = 0;
391:   for (j = 0, tt = t; j < pinterp; j++, tt *= t) {
392:     for (i = 0; i < nstages; i++) bt[i] += h * B[i * pinterp + j] * tt;
393:   }
394:   PetscCall(VecMAXPY(U, nstages, bt, irk->YdotI));
395:   PetscFunctionReturn(PETSC_SUCCESS);
396: }

398: static PetscErrorCode TSIRKTableauReset(TS ts)
399: {
400:   TS_IRK    *irk = (TS_IRK *)ts->data;
401:   IRKTableau tab = irk->tableau;

403:   PetscFunctionBegin;
404:   if (!tab) PetscFunctionReturn(PETSC_SUCCESS);
405:   PetscCall(PetscFree3(tab->A, tab->A_inv, tab->I_s));
406:   PetscCall(PetscFree4(tab->b, tab->c, tab->binterp, tab->A_inv_rowsum));
407:   PetscFunctionReturn(PETSC_SUCCESS);
408: }

410: static PetscErrorCode TSReset_IRK(TS ts)
411: {
412:   TS_IRK *irk = (TS_IRK *)ts->data;

414:   PetscFunctionBegin;
415:   PetscCall(TSIRKTableauReset(ts));
416:   if (irk->tableau) PetscCall(PetscFree(irk->tableau));
417:   if (irk->method_name) PetscCall(PetscFree(irk->method_name));
418:   if (irk->work) PetscCall(PetscFree(irk->work));
419:   PetscCall(VecDestroyVecs(irk->nstages, &irk->Y));
420:   PetscCall(VecDestroyVecs(irk->nstages, &irk->YdotI));
421:   PetscCall(VecDestroy(&irk->Ydot));
422:   PetscCall(VecDestroy(&irk->Z));
423:   PetscCall(VecDestroy(&irk->U));
424:   PetscCall(VecDestroy(&irk->U0));
425:   PetscCall(MatDestroy(&irk->TJ));
426:   PetscFunctionReturn(PETSC_SUCCESS);
427: }

429: static PetscErrorCode TSIRKGetVecs(TS ts, DM dm, Vec *U)
430: {
431:   TS_IRK *irk = (TS_IRK *)ts->data;

433:   PetscFunctionBegin;
434:   if (U) {
435:     if (dm && dm != ts->dm) {
436:       PetscCall(DMGetNamedGlobalVector(dm, "TSIRK_U", U));
437:     } else *U = irk->U;
438:   }
439:   PetscFunctionReturn(PETSC_SUCCESS);
440: }

442: static PetscErrorCode TSIRKRestoreVecs(TS ts, DM dm, Vec *U)
443: {
444:   PetscFunctionBegin;
445:   if (U) {
446:     if (dm && dm != ts->dm) PetscCall(DMRestoreNamedGlobalVector(dm, "TSIRK_U", U));
447:   }
448:   PetscFunctionReturn(PETSC_SUCCESS);
449: }

451: /*
452:   This defines the nonlinear equations that is to be solved with SNES
453:     G[e\otimes t + C*dt, Z, Zdot] = 0
454:     Zdot = (In \otimes S)*Z - (In \otimes Se) U
455:   where S = 1/(dt*A)
456: */
457: static PetscErrorCode SNESTSFormFunction_IRK(SNES snes, Vec ZC, Vec FC, TS ts)
458: {
459:   TS_IRK            *irk     = (TS_IRK *)ts->data;
460:   IRKTableau         tab     = irk->tableau;
461:   const PetscInt     nstages = irk->nstages;
462:   const PetscReal   *c       = tab->c;
463:   const PetscScalar *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
464:   DM                 dm, dmsave;
465:   Vec                U, *YdotI = irk->YdotI, Ydot = irk->Ydot, *Y = irk->Y;
466:   PetscReal          h = ts->time_step;
467:   PetscInt           i, j;

469:   PetscFunctionBegin;
470:   PetscCall(SNESGetDM(snes, &dm));
471:   PetscCall(TSIRKGetVecs(ts, dm, &U));
472:   PetscCall(VecStrideGatherAll(ZC, Y, INSERT_VALUES));
473:   dmsave = ts->dm;
474:   ts->dm = dm;
475:   for (i = 0; i < nstages; i++) {
476:     PetscCall(VecZeroEntries(Ydot));
477:     for (j = 0; j < nstages; j++) PetscCall(VecAXPY(Ydot, A_inv[j * nstages + i] / h, Y[j]));
478:     PetscCall(VecAXPY(Ydot, -A_inv_rowsum[i] / h, U)); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
479:     PetscCall(TSComputeIFunction(ts, ts->ptime + ts->time_step * c[i], Y[i], Ydot, YdotI[i], PETSC_FALSE));
480:   }
481:   PetscCall(VecStrideScatterAll(YdotI, FC, INSERT_VALUES));
482:   ts->dm = dmsave;
483:   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
484:   PetscFunctionReturn(PETSC_SUCCESS);
485: }

487: /*
488:    For explicit ODE, the Jacobian is
489:      JC = I_n \otimes S - J \otimes I_s
490:    For DAE, the Jacobian is
491:      JC = M_n \otimes S - J \otimes I_s
492: */
493: static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes, Vec ZC, Mat JC, Mat JCpre, TS ts)
494: {
495:   TS_IRK          *irk     = (TS_IRK *)ts->data;
496:   IRKTableau       tab     = irk->tableau;
497:   const PetscInt   nstages = irk->nstages;
498:   const PetscReal *c       = tab->c;
499:   DM               dm, dmsave;
500:   Vec             *Y = irk->Y, Ydot = irk->Ydot;
501:   Mat              J;
502:   PetscScalar     *S;
503:   PetscInt         i, j, bs;

505:   PetscFunctionBegin;
506:   PetscCall(SNESGetDM(snes, &dm));
507:   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
508:   dmsave = ts->dm;
509:   ts->dm = dm;
510:   PetscCall(VecGetBlockSize(Y[nstages - 1], &bs));
511:   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
512:     PetscCall(VecStrideGather(ZC, (nstages - 1) * bs, Y[nstages - 1], INSERT_VALUES));
513:     PetscCall(MatKAIJGetAIJ(JC, &J));
514:     PetscCall(TSComputeIJacobian(ts, ts->ptime + ts->time_step * c[nstages - 1], Y[nstages - 1], Ydot, 0, J, J, PETSC_FALSE));
515:     PetscCall(MatKAIJGetS(JC, NULL, NULL, &S));
516:     for (i = 0; i < nstages; i++)
517:       for (j = 0; j < nstages; j++) S[i + nstages * j] = tab->A_inv[i + nstages * j] / ts->time_step;
518:     PetscCall(MatKAIJRestoreS(JC, &S));
519:   } else SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not support implicit formula", irk->method_name); /* TODO: need the mass matrix for DAE  */
520:   ts->dm = dmsave;
521:   PetscFunctionReturn(PETSC_SUCCESS);
522: }

524: static PetscErrorCode DMCoarsenHook_TSIRK(DM fine, DM coarse, void *ctx)
525: {
526:   PetscFunctionBegin;
527:   PetscFunctionReturn(PETSC_SUCCESS);
528: }

530: static PetscErrorCode DMRestrictHook_TSIRK(DM fine, Mat restrct, Vec rscale, Mat inject, DM coarse, void *ctx)
531: {
532:   TS  ts = (TS)ctx;
533:   Vec U, U_c;

535:   PetscFunctionBegin;
536:   PetscCall(TSIRKGetVecs(ts, fine, &U));
537:   PetscCall(TSIRKGetVecs(ts, coarse, &U_c));
538:   PetscCall(MatRestrict(restrct, U, U_c));
539:   PetscCall(VecPointwiseMult(U_c, rscale, U_c));
540:   PetscCall(TSIRKRestoreVecs(ts, fine, &U));
541:   PetscCall(TSIRKRestoreVecs(ts, coarse, &U_c));
542:   PetscFunctionReturn(PETSC_SUCCESS);
543: }

545: static PetscErrorCode DMSubDomainHook_TSIRK(DM dm, DM subdm, void *ctx)
546: {
547:   PetscFunctionBegin;
548:   PetscFunctionReturn(PETSC_SUCCESS);
549: }

551: static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm, VecScatter gscat, VecScatter lscat, DM subdm, void *ctx)
552: {
553:   TS  ts = (TS)ctx;
554:   Vec U, U_c;

556:   PetscFunctionBegin;
557:   PetscCall(TSIRKGetVecs(ts, dm, &U));
558:   PetscCall(TSIRKGetVecs(ts, subdm, &U_c));

560:   PetscCall(VecScatterBegin(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));
561:   PetscCall(VecScatterEnd(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD));

563:   PetscCall(TSIRKRestoreVecs(ts, dm, &U));
564:   PetscCall(TSIRKRestoreVecs(ts, subdm, &U_c));
565:   PetscFunctionReturn(PETSC_SUCCESS);
566: }

568: static PetscErrorCode TSSetUp_IRK(TS ts)
569: {
570:   TS_IRK        *irk = (TS_IRK *)ts->data;
571:   IRKTableau     tab = irk->tableau;
572:   DM             dm;
573:   Mat            J;
574:   Vec            R;
575:   const PetscInt nstages = irk->nstages;
576:   PetscInt       vsize, bs;

578:   PetscFunctionBegin;
579:   if (!irk->work) PetscCall(PetscMalloc1(irk->nstages, &irk->work));
580:   if (!irk->Y) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->Y));
581:   if (!irk->YdotI) PetscCall(VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->YdotI));
582:   if (!irk->Ydot) PetscCall(VecDuplicate(ts->vec_sol, &irk->Ydot));
583:   if (!irk->U) PetscCall(VecDuplicate(ts->vec_sol, &irk->U));
584:   if (!irk->U0) PetscCall(VecDuplicate(ts->vec_sol, &irk->U0));
585:   if (!irk->Z) {
586:     PetscCall(VecCreate(PetscObjectComm((PetscObject)ts->vec_sol), &irk->Z));
587:     PetscCall(VecGetSize(ts->vec_sol, &vsize));
588:     PetscCall(VecSetSizes(irk->Z, PETSC_DECIDE, vsize * irk->nstages));
589:     PetscCall(VecGetBlockSize(ts->vec_sol, &bs));
590:     PetscCall(VecSetBlockSize(irk->Z, irk->nstages * bs));
591:     PetscCall(VecSetFromOptions(irk->Z));
592:   }
593:   PetscCall(TSGetDM(ts, &dm));
594:   PetscCall(DMCoarsenHookAdd(dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
595:   PetscCall(DMSubDomainHookAdd(dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));

597:   PetscCall(TSGetSNES(ts, &ts->snes));
598:   PetscCall(VecDuplicate(irk->Z, &R));
599:   PetscCall(SNESSetFunction(ts->snes, R, SNESTSFormFunction, ts));
600:   PetscCall(TSGetIJacobian(ts, &J, NULL, NULL, NULL));
601:   if (!irk->TJ) {
602:     /* Create the KAIJ matrix for solving the stages */
603:     PetscCall(MatCreateKAIJ(J, nstages, nstages, tab->A_inv, tab->I_s, &irk->TJ));
604:   }
605:   PetscCall(SNESSetJacobian(ts->snes, irk->TJ, irk->TJ, SNESTSFormJacobian, ts));
606:   PetscCall(VecDestroy(&R));
607:   PetscFunctionReturn(PETSC_SUCCESS);
608: }

610: static PetscErrorCode TSSetFromOptions_IRK(TS ts, PetscOptionItems PetscOptionsObject)
611: {
612:   TS_IRK *irk        = (TS_IRK *)ts->data;
613:   char    tname[256] = TSIRKGAUSS;

615:   PetscFunctionBegin;
616:   PetscOptionsHeadBegin(PetscOptionsObject, "IRK ODE solver options");
617:   {
618:     PetscBool flg1, flg2;
619:     PetscCall(PetscOptionsInt("-ts_irk_nstages", "Stages of the IRK method", "TSIRKSetNumStages", irk->nstages, &irk->nstages, &flg1));
620:     PetscCall(PetscOptionsFList("-ts_irk_type", "Type of IRK method", "TSIRKSetType", TSIRKList, irk->method_name[0] ? irk->method_name : tname, tname, sizeof(tname), &flg2));
621:     if (flg1 || flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
622:       PetscCall(TSIRKSetType(ts, tname));
623:     }
624:   }
625:   PetscOptionsHeadEnd();
626:   PetscFunctionReturn(PETSC_SUCCESS);
627: }

629: static PetscErrorCode TSView_IRK(TS ts, PetscViewer viewer)
630: {
631:   TS_IRK   *irk = (TS_IRK *)ts->data;
632:   PetscBool iascii;

634:   PetscFunctionBegin;
635:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
636:   if (iascii) {
637:     IRKTableau tab = irk->tableau;
638:     TSIRKType  irktype;
639:     char       buf[512];

641:     PetscCall(TSIRKGetType(ts, &irktype));
642:     PetscCall(PetscViewerASCIIPrintf(viewer, "  IRK type %s\n", irktype));
643:     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", irk->nstages, tab->c));
644:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Abscissa       c = %s\n", buf));
645:     PetscCall(PetscViewerASCIIPrintf(viewer, "Stiffly accurate: %s\n", irk->stiffly_accurate ? "yes" : "no"));
646:     PetscCall(PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", PetscSqr(irk->nstages), tab->A));
647:     PetscCall(PetscViewerASCIIPrintf(viewer, "  A coefficients       A = %s\n", buf));
648:   }
649:   PetscFunctionReturn(PETSC_SUCCESS);
650: }

652: static PetscErrorCode TSLoad_IRK(TS ts, PetscViewer viewer)
653: {
654:   SNES    snes;
655:   TSAdapt adapt;

657:   PetscFunctionBegin;
658:   PetscCall(TSGetAdapt(ts, &adapt));
659:   PetscCall(TSAdaptLoad(adapt, viewer));
660:   PetscCall(TSGetSNES(ts, &snes));
661:   PetscCall(SNESLoad(snes, viewer));
662:   /* function and Jacobian context for SNES when used with TS is always ts object */
663:   PetscCall(SNESSetFunction(snes, NULL, NULL, ts));
664:   PetscCall(SNESSetJacobian(snes, NULL, NULL, NULL, ts));
665:   PetscFunctionReturn(PETSC_SUCCESS);
666: }

668: /*@
669:   TSIRKSetType - Set the type of `TSIRK` scheme to use

671:   Logically Collective

673:   Input Parameters:
674: + ts      - timestepping context
675: - irktype - type of `TSIRK` scheme

677:   Options Database Key:
678: . -ts_irk_type <gauss> - set irk type

680:   Level: intermediate

682: .seealso: [](ch_ts), `TSIRKGetType()`, `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
683: @*/
684: PetscErrorCode TSIRKSetType(TS ts, TSIRKType irktype)
685: {
686:   PetscFunctionBegin;
688:   PetscAssertPointer(irktype, 2);
689:   PetscTryMethod(ts, "TSIRKSetType_C", (TS, TSIRKType), (ts, irktype));
690:   PetscFunctionReturn(PETSC_SUCCESS);
691: }

693: /*@
694:   TSIRKGetType - Get the type of `TSIRK` IMEX scheme being used

696:   Logically Collective

698:   Input Parameter:
699: . ts - timestepping context

701:   Output Parameter:
702: . irktype - type of `TSIRK` IMEX scheme

704:   Level: intermediate

706: .seealso: [](ch_ts), `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
707: @*/
708: PetscErrorCode TSIRKGetType(TS ts, TSIRKType *irktype)
709: {
710:   PetscFunctionBegin;
712:   PetscUseMethod(ts, "TSIRKGetType_C", (TS, TSIRKType *), (ts, irktype));
713:   PetscFunctionReturn(PETSC_SUCCESS);
714: }

716: /*@
717:   TSIRKSetNumStages - Set the number of stages of `TSIRK` scheme to use

719:   Logically Collective

721:   Input Parameters:
722: + ts      - timestepping context
723: - nstages - number of stages of `TSIRK` scheme

725:   Options Database Key:
726: . -ts_irk_nstages <int> - set number of stages

728:   Level: intermediate

730: .seealso: [](ch_ts), `TSIRKGetNumStages()`, `TSIRK`
731: @*/
732: PetscErrorCode TSIRKSetNumStages(TS ts, PetscInt nstages)
733: {
734:   PetscFunctionBegin;
736:   PetscTryMethod(ts, "TSIRKSetNumStages_C", (TS, PetscInt), (ts, nstages));
737:   PetscFunctionReturn(PETSC_SUCCESS);
738: }

740: /*@
741:   TSIRKGetNumStages - Get the number of stages of `TSIRK` scheme

743:   Logically Collective

745:   Input Parameters:
746: + ts      - timestepping context
747: - nstages - number of stages of `TSIRK` scheme

749:   Level: intermediate

751: .seealso: [](ch_ts), `TSIRKSetNumStages()`, `TSIRK`
752: @*/
753: PetscErrorCode TSIRKGetNumStages(TS ts, PetscInt *nstages)
754: {
755:   PetscFunctionBegin;
757:   PetscAssertPointer(nstages, 2);
758:   PetscTryMethod(ts, "TSIRKGetNumStages_C", (TS, PetscInt *), (ts, nstages));
759:   PetscFunctionReturn(PETSC_SUCCESS);
760: }

762: static PetscErrorCode TSIRKGetType_IRK(TS ts, TSIRKType *irktype)
763: {
764:   TS_IRK *irk = (TS_IRK *)ts->data;

766:   PetscFunctionBegin;
767:   *irktype = irk->method_name;
768:   PetscFunctionReturn(PETSC_SUCCESS);
769: }

771: static PetscErrorCode TSIRKSetType_IRK(TS ts, TSIRKType irktype)
772: {
773:   TS_IRK *irk = (TS_IRK *)ts->data;
774:   PetscErrorCode (*irkcreate)(TS);

776:   PetscFunctionBegin;
777:   if (irk->method_name) {
778:     PetscCall(PetscFree(irk->method_name));
779:     PetscCall(TSIRKTableauReset(ts));
780:   }
781:   PetscCall(PetscFunctionListFind(TSIRKList, irktype, &irkcreate));
782:   PetscCheck(irkcreate, PetscObjectComm((PetscObject)ts), PETSC_ERR_ARG_UNKNOWN_TYPE, "Unknown TSIRK type \"%s\" given", irktype);
783:   PetscCall((*irkcreate)(ts));
784:   PetscCall(PetscStrallocpy(irktype, &irk->method_name));
785:   PetscFunctionReturn(PETSC_SUCCESS);
786: }

788: static PetscErrorCode TSIRKSetNumStages_IRK(TS ts, PetscInt nstages)
789: {
790:   TS_IRK *irk = (TS_IRK *)ts->data;

792:   PetscFunctionBegin;
793:   PetscCheck(nstages > 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "input argument, %" PetscInt_FMT ", out of range", nstages);
794:   irk->nstages = nstages;
795:   PetscFunctionReturn(PETSC_SUCCESS);
796: }

798: static PetscErrorCode TSIRKGetNumStages_IRK(TS ts, PetscInt *nstages)
799: {
800:   TS_IRK *irk = (TS_IRK *)ts->data;

802:   PetscFunctionBegin;
803:   PetscAssertPointer(nstages, 2);
804:   *nstages = irk->nstages;
805:   PetscFunctionReturn(PETSC_SUCCESS);
806: }

808: static PetscErrorCode TSDestroy_IRK(TS ts)
809: {
810:   PetscFunctionBegin;
811:   PetscCall(TSReset_IRK(ts));
812:   if (ts->dm) {
813:     PetscCall(DMCoarsenHookRemove(ts->dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts));
814:     PetscCall(DMSubDomainHookRemove(ts->dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts));
815:   }
816:   PetscCall(PetscFree(ts->data));
817:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", NULL));
818:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", NULL));
819:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", NULL));
820:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", NULL));
821:   PetscFunctionReturn(PETSC_SUCCESS);
822: }

824: /*MC
825:       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes

827:   Level: beginner

829:   Notes:
830:   `TSIRK` uses the sparse Kronecker product matrix implementation of `MATKAIJ` to achieve good arithmetic intensity.

832:   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s
833:   when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with
834:   -ts_irk_nstages or `TSIRKSetNumStages()`.

836: .seealso: [](ch_ts), `TSCreate()`, `TS`, `TSSetType()`, `TSIRKSetType()`, `TSIRKGetType()`, `TSIRKGAUSS`, `TSIRKRegister()`, `TSIRKSetNumStages()`, `TSType`
837: M*/
838: PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts)
839: {
840:   TS_IRK *irk;

842:   PetscFunctionBegin;
843:   PetscCall(TSIRKInitializePackage());

845:   ts->ops->reset          = TSReset_IRK;
846:   ts->ops->destroy        = TSDestroy_IRK;
847:   ts->ops->view           = TSView_IRK;
848:   ts->ops->load           = TSLoad_IRK;
849:   ts->ops->setup          = TSSetUp_IRK;
850:   ts->ops->step           = TSStep_IRK;
851:   ts->ops->interpolate    = TSInterpolate_IRK;
852:   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
853:   ts->ops->rollback       = TSRollBack_IRK;
854:   ts->ops->setfromoptions = TSSetFromOptions_IRK;
855:   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
856:   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;

858:   ts->usessnes = PETSC_TRUE;

860:   PetscCall(PetscNew(&irk));
861:   ts->data = (void *)irk;

863:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", TSIRKSetType_IRK));
864:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", TSIRKGetType_IRK));
865:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", TSIRKSetNumStages_IRK));
866:   PetscCall(PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", TSIRKGetNumStages_IRK));
867:   /* 3-stage IRK_Gauss is the default */
868:   PetscCall(PetscNew(&irk->tableau));
869:   irk->nstages = 3;
870:   PetscCall(TSIRKSetType(ts, TSIRKDefault));
871:   PetscFunctionReturn(PETSC_SUCCESS);
872: }