Actual source code: ex23fwdadj.c

  1: static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a parametrized mass matrice.\n";

  3: /*
  4:   This example solves the simple ODE
  5:     c x' = b x, x(0) = a,
  6:   whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).

  8: */

 10: #include <petscts.h>

 12: typedef struct _n_User *User;
 13: struct _n_User {
 14:   PetscReal a;
 15:   PetscReal b;
 16:   PetscReal c;
 17:   /* Sensitivity analysis support */
 18:   PetscInt  steps;
 19:   PetscReal ftime;
 20:   Mat       Jac;  /* Jacobian matrix */
 21:   Mat       Jacp; /* JacobianP matrix */
 22:   Vec       x;
 23:   Mat       sp;        /* forward sensitivity variables */
 24:   Vec       lambda[1]; /* adjoint sensitivity variables */
 25:   Vec       mup[1];    /* adjoint sensitivity variables */
 26:   PetscInt  der;
 27: };

 29: static PetscErrorCode IFunction(TS ts, PetscReal t, Vec X, Vec Xdot, Vec F, void *ctx)
 30: {
 31:   User               user = (User)ctx;
 32:   const PetscScalar *x, *xdot;
 33:   PetscScalar       *f;

 35:   PetscFunctionBeginUser;
 36:   PetscCall(VecGetArrayRead(X, &x));
 37:   PetscCall(VecGetArrayRead(Xdot, &xdot));
 38:   PetscCall(VecGetArrayWrite(F, &f));
 39:   f[0] = user->c * xdot[0] - user->b * x[0];
 40:   PetscCall(VecRestoreArrayRead(X, &x));
 41:   PetscCall(VecRestoreArrayRead(Xdot, &xdot));
 42:   PetscCall(VecRestoreArrayWrite(F, &f));
 43:   PetscFunctionReturn(PETSC_SUCCESS);
 44: }

 46: static PetscErrorCode IJacobian(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal a, Mat A, Mat B, void *ctx)
 47: {
 48:   User               user     = (User)ctx;
 49:   PetscInt           rowcol[] = {0};
 50:   PetscScalar        J[1][1];
 51:   const PetscScalar *x;

 53:   PetscFunctionBeginUser;
 54:   PetscCall(VecGetArrayRead(X, &x));
 55:   J[0][0] = user->c * a - user->b * 1.0;
 56:   PetscCall(MatSetValues(B, 1, rowcol, 1, rowcol, &J[0][0], INSERT_VALUES));
 57:   PetscCall(VecRestoreArrayRead(X, &x));

 59:   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
 60:   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
 61:   if (A != B) {
 62:     PetscCall(MatAssemblyBegin(B, MAT_FINAL_ASSEMBLY));
 63:     PetscCall(MatAssemblyEnd(B, MAT_FINAL_ASSEMBLY));
 64:   }
 65:   PetscFunctionReturn(PETSC_SUCCESS);
 66: }

 68: static PetscErrorCode IJacobianP(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal shift, Mat A, void *ctx)
 69: {
 70:   User               user  = (User)ctx;
 71:   PetscInt           row[] = {0}, col[] = {0};
 72:   PetscScalar        J[1][1];
 73:   const PetscScalar *x, *xdot;
 74:   PetscReal          dt;

 76:   PetscFunctionBeginUser;
 77:   PetscCall(VecGetArrayRead(X, &x));
 78:   PetscCall(VecGetArrayRead(Xdot, &xdot));
 79:   PetscCall(TSGetTimeStep(ts, &dt));
 80:   if (user->der == 1) J[0][0] = xdot[0];
 81:   if (user->der == 2) J[0][0] = -x[0];
 82:   PetscCall(MatSetValues(A, 1, row, 1, col, &J[0][0], INSERT_VALUES));
 83:   PetscCall(VecRestoreArrayRead(X, &x));

 85:   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
 86:   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
 87:   PetscFunctionReturn(PETSC_SUCCESS);
 88: }

 90: int main(int argc, char **argv)
 91: {
 92:   TS             ts;
 93:   PetscScalar   *x_ptr;
 94:   PetscMPIInt    size;
 95:   struct _n_User user;
 96:   PetscInt       rows, cols;

 98:   PetscFunctionBeginUser;
 99:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));

101:   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
102:   PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");

104:   user.a     = 2.0;
105:   user.b     = 4.0;
106:   user.c     = 3.0;
107:   user.steps = 0;
108:   user.ftime = 1.0;
109:   user.der   = 1;
110:   PetscCall(PetscOptionsGetInt(NULL, NULL, "-der", &user.der, NULL));

