Actual source code: irk.c

  1: /*
  2:   Code for timestepping with implicit Runge-Kutta method

  4:   Notes:
  5:   The general system is written as

  7:   F(t,U,Udot) = 0

  9: */
 10: #include <petsc/private/tsimpl.h>
 11: #include <petscdm.h>
 12: #include <petscdt.h>

 14: static TSIRKType         TSIRKDefault = TSIRKGAUSS;
 15: static PetscBool         TSIRKRegisterAllCalled;
 16: static PetscBool         TSIRKPackageInitialized;
 17: static PetscFunctionList TSIRKList;

 19: struct _IRKTableau {
 20:   PetscReal   *A, *b, *c;
 21:   PetscScalar *A_inv, *A_inv_rowsum, *I_s;
 22:   PetscReal   *binterp; /* Dense output formula */
 23: };

 25: typedef struct _IRKTableau *IRKTableau;

 27: typedef struct {
 28:   char        *method_name;
 29:   PetscInt     order;   /* Classical approximation order of the method */
 30:   PetscInt     nstages; /* Number of stages */
 31:   PetscBool    stiffly_accurate;
 32:   PetscInt     pinterp; /* Interpolation order */
 33:   IRKTableau   tableau;
 34:   Vec          U0;    /* Backup vector */
 35:   Vec          Z;     /* Combined stage vector */
 36:   Vec         *Y;     /* States computed during the step */
 37:   Vec          Ydot;  /* Work vector holding time derivatives during residual evaluation */
 38:   Vec          U;     /* U is used to compute Ydot = shift(Y-U) */
 39:   Vec         *YdotI; /* Work vectors to hold the residual evaluation */
 40:   Mat          TJ;    /* KAIJ matrix for the Jacobian of the combined system */
 41:   PetscScalar *work;  /* Scalar work */
 42:   TSStepStatus status;
 43:   PetscBool    rebuild_completion;
 44:   PetscReal    ccfl;
 45: } TS_IRK;

 47: /*@C
 48:    TSIRKTableauCreate - create the tableau for TSIRK and provide the entries

 50:    Not Collective

 52:    Input Parameters:
 53: +  ts - timestepping context
 54: .  nstages - number of stages, this is the dimension of the matrices below
 55: .  A - stage coefficients (dimension nstages*nstages, row-major)
 56: .  b - step completion table (dimension nstages)
 57: .  c - abscissa (dimension nstages)
 58: .  binterp - coefficients of the interpolation formula (dimension nstages)
 59: .  A_inv - inverse of A (dimension nstages*nstages, row-major)
 60: .  A_inv_rowsum - row sum of the inverse of A (dimension nstages)
 61: -  I_s - identity matrix (dimension nstages*nstages)

 63:    Level: advanced

 65: .seealso: `TSIRK`, `TSIRKRegister()`
 66: @*/
 67: PetscErrorCode TSIRKTableauCreate(TS ts, PetscInt nstages, const PetscReal *A, const PetscReal *b, const PetscReal *c, const PetscReal *binterp, const PetscScalar *A_inv, const PetscScalar *A_inv_rowsum, const PetscScalar *I_s)
 68: {
 69:   TS_IRK    *irk = (TS_IRK *)ts->data;
 70:   IRKTableau tab = irk->tableau;

 72:   irk->order = nstages;
 73:   PetscMalloc3(PetscSqr(nstages), &tab->A, PetscSqr(nstages), &tab->A_inv, PetscSqr(nstages), &tab->I_s);
 74:   PetscMalloc4(nstages, &tab->b, nstages, &tab->c, nstages, &tab->binterp, nstages, &tab->A_inv_rowsum);
 75:   PetscArraycpy(tab->A, A, PetscSqr(nstages));
 76:   PetscArraycpy(tab->b, b, nstages);
 77:   PetscArraycpy(tab->c, c, nstages);
 78:   /* optional coefficient arrays */
 79:   if (binterp) PetscArraycpy(tab->binterp, binterp, nstages);
 80:   if (A_inv) PetscArraycpy(tab->A_inv, A_inv, PetscSqr(nstages));
 81:   if (A_inv_rowsum) PetscArraycpy(tab->A_inv_rowsum, A_inv_rowsum, nstages);
 82:   if (I_s) PetscArraycpy(tab->I_s, I_s, PetscSqr(nstages));
 83:   return 0;
 84: }

 86: /* Arrays should be freed with PetscFree3(A,b,c) */
 87: static PetscErrorCode TSIRKCreate_Gauss(TS ts)
 88: {
 89:   PetscInt     nstages;
 90:   PetscReal   *gauss_A_real, *gauss_b, *b, *gauss_c;
 91:   PetscScalar *gauss_A, *gauss_A_inv, *gauss_A_inv_rowsum, *I_s;
 92:   PetscScalar *G0, *G1;
 93:   PetscInt     i, j;
 94:   Mat          G0mat, G1mat, Amat;

 96:   TSIRKGetNumStages(ts, &nstages);
 97:   PetscMalloc3(PetscSqr(nstages), &gauss_A_real, nstages, &gauss_b, nstages, &gauss_c);
 98:   PetscMalloc4(PetscSqr(nstages), &gauss_A, PetscSqr(nstages), &gauss_A_inv, nstages, &gauss_A_inv_rowsum, PetscSqr(nstages), &I_s);
 99:   PetscMalloc3(nstages, &b, PetscSqr(nstages), &G0, PetscSqr(nstages), &G1);
100:   PetscDTGaussQuadrature(nstages, 0., 1., gauss_c, b);
101:   for (i = 0; i < nstages; i++) gauss_b[i] = b[i]; /* copy to possibly-complex array */

103:   /* A^T = G0^{-1} G1 */
104:   for (i = 0; i < nstages; i++) {
105:     for (j = 0; j < nstages; j++) {
106:       G0[i * nstages + j] = PetscPowRealInt(gauss_c[i], j);
107:       G1[i * nstages + j] = PetscPowRealInt(gauss_c[i], j + 1) / (j + 1);
108:     }
109:   }
110:   /* The arrays above are row-aligned, but we create dense matrices as the transpose */
111:   MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G0, &G0mat);
112:   MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, G1, &G1mat);
113:   MatCreateSeqDense(PETSC_COMM_SELF, nstages, nstages, gauss_A, &Amat);
114:   MatLUFactor(G0mat, NULL, NULL, NULL);
115:   MatMatSolve(G0mat, G1mat, Amat);
116:   MatTranspose(Amat, MAT_INPLACE_MATRIX, &Amat);
117:   for (i = 0; i < nstages; i++)
118:     for (j = 0; j < nstages; j++) gauss_A_real[i * nstages + j] = PetscRealPart(gauss_A[i * nstages + j]);

