Actual source code: taosnes.c

  1: #include <petsc/private/taoimpl.h>

  3: typedef struct {
  4:   SNES      snes;
  5:   PetscBool setfromoptionscalled;
  6: } Tao_SNES;

  8: static PetscErrorCode TaoSolve_SNES(Tao tao)
  9: {
 10:   Tao_SNES *taosnes = (Tao_SNES *)tao->data;
 11:   PetscInt  its;

 13:   PetscFunctionBegin;
 14:   /* TODO SNES fails if KSP reaches max_it, while TAO accepts whatever we got */
 15:   PetscCall(SNESSolve(taosnes->snes, NULL, tao->solution));
 16:   /* TODO REASONS */
 17:   tao->reason = TAO_CONVERGED_USER;
 18:   PetscCall(SNESGetIterationNumber(taosnes->snes, &its));
 19:   PetscCall(TaoSetIterationNumber(tao, its));
 20:   PetscFunctionReturn(PETSC_SUCCESS);
 21: }

 23: static PetscErrorCode TaoDestroy_SNES(Tao tao)
 24: {
 25:   Tao_SNES *taosnes = (Tao_SNES *)tao->data;

 27:   PetscFunctionBegin;
 28:   PetscCall(SNESDestroy(&taosnes->snes));
 29:   PetscCall(PetscFree(tao->data));
 30:   PetscFunctionReturn(PETSC_SUCCESS);
 31: }

 33: static PetscErrorCode TAOSNESObj(SNES snes, Vec X, PetscReal *f, void *ctx)
 34: {
 35:   Tao tao = (Tao)ctx;

 37:   PetscFunctionBegin;
 38:   PetscCall(TaoComputeObjective(tao, X, f));
 39:   PetscFunctionReturn(PETSC_SUCCESS);
 40: }

 42: static PetscErrorCode TAOSNESFunc(SNES snes, Vec X, Vec F, void *ctx)
 43: {
 44:   Tao tao = (Tao)ctx;

 46:   PetscFunctionBegin;
 47:   PetscCall(TaoComputeGradient(tao, X, F));
 48:   PetscFunctionReturn(PETSC_SUCCESS);
 49: }

 51: static PetscErrorCode TAOSNESJac(SNES snes, Vec X, Mat A, Mat P, void *ctx)
 52: {
 53:   Tao tao = (Tao)ctx;

 55:   PetscFunctionBegin;
 56:   PetscCall(TaoComputeHessian(tao, X, A, P));
 57:   PetscFunctionReturn(PETSC_SUCCESS);
 58: }

 60: static PetscErrorCode TAOSNESMonitor(SNES snes, PetscInt its, PetscReal fnorm, void *ctx)
 61: {
 62:   Tao       tao = (Tao)ctx;
 63:   PetscReal obj;
 64:   Vec       X;

 66:   PetscFunctionBegin;
 67:   PetscCall(SNESGetSolution(snes, &X));
 68:   PetscCall(TaoComputeObjective(tao, X, &obj));
 69:   PetscCall(TaoSetIterationNumber(tao, its));
 70:   PetscCall(TaoMonitor(tao, its, obj, fnorm, 0.0, 0.0));
 71:   PetscFunctionReturn(PETSC_SUCCESS);
 72: }

 74: static PetscErrorCode TaoSetUp_SNES(Tao tao)
 75: {
 76:   Tao_SNES   *taosnes = (Tao_SNES *)tao->data;
 77:   Mat         A, P;
 78:   const char *prefix;

 80:   PetscFunctionBegin;
 81:   PetscCall(TaoGetOptionsPrefix(tao, &prefix));
 82:   PetscCall(SNESSetOptionsPrefix(taosnes->snes, prefix));
 83:   PetscCall(SNESSetSolution(taosnes->snes, tao->solution));
 84:   PetscCall(SNESSetObjective(taosnes->snes, TAOSNESObj, tao));
 85:   PetscCall(SNESSetFunction(taosnes->snes, NULL, TAOSNESFunc, tao));
 86:   PetscCall(SNESMonitorSet(taosnes->snes, TAOSNESMonitor, tao, NULL));
 87:   PetscCall(TaoGetHessian(tao, &A, &P, NULL, NULL));
 88:   if (A) PetscCall(SNESSetJacobian(taosnes->snes, A, P, TAOSNESJac, tao));
 89:   if (taosnes->setfromoptionscalled) PetscCall(SNESSetFromOptions(taosnes->snes));
 90:   taosnes->setfromoptionscalled = PETSC_FALSE;
 91:   PetscCall(SNESSetUp(taosnes->snes));
 92:   PetscFunctionReturn(PETSC_SUCCESS);
 93: }

 95: static PetscErrorCode TaoSetFromOptions_SNES(Tao tao, PetscOptionItems *PetscOptionsObject)
 96: {
 97:   Tao_SNES *taosnes = (Tao_SNES *)tao->data;

 99:   PetscFunctionBegin;
100:   taosnes->setfromoptionscalled = PETSC_TRUE;
101:   PetscFunctionReturn(PETSC_SUCCESS);
102: }

104: static PetscErrorCode TaoView_SNES(Tao tao, PetscViewer viewer)
105: {
106:   Tao_SNES *taosnes = (Tao_SNES *)tao->data;

108:   PetscFunctionBegin;
109:   PetscCall(SNESView(taosnes->snes, viewer));
110:   PetscFunctionReturn(PETSC_SUCCESS);
111: }

113: /*MC
114:   TAOSNES - nonlinear solver using SNES

116:    Level: advanced

118: .seealso: `TaoCreate()`, `Tao`, `TaoSetType()`, `TaoType`
119: M*/
120: PETSC_EXTERN PetscErrorCode TaoCreate_SNES(Tao tao)
121: {
122:   Tao_SNES *taosnes;

124:   PetscFunctionBegin;
125:   tao->ops->destroy        = TaoDestroy_SNES;
126:   tao->ops->setup          = TaoSetUp_SNES;
127:   tao->ops->setfromoptions = TaoSetFromOptions_SNES;
128:   tao->ops->view           = TaoView_SNES;
129:   tao->ops->solve          = TaoSolve_SNES;

131:   PetscCall(PetscNew(&taosnes));
132:   tao->data = (void *)taosnes;
133:   PetscCall(SNESCreate(PetscObjectComm((PetscObject)tao), &taosnes->snes));
134:   PetscCall(PetscObjectIncrementTabLevel((PetscObject)taosnes->snes, (PetscObject)tao, 1));
135:   PetscFunctionReturn(PETSC_SUCCESS);
136: }