Actual source code: snespatch.c

  1: /*
  2:       Defines a SNES that can consist of a collection of SNESes on patches of the domain
  3: */
  4: #include <petsc/private/vecimpl.h>
  5: #include <petsc/private/snesimpl.h>
  6: #include <petsc/private/pcpatchimpl.h>
  7: #include <petscsf.h>
  8: #include <petscsection.h>

 10: typedef struct {
 11:   PC pc; /* The linear patch preconditioner */
 12: } SNES_Patch;

 14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
 15: {
 16:   PC                 pc      = (PC)ctx;
 17:   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
 18:   PetscInt           pt, size, i;
 19:   const PetscInt    *indices;
 20:   const PetscScalar *X;
 21:   PetscScalar       *XWithAll;

 23:   PetscFunctionBegin;
 24:   /* scatter from x to patch->patchStateWithAll[pt] */
 25:   pt = pcpatch->currentPatch;
 26:   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));

 28:   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
 29:   PetscCall(VecGetArrayRead(x, &X));
 30:   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));

 32:   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];

 34:   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
 35:   PetscCall(VecRestoreArrayRead(x, &X));
 36:   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));

 38:   PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
 39:   PetscFunctionReturn(PETSC_SUCCESS);
 40: }

 42: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
 43: {
 44:   PC                 pc      = (PC)ctx;
 45:   PC_PATCH          *pcpatch = (PC_PATCH *)pc->data;
 46:   PetscInt           pt, size, i;
 47:   const PetscInt    *indices;
 48:   const PetscScalar *X;
 49:   PetscScalar       *XWithAll;

 51:   PetscFunctionBegin;
 52:   /* scatter from x to patch->patchStateWithAll[pt] */
 53:   pt = pcpatch->currentPatch;
 54:   PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));

 56:   PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
 57:   PetscCall(VecGetArrayRead(x, &X));
 58:   PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));

 60:   for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];

 62:   PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
 63:   PetscCall(VecRestoreArrayRead(x, &X));
 64:   PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));

 66:   PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
 67:   PetscFunctionReturn(PETSC_SUCCESS);
 68: }

 70: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
 71: {
 72:   PC_PATCH   *patch = (PC_PATCH *)pc->data;
 73:   const char *prefix;
 74:   PetscInt    i, pStart, dof, maxDof = -1;

 76:   PetscFunctionBegin;
 77:   if (!pc->setupcalled) {
 78:     PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
 79:     PetscCall(PCGetOptionsPrefix(pc, &prefix));
 80:     PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
 81:     for (i = 0; i < patch->npatch; ++i) {
 82:       SNES snes;

 84:       PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
 85:       PetscCall(SNESSetOptionsPrefix(snes, prefix));
 86:       PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
 87:       PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2));
 88:       patch->solver[i] = (PetscObject)snes;

 90:       PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof));
 91:       maxDof = PetscMax(maxDof, dof);
 92:     }
 93:     PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
 94:     PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
 95:     PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));

 97:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
 98:     PetscCall(VecSetUp(patch->patchStateWithAll));
 99:   }
100:   for (i = 0; i < patch->npatch; ++i) {
101:     SNES snes = (SNES)patch->solver[i];

103:     PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
104:     PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
105:   }
106:   if (!pc->setupcalled && patch->optionsSet)
107:     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i]));
108:   PetscFunctionReturn(PETSC_SUCCESS);
109: }

111: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
112: {
113:   PC_PATCH *patch = (PC_PATCH *)pc->data;
114:   PetscInt  pStart, n;

116:   PetscFunctionBegin;
117:   patch->currentPatch = i;
118:   PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));

120:   /* Scatter the overlapped global state to our patch state vector */
121:   PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
122:   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
123:   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));

125:   PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
126:   patch->patchState->map->n = n;
127:   patch->patchState->map->N = n;
128:   patchUpdate->map->n       = n;
129:   patchUpdate->map->N       = n;
130:   patchRHS->map->n          = n;
131:   patchRHS->map->N          = n;
132:   /* Set initial guess to be current state*/
133:   PetscCall(VecCopy(patch->patchState, patchUpdate));
134:   /* Solve for new state */
135:   PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate));
136:   /* To compute update, subtract off previous state */
137:   PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));

139:   PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
140:   PetscFunctionReturn(PETSC_SUCCESS);
141: }

143: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
144: {
145:   PC_PATCH *patch = (PC_PATCH *)pc->data;
146:   PetscInt  i;

148:   PetscFunctionBegin;
149:   if (patch->solver) {
150:     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i]));
151:   }