120:   MatDestroy(&G0mat);
121:   MatDestroy(&G1mat);
122:   MatDestroy(&Amat);
123:   PetscFree3(b, G0, G1);

125:   { /* Invert A */
126:     /* PETSc does not provide a routine to calculate the inverse of a general matrix.
127:      * To get the inverse of A, we form a sequential BAIJ matrix from it, consisting of a single block with block size
128:      * equal to the dimension of A, and then use MatInvertBlockDiagonal(). */
129:     Mat                A_baij;
130:     PetscInt           idxm[1] = {0}, idxn[1] = {0};
131:     const PetscScalar *A_inv;

133:     MatCreateSeqBAIJ(PETSC_COMM_SELF, nstages, nstages, nstages, 1, NULL, &A_baij);
134:     MatSetOption(A_baij, MAT_ROW_ORIENTED, PETSC_FALSE);
135:     MatSetValuesBlocked(A_baij, 1, idxm, 1, idxn, gauss_A, INSERT_VALUES);
136:     MatAssemblyBegin(A_baij, MAT_FINAL_ASSEMBLY);
137:     MatAssemblyEnd(A_baij, MAT_FINAL_ASSEMBLY);
138:     MatInvertBlockDiagonal(A_baij, &A_inv);
139:     PetscMemcpy(gauss_A_inv, A_inv, nstages * nstages * sizeof(PetscScalar));
140:     MatDestroy(&A_baij);
141:   }

143:   /* Compute row sums A_inv_rowsum and identity I_s */
144:   for (i = 0; i < nstages; i++) {
145:     gauss_A_inv_rowsum[i] = 0;
146:     for (j = 0; j < nstages; j++) {
147:       gauss_A_inv_rowsum[i] += gauss_A_inv[i + nstages * j];
148:       I_s[i + nstages * j] = 1. * (i == j);
149:     }
150:   }
151:   TSIRKTableauCreate(ts, nstages, gauss_A_real, gauss_b, gauss_c, NULL, gauss_A_inv, gauss_A_inv_rowsum, I_s);
152:   PetscFree3(gauss_A_real, gauss_b, gauss_c);
153:   PetscFree4(gauss_A, gauss_A_inv, gauss_A_inv_rowsum, I_s);
154:   return 0;
155: }