112:   rows = 1;
113:   cols = 1;
114:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jac));
115:   PetscCall(MatSetSizes(user.Jac, PETSC_DECIDE, PETSC_DECIDE, 1, 1));
116:   PetscCall(MatSetFromOptions(user.Jac));
117:   PetscCall(MatSetUp(user.Jac));
118:   PetscCall(MatCreateVecs(user.Jac, &user.x, NULL));

120:   PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
121:   PetscCall(TSSetType(ts, TSBEULER));
122:   PetscCall(TSSetIFunction(ts, NULL, IFunction, &user));
123:   PetscCall(TSSetIJacobian(ts, user.Jac, user.Jac, IJacobian, &user));
124:   PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
125:   PetscCall(TSSetMaxTime(ts, user.ftime));

127:   PetscCall(VecGetArrayWrite(user.x, &x_ptr));
128:   x_ptr[0] = user.a;
129:   PetscCall(VecRestoreArrayWrite(user.x, &x_ptr));
130:   PetscCall(TSSetTimeStep(ts, 0.001));

132:   /* Set up forward sensitivity */
133:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
134:   PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, rows, cols));
135:   PetscCall(MatSetFromOptions(user.Jacp));
136:   PetscCall(MatSetUp(user.Jacp));
137:   PetscCall(MatCreateDense(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, rows, cols, NULL, &user.sp));
138:   PetscCall(MatZeroEntries(user.sp));
139:   PetscCall(TSForwardSetSensitivities(ts, cols, user.sp));
140:   PetscCall(TSSetIJacobianP(ts, user.Jacp, IJacobianP, &user));

142:   PetscCall(TSSetSaveTrajectory(ts));
143:   PetscCall(TSSetFromOptions(ts));

145:   PetscCall(TSSolve(ts, user.x));
146:   PetscCall(TSGetSolveTime(ts, &user.ftime));
147:   PetscCall(TSGetStepNumber(ts, &user.steps));
148:   PetscCall(VecGetArray(user.x, &x_ptr));
149:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n ode solution %g\n", (double)PetscRealPart(x_ptr[0])));
150:   PetscCall(VecRestoreArray(user.x, &x_ptr));
151:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical solution %g\n", (double)(user.a * PetscExpReal(user.b / user.c * user.ftime))));

153:   if (user.der == 1) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. c %g\n", (double)(-user.a * user.ftime * user.b / (user.c * user.c) * PetscExpReal(user.b / user.c * user.ftime))));
154:   if (user.der == 2) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. b %g\n", (double)(user.a * user.ftime / user.c * PetscExpReal(user.b / user.c * user.ftime))));
155:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n forward sensitivity:\n"));
156:   PetscCall(MatView(user.sp, PETSC_VIEWER_STDOUT_WORLD));

158:   PetscCall(MatCreateVecs(user.Jac, &user.lambda[0], NULL));
159:   /* Set initial conditions for the adjoint integration */
160:   PetscCall(VecGetArrayWrite(user.lambda[0], &x_ptr));
161:   x_ptr[0] = 1.0;
162:   PetscCall(VecRestoreArrayWrite(user.lambda[0], &x_ptr));
163:   PetscCall(MatCreateVecs(user.Jacp, &user.mup[0], NULL));
164:   PetscCall(VecGetArrayWrite(user.mup[0], &x_ptr));
165:   x_ptr[0] = 0.0;
166:   PetscCall(VecRestoreArrayWrite(user.mup[0], &x_ptr));

168:   PetscCall(TSSetCostGradients(ts, 1, user.lambda, user.mup));
169:   PetscCall(TSAdjointSolve(ts));

171:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n adjoint sensitivity:\n"));
172:   PetscCall(VecView(user.mup[0], PETSC_VIEWER_STDOUT_WORLD));

174:   PetscCall(MatDestroy(&user.Jac));
175:   PetscCall(MatDestroy(&user.sp));
176:   PetscCall(MatDestroy(&user.Jacp));
177:   PetscCall(VecDestroy(&user.x));
178:   PetscCall(VecDestroy(&user.lambda[0]));
179:   PetscCall(VecDestroy(&user.mup[0]));
180:   PetscCall(TSDestroy(&ts));

182:   PetscCall(PetscFinalize());
183:   return 0;
184: }

186: /*TEST

188:     test:
189:       args: -ts_type beuler

191:     test:
192:       suffix: 2
193:       args: -ts_type cn
194:       output_file: output/ex23fwdadj_1.out

196: TEST*/