153:   PetscCall(VecDestroy(&patch->patchResidual));
154:   PetscCall(VecDestroy(&patch->patchState));
155:   PetscCall(VecDestroy(&patch->patchStateWithAll));

157:   PetscCall(VecDestroy(&patch->localState));
158:   PetscFunctionReturn(PETSC_SUCCESS);
159: }

161: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
162: {
163:   PC_PATCH *patch = (PC_PATCH *)pc->data;
164:   PetscInt  i;

166:   PetscFunctionBegin;
167:   if (patch->solver) {
168:     for (i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i]));
169:     PetscCall(PetscFree(patch->solver));
170:   }
171:   PetscFunctionReturn(PETSC_SUCCESS);
172: }

174: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
175: {
176:   PC_PATCH *patch = (PC_PATCH *)pc->data;

178:   PetscFunctionBegin;
179:   PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
180:   PetscFunctionReturn(PETSC_SUCCESS);
181: }

183: static PetscErrorCode SNESSetUp_Patch(SNES snes)
184: {
185:   SNES_Patch *patch = (SNES_Patch *)snes->data;
186:   DM          dm;
187:   Mat         dummy;
188:   Vec         F;
189:   PetscInt    n, N;

191:   PetscFunctionBegin;
192:   PetscCall(SNESGetDM(snes, &dm));
193:   PetscCall(PCSetDM(patch->pc, dm));
194:   PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
195:   PetscCall(VecGetLocalSize(F, &n));
196:   PetscCall(VecGetSize(F, &N));
197:   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy));
198:   PetscCall(PCSetOperators(patch->pc, dummy, dummy));
199:   PetscCall(MatDestroy(&dummy));
200:   PetscCall(PCSetUp(patch->pc));
201:   /* allocate workspace */
202:   PetscFunctionReturn(PETSC_SUCCESS);
203: }

205: static PetscErrorCode SNESReset_Patch(SNES snes)
206: {
207:   SNES_Patch *patch = (SNES_Patch *)snes->data;

209:   PetscFunctionBegin;
210:   PetscCall(PCReset(patch->pc));
211:   PetscFunctionReturn(PETSC_SUCCESS);
212: }

214: static PetscErrorCode SNESDestroy_Patch(SNES snes)
215: {
216:   SNES_Patch *patch = (SNES_Patch *)snes->data;

218:   PetscFunctionBegin;
219:   PetscCall(SNESReset_Patch(snes));
220:   PetscCall(PCDestroy(&patch->pc));
221:   PetscCall(PetscFree(snes->data));
222:   PetscFunctionReturn(PETSC_SUCCESS);
223: }

225: static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems *PetscOptionsObject)
226: {
227:   SNES_Patch *patch = (SNES_Patch *)snes->data;
228:   const char *prefix;

230:   PetscFunctionBegin;
231:   PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
232:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
233:   PetscCall(PCSetFromOptions(patch->pc));
234:   PetscFunctionReturn(PETSC_SUCCESS);
235: }

237: static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
238: {
239:   SNES_Patch *patch = (SNES_Patch *)snes->data;
240:   PetscBool   iascii;

242:   PetscFunctionBegin;
243:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
244:   if (iascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n"));
245:   PetscCall(PetscViewerASCIIPushTab(viewer));
246:   PetscCall(PCView(patch->pc, viewer));
247:   PetscCall(PetscViewerASCIIPopTab(viewer));
248:   PetscFunctionReturn(PETSC_SUCCESS);
249: }

251: static PetscErrorCode SNESSolve_Patch(SNES snes)
252: {
253:   SNES_Patch        *patch   = (SNES_Patch *)snes->data;
254:   PC_PATCH          *pcpatch = (PC_PATCH *)patch->pc->data;
255:   SNESLineSearch     ls;
256:   Vec                rhs, update, state, residual;
257:   const PetscScalar *globalState = NULL;
258:   PetscScalar       *localState  = NULL;
259:   PetscInt           its         = 0;
260:   PetscReal          xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;

262:   PetscFunctionBegin;
263:   PetscCall(SNESGetSolution(snes, &state));
264:   PetscCall(SNESGetSolutionUpdate(snes, &update));
265:   PetscCall(SNESGetRhs(snes, &rhs));

267:   PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
268:   PetscCall(SNESGetLineSearch(snes, &ls));

270:   PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
271:   PetscCall(VecSet(update, 0.0));
272:   PetscCall(SNESComputeFunction(snes, state, residual));

274:   PetscCall(VecNorm(state, NORM_2, &xnorm));
275:   PetscCall(VecNorm(residual, NORM_2, &fnorm));
276:   snes->ttol = fnorm * snes->rtol;

278:   if (snes->ops->converged) {
279:     PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
280:   } else {
281:     PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
282:   }
283:   PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
284:   PetscCall(SNESMonitor(snes, its, fnorm));

286:   /* The main solver loop */
287:   for (its = 0; its < snes->max_its; its++) {
288:     PetscCall(SNESSetIterationNumber(snes, its));

290:     /* Scatter state vector to overlapped vector on all patches.
291:        The vector pcpatch->localState is scattered to each patch
292:        in PCApply_PATCH_Nonlinear. */
293:     PetscCall(VecGetArrayRead(state, &globalState));
294:     PetscCall(VecGetArray(pcpatch->localState, &localState));
295:     PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
296:     PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
297:     PetscCall(VecRestoreArray(pcpatch->localState, &localState));
298:     PetscCall(VecRestoreArrayRead(state, &globalState));

300:     /* The looping over patches happens here */
301:     PetscCall(PCApply(patch->pc, rhs, update));

303:     /* Apply a line search. This will often be basic with
304:        damping = 1/(max number of patches a dof can be in),
305:        but not always */
306:     PetscCall(VecScale(update, -1.0));
307:     PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));

309:     PetscCall(VecNorm(state, NORM_2, &xnorm));
310:     PetscCall(VecNorm(update, NORM_2, &ynorm));

312:     if (snes->ops->converged) {
313:       PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
314:     } else {
315:       PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
316:     }
317:     PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
318:     PetscCall(SNESMonitor(snes, its, fnorm));
319:   }