157: /*@C
158:    TSIRKRegister -  adds a TSIRK implementation

160:    Not Collective

162:    Input Parameters:
163: +  sname - name of user-defined IRK scheme
164: -  function - function to create method context

166:    Notes:
167:    TSIRKRegister() may be called multiple times to add several user-defined families.

169:    Sample usage:
170: .vb
171:    TSIRKRegister("my_scheme",MySchemeCreate);
172: .ve

174:    Then, your scheme can be chosen with the procedural interface via
175: $     TSIRKSetType(ts,"my_scheme")
176:    or at runtime via the option
177: $     -ts_irk_type my_scheme

179:    Level: advanced

181: .seealso: `TSIRKRegisterAll()`
182: @*/
183: PetscErrorCode TSIRKRegister(const char sname[], PetscErrorCode (*function)(TS))
184: {
185:   TSIRKInitializePackage();
186:   PetscFunctionListAdd(&TSIRKList, sname, function);
187:   return 0;
188: }

190: /*@C
191:   TSIRKRegisterAll - Registers all of the implicit Runge-Kutta methods in TSIRK

193:   Not Collective, but should be called by all processes which will need the schemes to be registered

195:   Level: advanced

197: .seealso: `TSIRKRegisterDestroy()`
198: @*/
199: PetscErrorCode TSIRKRegisterAll(void)
200: {
201:   if (TSIRKRegisterAllCalled) return 0;
202:   TSIRKRegisterAllCalled = PETSC_TRUE;

204:   TSIRKRegister(TSIRKGAUSS, TSIRKCreate_Gauss);
205:   return 0;
206: }

208: /*@C
209:    TSIRKRegisterDestroy - Frees the list of schemes that were registered by TSIRKRegister().

211:    Not Collective

213:    Level: advanced

215: .seealso: `TSIRKRegister()`, `TSIRKRegisterAll()`
216: @*/
217: PetscErrorCode TSIRKRegisterDestroy(void)
218: {
219:   TSIRKRegisterAllCalled = PETSC_FALSE;
220:   return 0;
221: }

223: /*@C
224:   TSIRKInitializePackage - This function initializes everything in the TSIRK package. It is called
225:   from TSInitializePackage().

227:   Level: developer

229: .seealso: `PetscInitialize()`
230: @*/
231: PetscErrorCode TSIRKInitializePackage(void)
232: {
233:   if (TSIRKPackageInitialized) return 0;
234:   TSIRKPackageInitialized = PETSC_TRUE;
235:   TSIRKRegisterAll();
236:   PetscRegisterFinalize(TSIRKFinalizePackage);
237:   return 0;
238: }

240: /*@C
241:   TSIRKFinalizePackage - This function destroys everything in the TSIRK package. It is
242:   called from PetscFinalize().

244:   Level: developer

246: .seealso: `PetscFinalize()`
247: @*/
248: PetscErrorCode TSIRKFinalizePackage(void)
249: {
250:   PetscFunctionListDestroy(&TSIRKList);
251:   TSIRKPackageInitialized = PETSC_FALSE;
252:   return 0;
253: }

255: /*
256:  This function can be called before or after ts->vec_sol has been updated.
257: */
258: static PetscErrorCode TSEvaluateStep_IRK(TS ts, PetscInt order, Vec U, PetscBool *done)
259: {
260:   TS_IRK      *irk   = (TS_IRK *)ts->data;
261:   IRKTableau   tab   = irk->tableau;
262:   Vec         *YdotI = irk->YdotI;
263:   PetscScalar *w     = irk->work;
264:   PetscReal    h;
265:   PetscInt     j;

267:   switch (irk->status) {
268:   case TS_STEP_INCOMPLETE:
269:   case TS_STEP_PENDING:
270:     h = ts->time_step;
271:     break;
272:   case TS_STEP_COMPLETE:
273:     h = ts->ptime - ts->ptime_prev;
274:     break;
275:   default:
276:     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
277:   }

279:   VecCopy(ts->vec_sol, U);
280:   for (j = 0; j < irk->nstages; j++) w[j] = h * tab->b[j];
281:   VecMAXPY(U, irk->nstages, w, YdotI);
282:   return 0;
283: }

285: static PetscErrorCode TSRollBack_IRK(TS ts)
286: {
287:   TS_IRK *irk = (TS_IRK *)ts->data;

289:   VecCopy(irk->U0, ts->vec_sol);
290:   return 0;
291: }

