Actual source code: snespatch.c
1: /*
2: Defines a SNES that can consist of a collection of SNESes on patches of the domain
3: */
4: #include <petsc/private/vecimpl.h>
5: #include <petsc/private/snesimpl.h>
6: #include <petsc/private/pcpatchimpl.h>
7: #include <petscsf.h>
8: #include <petscsection.h>
10: typedef struct {
11: PC pc; /* The linear patch preconditioner */
12: } SNES_Patch;
14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, PetscCtx ctx)
15: {
16: PC pc = (PC)ctx;
17: PC_PATCH *pcpatch = (PC_PATCH *)pc->data;
18: PetscInt pt, size, i;
19: const PetscInt *indices;
20: const PetscScalar *X;
21: PetscScalar *XWithAll;
23: PetscFunctionBegin;
24: /* scatter from x to patch->patchStateWithAll[pt] */
25: pt = pcpatch->currentPatch;
26: PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
28: PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
29: PetscCall(VecGetArrayRead(x, &X));
30: PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
32: for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
34: PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
35: PetscCall(VecRestoreArrayRead(x, &X));
36: PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
38: PetscCall(PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt));
39: PetscFunctionReturn(PETSC_SUCCESS);
40: }
42: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, PetscCtx ctx)
43: {
44: PC pc = (PC)ctx;
45: PC_PATCH *pcpatch = (PC_PATCH *)pc->data;
46: PetscInt pt, size, i;
47: const PetscInt *indices;
48: const PetscScalar *X;
49: PetscScalar *XWithAll;
51: PetscFunctionBegin;
52: /* scatter from x to patch->patchStateWithAll[pt] */
53: pt = pcpatch->currentPatch;
54: PetscCall(ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size));
56: PetscCall(ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
57: PetscCall(VecGetArrayRead(x, &X));
58: PetscCall(VecGetArray(pcpatch->patchStateWithAll, &XWithAll));
60: for (i = 0; i < size; ++i) XWithAll[indices[i]] = X[i];
62: PetscCall(VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll));
63: PetscCall(VecRestoreArrayRead(x, &X));
64: PetscCall(ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices));
66: PetscCall(PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE));
67: PetscFunctionReturn(PETSC_SUCCESS);
68: }
70: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
71: {
72: PC_PATCH *patch = (PC_PATCH *)pc->data;
73: const char *prefix;
74: PetscInt i, pStart, dof, maxDof = -1;
76: PetscFunctionBegin;
77: if (!pc->setupcalled) {
78: PetscCall(PetscMalloc1(patch->npatch, &patch->solver));
79: PetscCall(PCGetOptionsPrefix(pc, &prefix));
80: PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
81: for (i = 0; i < patch->npatch; ++i) {
82: SNES snes;
84: PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
85: PetscCall(SNESSetOptionsPrefix(snes, prefix));
86: PetscCall(SNESAppendOptionsPrefix(snes, "sub_"));
87: PetscCall(PetscObjectIncrementTabLevel((PetscObject)snes, (PetscObject)pc, 2));
88: patch->solver[i] = (PetscObject)snes;
90: PetscCall(PetscSectionGetDof(patch->gtolCountsWithAll, i + pStart, &dof));
91: maxDof = PetscMax(maxDof, dof);
92: }
93: PetscCall(VecDuplicate(patch->localUpdate, &patch->localState));
94: PetscCall(VecDuplicate(patch->patchRHS, &patch->patchResidual));
95: PetscCall(VecDuplicate(patch->patchUpdate, &patch->patchState));
97: PetscCall(VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll));
98: PetscCall(VecSetUp(patch->patchStateWithAll));
99: }
100: for (i = 0; i < patch->npatch; ++i) {
101: SNES snes = (SNES)patch->solver[i];
103: PetscCall(SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc));
104: PetscCall(SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc));
105: }
106: if (!pc->setupcalled && patch->optionsSet)
107: for (i = 0; i < patch->npatch; ++i) PetscCall(SNESSetFromOptions((SNES)patch->solver[i]));
108: PetscFunctionReturn(PETSC_SUCCESS);
109: }
111: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
112: {
113: PC_PATCH *patch = (PC_PATCH *)pc->data;
114: PetscInt pStart, n;
116: PetscFunctionBegin;
117: patch->currentPatch = i;
118: PetscCall(PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0));
120: /* Scatter the overlapped global state to our patch state vector */
121: PetscCall(PetscSectionGetChart(patch->gtolCounts, &pStart, NULL));
122: PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR));
123: PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL));
125: PetscCall(MatGetLocalSize(patch->mat[i], NULL, &n));
126: patch->patchState->map->n = n;
127: patch->patchState->map->N = n;
128: patchUpdate->map->n = n;
129: patchUpdate->map->N = n;
130: patchRHS->map->n = n;
131: patchRHS->map->N = n;
132: /* Set initial guess to be current state*/
133: PetscCall(VecCopy(patch->patchState, patchUpdate));
134: /* Solve for new state */
135: PetscCall(SNESSolve((SNES)patch->solver[i], patchRHS, patchUpdate));
136: /* To compute update, subtract off previous state */
137: PetscCall(VecAXPY(patchUpdate, -1.0, patch->patchState));
139: PetscCall(PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0));
140: PetscFunctionReturn(PETSC_SUCCESS);
141: }
143: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
144: {
145: PC_PATCH *patch = (PC_PATCH *)pc->data;
147: PetscFunctionBegin;
148: if (patch->solver) {
149: for (PetscInt i = 0; i < patch->npatch; ++i) PetscCall(SNESReset((SNES)patch->solver[i]));
150: }
152: PetscCall(VecDestroy(&patch->patchResidual));
153: PetscCall(VecDestroy(&patch->patchState));
154: PetscCall(VecDestroy(&patch->patchStateWithAll));
156: PetscCall(VecDestroy(&patch->localState));
157: PetscFunctionReturn(PETSC_SUCCESS);
158: }
160: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
161: {
162: PC_PATCH *patch = (PC_PATCH *)pc->data;
164: PetscFunctionBegin;
165: if (patch->solver) {
166: for (PetscInt i = 0; i < patch->npatch; ++i) PetscCall(SNESDestroy((SNES *)&patch->solver[i]));
167: PetscCall(PetscFree(patch->solver));
168: }
169: PetscFunctionReturn(PETSC_SUCCESS);
170: }
172: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
173: {
174: PC_PATCH *patch = (PC_PATCH *)pc->data;
176: PetscFunctionBegin;
177: PetscCall(PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR));
178: PetscFunctionReturn(PETSC_SUCCESS);
179: }
181: static PetscErrorCode SNESSetUp_Patch(SNES snes)
182: {
183: SNES_Patch *patch = (SNES_Patch *)snes->data;
184: DM dm;
185: Mat dummy;
186: Vec F;
187: PetscInt n, N;
189: PetscFunctionBegin;
190: PetscCall(SNESGetDM(snes, &dm));
191: PetscCall(PCSetDM(patch->pc, dm));
192: PetscCall(SNESGetFunction(snes, &F, NULL, NULL));
193: PetscCall(VecGetLocalSize(F, &n));
194: PetscCall(VecGetSize(F, &N));
195: PetscCall(MatCreateShell(PetscObjectComm((PetscObject)snes), n, n, N, N, (void *)snes, &dummy));
196: PetscCall(PCSetOperators(patch->pc, dummy, dummy));
197: PetscCall(MatDestroy(&dummy));
198: PetscCall(PCSetUp(patch->pc));
199: /* allocate workspace */
200: PetscFunctionReturn(PETSC_SUCCESS);
201: }
203: static PetscErrorCode SNESReset_Patch(SNES snes)
204: {
205: SNES_Patch *patch = (SNES_Patch *)snes->data;
207: PetscFunctionBegin;
208: PetscCall(PCReset(patch->pc));
209: PetscFunctionReturn(PETSC_SUCCESS);
210: }
212: static PetscErrorCode SNESDestroy_Patch(SNES snes)
213: {
214: SNES_Patch *patch = (SNES_Patch *)snes->data;
216: PetscFunctionBegin;
217: PetscCall(SNESReset_Patch(snes));
218: PetscCall(PCDestroy(&patch->pc));
219: PetscCall(PetscFree(snes->data));
220: PetscFunctionReturn(PETSC_SUCCESS);
221: }
223: static PetscErrorCode SNESSetFromOptions_Patch(SNES snes, PetscOptionItems PetscOptionsObject)
224: {
225: SNES_Patch *patch = (SNES_Patch *)snes->data;
226: const char *prefix;
228: PetscFunctionBegin;
229: PetscCall(PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix));
230: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix));
231: PetscCall(PCSetFromOptions(patch->pc));
232: PetscFunctionReturn(PETSC_SUCCESS);
233: }
235: static PetscErrorCode SNESView_Patch(SNES snes, PetscViewer viewer)
236: {
237: SNES_Patch *patch = (SNES_Patch *)snes->data;
238: PetscBool isascii;
240: PetscFunctionBegin;
241: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
242: if (isascii) PetscCall(PetscViewerASCIIPrintf(viewer, "SNESPATCH\n"));
243: PetscCall(PetscViewerASCIIPushTab(viewer));
244: PetscCall(PCView(patch->pc, viewer));
245: PetscCall(PetscViewerASCIIPopTab(viewer));
246: PetscFunctionReturn(PETSC_SUCCESS);
247: }
249: static PetscErrorCode SNESSolve_Patch(SNES snes)
250: {
251: SNES_Patch *patch = (SNES_Patch *)snes->data;
252: PC_PATCH *pcpatch = (PC_PATCH *)patch->pc->data;
253: SNESLineSearch ls;
254: Vec rhs, update, state, residual;
255: const PetscScalar *globalState = NULL;
256: PetscScalar *localState = NULL;
257: PetscInt its = 0;
258: PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
260: PetscFunctionBegin;
261: PetscCall(SNESGetSolution(snes, &state));
262: PetscCall(SNESGetSolutionUpdate(snes, &update));
263: PetscCall(SNESGetRhs(snes, &rhs));
265: PetscCall(SNESGetFunction(snes, &residual, NULL, NULL));
266: PetscCall(SNESGetLineSearch(snes, &ls));
268: PetscCall(SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING));
269: PetscCall(VecSet(update, 0.0));
270: PetscCall(SNESComputeFunction(snes, state, residual));
272: PetscCall(VecNorm(state, NORM_2, &xnorm));
273: PetscCall(VecNorm(residual, NORM_2, &fnorm));
274: SNESCheckFunctionDomainError(snes, fnorm);
275: snes->ttol = fnorm * snes->rtol;
277: if (snes->ops->converged) {
278: PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
279: } else {
280: PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
281: }
282: PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* should we count lits from the patches? */
283: PetscCall(SNESMonitor(snes, its, fnorm));
285: /* The main solver loop */
286: for (its = 0; its < snes->max_its; its++) {
287: PetscCall(SNESSetIterationNumber(snes, its));
289: /* Scatter state vector to overlapped vector on all patches.
290: The vector pcpatch->localState is scattered to each patch
291: in PCApply_PATCH_Nonlinear. */
292: PetscCall(VecGetArrayRead(state, &globalState));
293: PetscCall(VecGetArray(pcpatch->localState, &localState));
294: PetscCall(PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
295: PetscCall(PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState, MPI_REPLACE));
296: PetscCall(VecRestoreArray(pcpatch->localState, &localState));
297: PetscCall(VecRestoreArrayRead(state, &globalState));
299: /* The looping over patches happens here */
300: PetscCall(PCApply(patch->pc, rhs, update));
302: /* Apply a line search. This will often be basic with
303: damping = 1/(max number of patches a dof can be in),
304: but not always */
305: PetscCall(VecScale(update, -1.0));
306: PetscCall(SNESLineSearchApply(ls, state, residual, &fnorm, update));
308: PetscCall(VecNorm(state, NORM_2, &xnorm));
309: PetscCall(VecNorm(update, NORM_2, &ynorm));
311: if (snes->ops->converged) {
312: PetscUseTypeMethod(snes, converged, its, xnorm, ynorm, fnorm, &snes->reason, snes->cnvP);
313: } else {
314: PetscCall(SNESConvergedSkip(snes, its, xnorm, ynorm, fnorm, &snes->reason, NULL));
315: }
316: PetscCall(SNESLogConvergenceHistory(snes, fnorm, 0)); /* FIXME: should we count lits? */
317: PetscCall(SNESMonitor(snes, its, fnorm));
318: }
320: if (its == snes->max_its) PetscCall(SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT));
321: PetscFunctionReturn(PETSC_SUCCESS);
322: }
324: /*MC
325: SNESPATCH - Solve a nonlinear problem or apply a nonlinear smoother by composing together many nonlinear solvers on (often overlapping) patches {cite}`bruneknepleysmithtu15`
327: Level: intermediate
329: .seealso: [](ch_snes), `SNESFAS`, `SNESCreate()`, `SNESSetType()`, `SNESType`, `SNES`, `PCPATCH`
330: M*/
331: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
332: {
333: SNES_Patch *patch;
334: PC_PATCH *patchpc;
335: SNESLineSearch linesearch;
337: PetscFunctionBegin;
338: PetscCall(PetscNew(&patch));
340: snes->ops->solve = SNESSolve_Patch;
341: snes->ops->setup = SNESSetUp_Patch;
342: snes->ops->reset = SNESReset_Patch;
343: snes->ops->destroy = SNESDestroy_Patch;
344: snes->ops->setfromoptions = SNESSetFromOptions_Patch;
345: snes->ops->view = SNESView_Patch;
347: PetscCall(SNESGetLineSearch(snes, &linesearch));
348: if (!((PetscObject)linesearch)->type_name) PetscCall(SNESLineSearchSetType(linesearch, SNESLINESEARCHBASIC));
349: snes->usesksp = PETSC_FALSE;
351: snes->alwayscomputesfinalresidual = PETSC_FALSE;
353: PetscCall(SNESParametersInitialize(snes));
355: snes->data = (void *)patch;
356: PetscCall(PCCreate(PetscObjectComm((PetscObject)snes), &patch->pc));
357: PetscCall(PCSetType(patch->pc, PCPATCH));
359: patchpc = (PC_PATCH *)patch->pc->data;
360: patchpc->classname = "snes";
361: patchpc->isNonlinear = PETSC_TRUE;
363: patchpc->setupsolver = PCSetUp_PATCH_Nonlinear;
364: patchpc->applysolver = PCApply_PATCH_Nonlinear;
365: patchpc->resetsolver = PCReset_PATCH_Nonlinear;
366: patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
367: patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
368: PetscFunctionReturn(PETSC_SUCCESS);
369: }
371: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap, const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
372: {
373: SNES_Patch *patch = (SNES_Patch *)snes->data;
374: DM dm;
376: PetscFunctionBegin;
377: PetscCall(SNESGetDM(snes, &dm));
378: PetscCheck(dm, PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES");
379: PetscCall(PCSetDM(patch->pc, dm));
380: PetscCall(PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes));
381: PetscFunctionReturn(PETSC_SUCCESS);
382: }
384: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), PetscCtx ctx)
385: {
386: SNES_Patch *patch = (SNES_Patch *)snes->data;
388: PetscFunctionBegin;
389: PetscCall(PCPatchSetComputeOperator(patch->pc, func, ctx));
390: PetscFunctionReturn(PETSC_SUCCESS);
391: }
393: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), PetscCtx ctx)
394: {
395: SNES_Patch *patch = (SNES_Patch *)snes->data;
397: PetscFunctionBegin;
398: PetscCall(PCPatchSetComputeFunction(patch->pc, func, ctx));
399: PetscFunctionReturn(PETSC_SUCCESS);
400: }
402: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), PetscCtx ctx)
403: {
404: SNES_Patch *patch = (SNES_Patch *)snes->data;
406: PetscFunctionBegin;
407: PetscCall(PCPatchSetConstructType(patch->pc, ctype, func, ctx));
408: PetscFunctionReturn(PETSC_SUCCESS);
409: }
411: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
412: {
413: SNES_Patch *patch = (SNES_Patch *)snes->data;
415: PetscFunctionBegin;
416: PetscCall(PCPatchSetCellNumbering(patch->pc, cellNumbering));
417: PetscFunctionReturn(PETSC_SUCCESS);
418: }