321:   if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
322:   PetscFunctionReturn(PETSC_SUCCESS);
323: }

325: /*MC
326:   SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches {cite}`bruneknepleysmithtu15`

328:   Level: intermediate

330: .seealso: [](ch_snes), `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
331: M*/
332: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
333: {
334:   SNES_Patch    *patch;
335:   PC_PATCH      *patchpc;
336:   SNESLineSearch linesearch;

338:   PetscFunctionBegin;
339:   PetscCall(PetscNew(&patch));

341:   snes->ops->solve          = SNESSolve_Patch;
342:   snes->ops->setup          = SNESSetUp_Patch;
343:   snes->ops->reset          = SNESReset_Patch;
344:   snes->ops->destroy        = SNESDestroy_Patch;
345:   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
346:   snes->ops->view           = SNESView_Patch;

348:   PetscCall(SNESGetLineSearch(snes, &linesearch));
349:   if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC));
350:   snes->usesksp = PETSC_FALSE;

352:   snes->alwayscomputesfinalresidual = PETSC_FALSE;

354:   PetscCall(SNESParametersInitialize(snes));

356:   snes->data = (void *)patch;
357:   PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc));
358:   PetscCall(PCSetType(patch->pc, PCPATCH));

360:   patchpc              = (PC_PATCH *)patch->pc->data;
361:   patchpc->classname   = "snes";
362:   patchpc->isNonlinear = PETSC_TRUE;

364:   patchpc->setupsolver          = PCSetUp_PATCH_Nonlinear;
365:   patchpc->applysolver          = PCApply_PATCH_Nonlinear;
366:   patchpc->resetsolver          = PCReset_PATCH_Nonlinear;
367:   patchpc->destroysolver        = PCDestroy_PATCH_Nonlinear;
368:   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
369:   PetscFunctionReturn(PETSC_SUCCESS);
370: }

372: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
373: {
374:   SNES_Patch *patch = (SNES_Patch *)snes->data;
375:   DM          dm;

377:   PetscFunctionBegin;
378:   PetscCall(SNESGetDM(snes, &dm));
379:   PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
380:   PetscCall(PCSetDM(patch->pc, dm));
381:   PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
382:   PetscFunctionReturn(PETSC_SUCCESS);
383: }

385: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
386: {
387:   SNES_Patch *patch = (SNES_Patch *)snes->data;

389:   PetscFunctionBegin;
390:   PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
391:   PetscFunctionReturn(PETSC_SUCCESS);
392: }

394: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
395: {
396:   SNES_Patch *patch = (SNES_Patch *)snes->data;

398:   PetscFunctionBegin;
399:   PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
400:   PetscFunctionReturn(PETSC_SUCCESS);
401: }

403: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
404: {
405:   SNES_Patch *patch = (SNES_Patch *)snes->data;

407:   PetscFunctionBegin;
408:   PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
409:   PetscFunctionReturn(PETSC_SUCCESS);
410: }

412: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
413: {
414:   SNES_Patch *patch = (SNES_Patch *)snes->data;

416:   PetscFunctionBegin;
417:   PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
418:   PetscFunctionReturn(PETSC_SUCCESS);
419: }