293: static PetscErrorCode TSStep_IRK(TS ts)
294: {
295:   TS_IRK        *irk   = (TS_IRK *)ts->data;
296:   IRKTableau     tab   = irk->tableau;
297:   PetscScalar   *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
298:   const PetscInt nstages = irk->nstages;
299:   SNES           snes;
300:   PetscInt       i, j, its, lits, bs;
301:   TSAdapt        adapt;
302:   PetscInt       rejections     = 0;
303:   PetscBool      accept         = PETSC_TRUE;
304:   PetscReal      next_time_step = ts->time_step;

306:   if (!ts->steprollback) VecCopy(ts->vec_sol, irk->U0);
307:   VecGetBlockSize(ts->vec_sol, &bs);
308:   for (i = 0; i < nstages; i++) VecStrideScatter(ts->vec_sol, i * bs, irk->Z, INSERT_VALUES);

310:   irk->status = TS_STEP_INCOMPLETE;
311:   while (!ts->reason && irk->status != TS_STEP_COMPLETE) {
312:     VecCopy(ts->vec_sol, irk->U);
313:     TSGetSNES(ts, &snes);
314:     SNESSolve(snes, NULL, irk->Z);
315:     SNESGetIterationNumber(snes, &its);
316:     SNESGetLinearSolveIterations(snes, &lits);
317:     ts->snes_its += its;
318:     ts->ksp_its += lits;
319:     VecStrideGatherAll(irk->Z, irk->Y, INSERT_VALUES);
320:     for (i = 0; i < nstages; i++) {
321:       VecZeroEntries(irk->YdotI[i]);
322:       for (j = 0; j < nstages; j++) VecAXPY(irk->YdotI[i], A_inv[i + j * nstages] / ts->time_step, irk->Y[j]);
323:       VecAXPY(irk->YdotI[i], -A_inv_rowsum[i] / ts->time_step, irk->U);
324:     }
325:     irk->status = TS_STEP_INCOMPLETE;
326:     TSEvaluateStep_IRK(ts, irk->order, ts->vec_sol, NULL);
327:     irk->status = TS_STEP_PENDING;
328:     TSGetAdapt(ts, &adapt);
329:     TSAdaptChoose(adapt, ts, ts->time_step, NULL, &next_time_step, &accept);
330:     irk->status = accept ? TS_STEP_COMPLETE : TS_STEP_INCOMPLETE;
331:     if (!accept) {
332:       TSRollBack_IRK(ts);
333:       ts->time_step = next_time_step;
334:       goto reject_step;
335:     }

337:     ts->ptime += ts->time_step;
338:     ts->time_step = next_time_step;
339:     break;
340:   reject_step:
341:     ts->reject++;
342:     accept = PETSC_FALSE;
343:     if (!ts->reason && ++rejections > ts->max_reject && ts->max_reject >= 0) {
344:       ts->reason = TS_DIVERGED_STEP_REJECTED;
345:       PetscInfo(ts, "Step=%" PetscInt_FMT ", step rejections %" PetscInt_FMT " greater than current TS allowed, stopping solve\n", ts->steps, rejections);
346:     }
347:   }
348:   return 0;
349: }

351: static PetscErrorCode TSInterpolate_IRK(TS ts, PetscReal itime, Vec U)
352: {
353:   TS_IRK          *irk     = (TS_IRK *)ts->data;
354:   PetscInt         nstages = irk->nstages, pinterp = irk->pinterp, i, j;
355:   PetscReal        h;
356:   PetscReal        tt, t;
357:   PetscScalar     *bt;
358:   const PetscReal *B = irk->tableau->binterp;

361:   switch (irk->status) {
362:   case TS_STEP_INCOMPLETE:
363:   case TS_STEP_PENDING:
364:     h = ts->time_step;
365:     t = (itime - ts->ptime) / h;
366:     break;
367:   case TS_STEP_COMPLETE:
368:     h = ts->ptime - ts->ptime_prev;
369:     t = (itime - ts->ptime) / h + 1; /* In the interval [0,1] */
370:     break;
371:   default:
372:     SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_PLIB, "Invalid TSStepStatus");
373:   }
374:   PetscMalloc1(nstages, &bt);
375:   for (i = 0; i < nstages; i++) bt[i] = 0;
376:   for (j = 0, tt = t; j < pinterp; j++, tt *= t) {
377:     for (i = 0; i < nstages; i++) bt[i] += h * B[i * pinterp + j] * tt;
378:   }
379:   VecMAXPY(U, nstages, bt, irk->YdotI);
380:   return 0;
381: }

383: static PetscErrorCode TSIRKTableauReset(TS ts)
384: {
385:   TS_IRK    *irk = (TS_IRK *)ts->data;
386:   IRKTableau tab = irk->tableau;

388:   if (!tab) return 0;
389:   PetscFree3(tab->A, tab->A_inv, tab->I_s);
390:   PetscFree4(tab->b, tab->c, tab->binterp, tab->A_inv_rowsum);
391:   return 0;
392: }

394: static PetscErrorCode TSReset_IRK(TS ts)
395: {
396:   TS_IRK *irk = (TS_IRK *)ts->data;

398:   TSIRKTableauReset(ts);
399:   if (irk->tableau) PetscFree(irk->tableau);
400:   if (irk->method_name) PetscFree(irk->method_name);
401:   if (irk->work) PetscFree(irk->work);
402:   VecDestroyVecs(irk->nstages, &irk->Y);
403:   VecDestroyVecs(irk->nstages, &irk->YdotI);
404:   VecDestroy(&irk->Ydot);
405:   VecDestroy(&irk->Z);
406:   VecDestroy(&irk->U);
407:   VecDestroy(&irk->U0);
408:   MatDestroy(&irk->TJ);
409:   return 0;
410: }

412: static PetscErrorCode TSIRKGetVecs(TS ts, DM dm, Vec *U)
413: {
414:   TS_IRK *irk = (TS_IRK *)ts->data;

416:   if (U) {
417:     if (dm && dm != ts->dm) {
418:       DMGetNamedGlobalVector(dm, "TSIRK_U", U);
419:     } else *U = irk->U;
420:   }
421:   return 0;
422: }

424: static PetscErrorCode TSIRKRestoreVecs(TS ts, DM dm, Vec *U)
425: {
426:   if (U) {
427:     if (dm && dm != ts->dm) DMRestoreNamedGlobalVector(dm, "TSIRK_U", U);
428:   }
429:   return 0;
430: }

432: /*
433:   This defines the nonlinear equations that is to be solved with SNES
434:     G[e\otimes t + C*dt, Z, Zdot] = 0
435:     Zdot = (In \otimes S)*Z - (In \otimes Se) U
436:   where S = 1/(dt*A)
437: */
438: static PetscErrorCode SNESTSFormFunction_IRK(SNES snes, Vec ZC, Vec FC, TS ts)
439: {
440:   TS_IRK            *irk     = (TS_IRK *)ts->data;
441:   IRKTableau         tab     = irk->tableau;
442:   const PetscInt     nstages = irk->nstages;
443:   const PetscReal   *c       = tab->c;
444:   const PetscScalar *A_inv = tab->A_inv, *A_inv_rowsum = tab->A_inv_rowsum;
445:   DM                 dm, dmsave;
446:   Vec                U, *YdotI = irk->YdotI, Ydot = irk->Ydot, *Y = irk->Y;
447:   PetscReal          h = ts->time_step;
448:   PetscInt           i, j;

450:   SNESGetDM(snes, &dm);
451:   TSIRKGetVecs(ts, dm, &U);
452:   VecStrideGatherAll(ZC, Y, INSERT_VALUES);
453:   dmsave = ts->dm;
454:   ts->dm = dm;
455:   for (i = 0; i < nstages; i++) {
456:     VecZeroEntries(Ydot);
457:     for (j = 0; j < nstages; j++) VecAXPY(Ydot, A_inv[j * nstages + i] / h, Y[j]);
458:     VecAXPY(Ydot, -A_inv_rowsum[i] / h, U); /* Ydot = (S \otimes In)*Z - (Se \otimes In) U */
459:     TSComputeIFunction(ts, ts->ptime + ts->time_step * c[i], Y[i], Ydot, YdotI[i], PETSC_FALSE);
460:   }
461:   VecStrideScatterAll(YdotI, FC, INSERT_VALUES);
462:   ts->dm = dmsave;
463:   TSIRKRestoreVecs(ts, dm, &U);
464:   return 0;
465: }

467: /*
468:    For explicit ODE, the Jacobian is
469:      JC = I_n \otimes S - J \otimes I_s
470:    For DAE, the Jacobian is
471:      JC = M_n \otimes S - J \otimes I_s
472: */
473: static PetscErrorCode SNESTSFormJacobian_IRK(SNES snes, Vec ZC, Mat JC, Mat JCpre, TS ts)
474: {
475:   TS_IRK          *irk     = (TS_IRK *)ts->data;
476:   IRKTableau       tab     = irk->tableau;
477:   const PetscInt   nstages = irk->nstages;
478:   const PetscReal *c       = tab->c;
479:   DM               dm, dmsave;
480:   Vec             *Y = irk->Y, Ydot = irk->Ydot;
481:   Mat              J;
482:   PetscScalar     *S;
483:   PetscInt         i, j, bs;

485:   SNESGetDM(snes, &dm);
486:   /* irk->Ydot has already been computed in SNESTSFormFunction_IRK (SNES guarantees this) */
487:   dmsave = ts->dm;
488:   ts->dm = dm;
489:   VecGetBlockSize(Y[nstages - 1], &bs);
490:   if (ts->equation_type <= TS_EQ_ODE_EXPLICIT) { /* Support explicit formulas only */
491:     VecStrideGather(ZC, (nstages - 1) * bs, Y[nstages - 1], INSERT_VALUES);
492:     MatKAIJGetAIJ(JC, &J);
493:     TSComputeIJacobian(ts, ts->ptime + ts->time_step * c[nstages - 1], Y[nstages - 1], Ydot, 0, J, J, PETSC_FALSE);
494:     MatKAIJGetS(JC, NULL, NULL, &S);
495:     for (i = 0; i < nstages; i++)
496:       for (j = 0; j < nstages; j++) S[i + nstages * j] = tab->A_inv[i + nstages * j] / ts->time_step;
497:     MatKAIJRestoreS(JC, &S);
498:   } else SETERRQ(PetscObjectComm((PetscObject)ts), PETSC_ERR_SUP, "TSIRK %s does not support implicit formula", irk->method_name); /* TODO: need the mass matrix for DAE  */
499:   ts->dm = dmsave;
500:   return 0;
501: }

503: static PetscErrorCode DMCoarsenHook_TSIRK(DM fine, DM coarse, void *ctx)
504: {
505:   return 0;
506: }

508: static PetscErrorCode DMRestrictHook_TSIRK(DM fine, Mat restrct, Vec rscale, Mat inject, DM coarse, void *ctx)
509: {
510:   TS  ts = (TS)ctx;
511:   Vec U, U_c;

513:   TSIRKGetVecs(ts, fine, &U);
514:   TSIRKGetVecs(ts, coarse, &U_c);
515:   MatRestrict(restrct, U, U_c);
516:   VecPointwiseMult(U_c, rscale, U_c);
517:   TSIRKRestoreVecs(ts, fine, &U);
518:   TSIRKRestoreVecs(ts, coarse, &U_c);
519:   return 0;
520: }

522: static PetscErrorCode DMSubDomainHook_TSIRK(DM dm, DM subdm, void *ctx)
523: {
524:   return 0;
525: }

527: static PetscErrorCode DMSubDomainRestrictHook_TSIRK(DM dm, VecScatter gscat, VecScatter lscat, DM subdm, void *ctx)
528: {
529:   TS  ts = (TS)ctx;
530:   Vec U, U_c;

532:   TSIRKGetVecs(ts, dm, &U);
533:   TSIRKGetVecs(ts, subdm, &U_c);

535:   VecScatterBegin(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD);
536:   VecScatterEnd(gscat, U, U_c, INSERT_VALUES, SCATTER_FORWARD);

538:   TSIRKRestoreVecs(ts, dm, &U);
539:   TSIRKRestoreVecs(ts, subdm, &U_c);
540:   return 0;
541: }

543: static PetscErrorCode TSSetUp_IRK(TS ts)
544: {
545:   TS_IRK        *irk = (TS_IRK *)ts->data;
546:   IRKTableau     tab = irk->tableau;
547:   DM             dm;
548:   Mat            J;
549:   Vec            R;
550:   const PetscInt nstages = irk->nstages;
551:   PetscInt       vsize, bs;

553:   if (!irk->work) PetscMalloc1(irk->nstages, &irk->work);
554:   if (!irk->Y) VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->Y);
555:   if (!irk->YdotI) VecDuplicateVecs(ts->vec_sol, irk->nstages, &irk->YdotI);
556:   if (!irk->Ydot) VecDuplicate(ts->vec_sol, &irk->Ydot);
557:   if (!irk->U) VecDuplicate(ts->vec_sol, &irk->U);
558:   if (!irk->U0) VecDuplicate(ts->vec_sol, &irk->U0);
559:   if (!irk->Z) {
560:     VecCreate(PetscObjectComm((PetscObject)ts->vec_sol), &irk->Z);
561:     VecGetSize(ts->vec_sol, &vsize);
562:     VecSetSizes(irk->Z, PETSC_DECIDE, vsize * irk->nstages);
563:     VecGetBlockSize(ts->vec_sol, &bs);
564:     VecSetBlockSize(irk->Z, irk->nstages * bs);
565:     VecSetFromOptions(irk->Z);
566:   }
567:   TSGetDM(ts, &dm);
568:   DMCoarsenHookAdd(dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts);
569:   DMSubDomainHookAdd(dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts);

571:   TSGetSNES(ts, &ts->snes);
572:   VecDuplicate(irk->Z, &R);
573:   SNESSetFunction(ts->snes, R, SNESTSFormFunction, ts);
574:   TSGetIJacobian(ts, &J, NULL, NULL, NULL);
575:   if (!irk->TJ) {
576:     /* Create the KAIJ matrix for solving the stages */
577:     MatCreateKAIJ(J, nstages, nstages, tab->A_inv, tab->I_s, &irk->TJ);
578:   }
579:   SNESSetJacobian(ts->snes, irk->TJ, irk->TJ, SNESTSFormJacobian, ts);
580:   VecDestroy(&R);
581:   return 0;
582: }

584: static PetscErrorCode TSSetFromOptions_IRK(TS ts, PetscOptionItems *PetscOptionsObject)
585: {
586:   TS_IRK *irk        = (TS_IRK *)ts->data;
587:   char    tname[256] = TSIRKGAUSS;

589:   PetscOptionsHeadBegin(PetscOptionsObject, "IRK ODE solver options");
590:   {
591:     PetscBool flg1, flg2;
592:     PetscOptionsInt("-ts_irk_nstages", "Stages of the IRK method", "TSIRKSetNumStages", irk->nstages, &irk->nstages, &flg1);
593:     PetscOptionsFList("-ts_irk_type", "Type of IRK method", "TSIRKSetType", TSIRKList, irk->method_name[0] ? irk->method_name : tname, tname, sizeof(tname), &flg2);
594:     if (flg1 || flg2 || !irk->method_name[0]) { /* Create the method tableau after nstages or method is set */
595:       TSIRKSetType(ts, tname);
596:     }
597:   }
598:   PetscOptionsHeadEnd();
599:   return 0;
600: }

602: static PetscErrorCode TSView_IRK(TS ts, PetscViewer viewer)
603: {
604:   TS_IRK   *irk = (TS_IRK *)ts->data;
605:   PetscBool iascii;

607:   PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii);
608:   if (iascii) {
609:     IRKTableau tab = irk->tableau;
610:     TSIRKType  irktype;
611:     char       buf[512];

613:     TSIRKGetType(ts, &irktype);
614:     PetscViewerASCIIPrintf(viewer, "  IRK type %s\n", irktype);
615:     PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", irk->nstages, tab->c);
616:     PetscViewerASCIIPrintf(viewer, "  Abscissa       c = %s\n", buf);
617:     PetscViewerASCIIPrintf(viewer, "Stiffly accurate: %s\n", irk->stiffly_accurate ? "yes" : "no");
618:     PetscFormatRealArray(buf, sizeof(buf), "% 8.6f", PetscSqr(irk->nstages), tab->A);
619:     PetscViewerASCIIPrintf(viewer, "  A coefficients       A = %s\n", buf);
620:   }
621:   return 0;
622: }

624: static PetscErrorCode TSLoad_IRK(TS ts, PetscViewer viewer)
625: {
626:   SNES    snes;
627:   TSAdapt adapt;

629:   TSGetAdapt(ts, &adapt);
630:   TSAdaptLoad(adapt, viewer);
631:   TSGetSNES(ts, &snes);
632:   SNESLoad(snes, viewer);
633:   /* function and Jacobian context for SNES when used with TS is always ts object */
634:   SNESSetFunction(snes, NULL, NULL, ts);
635:   SNESSetJacobian(snes, NULL, NULL, NULL, ts);
636:   return 0;
637: }

639: /*@C
640:   TSIRKSetType - Set the type of IRK scheme

642:   Logically collective

644:   Input Parameters:
645: +  ts - timestepping context
646: -  irktype - type of IRK scheme

648:   Options Database:
649: .  -ts_irk_type <gauss> - set irk type

651:   Level: intermediate

653: .seealso: `TSIRKGetType()`, `TSIRK`, `TSIRKType`, `TSIRKGAUSS`
654: @*/
655: PetscErrorCode TSIRKSetType(TS ts, TSIRKType irktype)
656: {
659:   PetscTryMethod(ts, "TSIRKSetType_C", (TS, TSIRKType), (ts, irktype));
660:   return 0;
661: }

663: /*@C
664:   TSIRKGetType - Get the type of IRK IMEX scheme

666:   Logically collective

668:   Input Parameter:
669: .  ts - timestepping context

671:   Output Parameter:
672: .  irktype - type of IRK-IMEX scheme

674:   Level: intermediate

676: .seealso: `TSIRKGetType()`
677: @*/
678: PetscErrorCode TSIRKGetType(TS ts, TSIRKType *irktype)
679: {
681:   PetscUseMethod(ts, "TSIRKGetType_C", (TS, TSIRKType *), (ts, irktype));
682:   return 0;
683: }

685: /*@C
686:   TSIRKSetNumStages - Set the number of stages of IRK scheme

688:   Logically collective

690:   Input Parameters:
691: +  ts - timestepping context
692: -  nstages - number of stages of IRK scheme

694:   Options Database:
695: .  -ts_irk_nstages <int> - set number of stages

697:   Level: intermediate

699: .seealso: `TSIRKGetNumStages()`, `TSIRK`
700: @*/
701: PetscErrorCode TSIRKSetNumStages(TS ts, PetscInt nstages)
702: {
704:   PetscTryMethod(ts, "TSIRKSetNumStages_C", (TS, PetscInt), (ts, nstages));
705:   return 0;
706: }

708: /*@C
709:   TSIRKGetNumStages - Get the number of stages of IRK scheme

711:   Logically collective

713:   Input Parameters:
714: +  ts - timestepping context
715: -  nstages - number of stages of IRK scheme

717:   Level: intermediate

719: .seealso: `TSIRKSetNumStages()`, `TSIRK`
720: @*/
721: PetscErrorCode TSIRKGetNumStages(TS ts, PetscInt *nstages)
722: {
725:   PetscTryMethod(ts, "TSIRKGetNumStages_C", (TS, PetscInt *), (ts, nstages));
726:   return 0;
727: }

729: static PetscErrorCode TSIRKGetType_IRK(TS ts, TSIRKType *irktype)
730: {
731:   TS_IRK *irk = (TS_IRK *)ts->data;

733:   *irktype = irk->method_name;
734:   return 0;
735: }

737: static PetscErrorCode TSIRKSetType_IRK(TS ts, TSIRKType irktype)
738: {
739:   TS_IRK *irk = (TS_IRK *)ts->data;
740:   PetscErrorCode (*irkcreate)(TS);

742:   if (irk->method_name) {
743:     PetscFree(irk->method_name);
744:     TSIRKTableauReset(ts);
745:   }
746:   PetscFunctionListFind(TSIRKList, irktype, &irkcreate);
748:   (*irkcreate)(ts);
749:   PetscStrallocpy(irktype, &irk->method_name);
750:   return 0;
751: }

753: static PetscErrorCode TSIRKSetNumStages_IRK(TS ts, PetscInt nstages)
754: {
755:   TS_IRK *irk = (TS_IRK *)ts->data;

758:   irk->nstages = nstages;
759:   return 0;
760: }

762: static PetscErrorCode TSIRKGetNumStages_IRK(TS ts, PetscInt *nstages)
763: {
764:   TS_IRK *irk = (TS_IRK *)ts->data;

767:   *nstages = irk->nstages;
768:   return 0;
769: }

771: static PetscErrorCode TSDestroy_IRK(TS ts)
772: {
773:   TSReset_IRK(ts);
774:   if (ts->dm) {
775:     DMCoarsenHookRemove(ts->dm, DMCoarsenHook_TSIRK, DMRestrictHook_TSIRK, ts);
776:     DMSubDomainHookRemove(ts->dm, DMSubDomainHook_TSIRK, DMSubDomainRestrictHook_TSIRK, ts);
777:   }
778:   PetscFree(ts->data);
779:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", NULL);
780:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", NULL);
781:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", NULL);
782:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", NULL);
783:   return 0;
784: }

786: /*MC
787:       TSIRK - ODE and DAE solver using Implicit Runge-Kutta schemes

789:   Notes:

791:   TSIRK uses the sparse Kronecker product matrix implementation of MATKAIJ to achieve good arithmetic intensity.

793:   Gauss-Legrendre methods are currently supported. These are A-stable symplectic methods with an arbitrary number of stages. The order of accuracy is 2s when using s stages. The default method uses three stages and thus has an order of six. The number of stages (thus order) can be set with -ts_irk_nstages or TSIRKSetNumStages().

795:   Level: beginner

797: .seealso: `TSCreate()`, `TS`, `TSSetType()`, `TSIRKSetType()`, `TSIRKGetType()`, `TSIRKGAUSS`, `TSIRKRegister()`, `TSIRKSetNumStages()`

799: M*/
800: PETSC_EXTERN PetscErrorCode TSCreate_IRK(TS ts)
801: {
802:   TS_IRK *irk;

804:   TSIRKInitializePackage();

806:   ts->ops->reset          = TSReset_IRK;
807:   ts->ops->destroy        = TSDestroy_IRK;
808:   ts->ops->view           = TSView_IRK;
809:   ts->ops->load           = TSLoad_IRK;
810:   ts->ops->setup          = TSSetUp_IRK;
811:   ts->ops->step           = TSStep_IRK;
812:   ts->ops->interpolate    = TSInterpolate_IRK;
813:   ts->ops->evaluatestep   = TSEvaluateStep_IRK;
814:   ts->ops->rollback       = TSRollBack_IRK;
815:   ts->ops->setfromoptions = TSSetFromOptions_IRK;
816:   ts->ops->snesfunction   = SNESTSFormFunction_IRK;
817:   ts->ops->snesjacobian   = SNESTSFormJacobian_IRK;

819:   ts->usessnes = PETSC_TRUE;

821:   PetscNew(&irk);
822:   ts->data = (void *)irk;

824:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetType_C", TSIRKSetType_IRK);
825:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetType_C", TSIRKGetType_IRK);
826:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKSetNumStages_C", TSIRKSetNumStages_IRK);
827:   PetscObjectComposeFunction((PetscObject)ts, "TSIRKGetNumStages_C", TSIRKGetNumStages_IRK);
828:   /* 3-stage IRK_Gauss is the default */
829:   PetscNew(&irk->tableau);
830:   irk->nstages = 3;
831:   TSIRKSetType(ts, TSIRKDefault);
832:   return 0;
833: }