Actual source code: ex20td.c
1: static char help[] = "Performs adjoint sensitivity analysis for a van der Pol like \n\
2: equation with time dependent parameters using two approaches : \n\
3: track : track only local sensitivities at each adjoint step \n\
4: and accumulate them in a global array \n\
5: global : track parameters at all timesteps together \n\
6: Choose one of the two at runtime by -sa_method {track,global}. \n";
8: /*
9: Simple example to demonstrate TSAdjoint capabilities for time dependent params
10: without integral cost terms using either a tracking or global method.
12: Modify the Van Der Pol Eq to :
13: [u1'] = [mu1(t)*u1]
14: [u2'] = [mu2(t)*((1-u1^2)*u2-u1)]
15: (with initial conditions & params independent)
17: Define uref to be solution with initial conditions (2,-2/3), mu=(1,1e3)
18: - u_ref : (1.5967,-1.02969)
20: Define const function as cost = 2-norm(u - u_ref);
22: Initialization for the adjoint TS :
23: - dcost/dy|final_time = 2*(u-u_ref)|final_time
24: - dcost/dp|final_time = 0
26: The tracking method only tracks local sensitivity at each time step
27: and accumulates these sensitivities in a global array. Since the structure
28: of the equations being solved at each time step does not change, the jacobian
29: wrt parameters is defined analogous to constant RHSJacobian for a liner
30: TSSolve and the size of the jacP is independent of the number of time
31: steps. Enable this mode of adjoint analysis by -sa_method track.
33: The global method combines the parameters at all timesteps and tracks them
34: together. Thus, the columns of the jacP matrix are filled dependent upon the
35: time step. Also, the dimensions of the jacP matrix now depend upon the number
36: of time steps. Enable this mode of adjoint analysis by -sa_method global.
38: Since the equations here have parameters at predefined time steps, this
39: example should be run with non adaptive time stepping solvers only. This
40: can be ensured by -ts_adapt_type none (which is the default behavior only
41: for certain TS solvers like TSCN. If using an explicit TS like TSRK,
42: please be sure to add the aforementioned option to disable adaptive
43: timestepping.)
44: */
46: /*
47: Include "petscts.h" so that we can use TS solvers. Note that this file
48: automatically includes:
49: petscsys.h - base PETSc routines petscvec.h - vectors
50: petscmat.h - matrices
51: petscis.h - index sets petscksp.h - Krylov subspace methods
52: petscviewer.h - viewers petscpc.h - preconditioners
53: petscksp.h - linear solvers petscsnes.h - nonlinear solvers
54: */
55: #include <petscts.h>
57: extern PetscErrorCode RHSFunction(TS, PetscReal, Vec, Vec, void *);
58: extern PetscErrorCode RHSJacobian(TS, PetscReal, Vec, Mat, Mat, void *);
59: extern PetscErrorCode RHSJacobianP_track(TS, PetscReal, Vec, Mat, void *);
60: extern PetscErrorCode RHSJacobianP_global(TS, PetscReal, Vec, Mat, void *);
61: extern PetscErrorCode Monitor(TS, PetscInt, PetscReal, Vec, void *);
62: extern PetscErrorCode AdjointMonitor(TS, PetscInt, PetscReal, Vec, PetscInt, Vec *, Vec *, void *);
64: /*
65: User-defined application context - contains data needed by the
66: application-provided call-back routines.
67: */
69: typedef struct {
70: /*------------- Forward solve data structures --------------*/
71: PetscInt max_steps; /* number of steps to run ts for */
72: PetscReal final_time; /* final time to integrate to*/
73: PetscReal time_step; /* ts integration time step */
74: Vec mu1; /* time dependent params */
75: Vec mu2; /* time dependent params */
76: Vec U; /* solution vector */
77: Mat A; /* Jacobian matrix */
79: /*------------- Adjoint solve data structures --------------*/
80: Mat Jacp; /* JacobianP matrix */
81: Vec lambda; /* adjoint variable */
82: Vec mup; /* adjoint variable */
84: /*------------- Global accumation vecs for monitor based tracking --------------*/
85: Vec sens_mu1; /* global sensitivity accumulation */
86: Vec sens_mu2; /* global sensitivity accumulation */
87: PetscInt adj_idx; /* to keep track of adjoint solve index */
88: } AppCtx;
90: typedef enum {
91: SA_TRACK,
92: SA_GLOBAL
93: } SAMethod;
94: static const char *const SAMethods[] = {"TRACK", "GLOBAL", "SAMethod", "SA_", 0};
96: /* ----------------------- Explicit form of the ODE -------------------- */
98: PetscErrorCode RHSFunction(TS ts, PetscReal t, Vec U, Vec F, void *ctx)
99: {
100: AppCtx *user = (AppCtx *)ctx;
101: PetscScalar *f;
102: PetscInt curr_step;
103: const PetscScalar *u;
104: const PetscScalar *mu1;
105: const PetscScalar *mu2;
107: PetscFunctionBeginUser;
108: PetscCall(TSGetStepNumber(ts, &curr_step));
109: PetscCall(VecGetArrayRead(U, &u));
110: PetscCall(VecGetArrayRead(user->mu1, &mu1));
111: PetscCall(VecGetArrayRead(user->mu2, &mu2));
112: PetscCall(VecGetArray(F, &f));
113: f[0] = mu1[curr_step] * u[1];
114: f[1] = mu2[curr_step] * ((1. - u[0] * u[0]) * u[1] - u[0]);
115: PetscCall(VecRestoreArrayRead(U, &u));
116: PetscCall(VecRestoreArrayRead(user->mu1, &mu1));
117: PetscCall(VecRestoreArrayRead(user->mu2, &mu2));
118: PetscCall(VecRestoreArray(F, &f));
119: PetscFunctionReturn(PETSC_SUCCESS);
120: }
122: PetscErrorCode RHSJacobian(TS ts, PetscReal t, Vec U, Mat A, Mat B, void *ctx)
123: {
124: AppCtx *user = (AppCtx *)ctx;
125: PetscInt rowcol[] = {0, 1};
126: PetscScalar J[2][2];
127: PetscInt curr_step;
128: const PetscScalar *u;
129: const PetscScalar *mu1;
130: const PetscScalar *mu2;
132: PetscFunctionBeginUser;
133: PetscCall(TSGetStepNumber(ts, &curr_step));
134: PetscCall(VecGetArrayRead(user->mu1, &mu1));
135: PetscCall(VecGetArrayRead(user->mu2, &mu2));
136: PetscCall(VecGetArrayRead(U, &u));
137: J[0][0] = 0;
138: J[1][0] = -mu2[curr_step] * (2.0 * u[1] * u[0] + 1.);
139: J[0][1] = mu1[curr_step];
140: J[1][1] = mu2[curr_step] * (1.0 - u[0] * u[0]);
141: PetscCall(MatSetValues(A, 2, rowcol, 2, rowcol, &J[0][0], INSERT_VALUES));
142: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
143: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
144: PetscCall(VecRestoreArrayRead(U, &u));
145: PetscCall(VecRestoreArrayRead(user->mu1, &mu1));
146: PetscCall(VecRestoreArrayRead(user->mu2, &mu2));
147: PetscFunctionReturn(PETSC_SUCCESS);
148: }
150: /* ------------------ Jacobian wrt parameters for tracking method ------------------ */
152: PetscErrorCode RHSJacobianP_track(TS ts, PetscReal t, Vec U, Mat A, void *ctx)
153: {
154: PetscInt row[] = {0, 1}, col[] = {0, 1};
155: PetscScalar J[2][2];
156: const PetscScalar *u;
158: PetscFunctionBeginUser;
159: PetscCall(VecGetArrayRead(U, &u));
160: J[0][0] = u[1];
161: J[1][0] = 0;
162: J[0][1] = 0;
163: J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0];
164: PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES));
165: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
166: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
167: PetscCall(VecRestoreArrayRead(U, &u));
168: PetscFunctionReturn(PETSC_SUCCESS);
169: }
171: /* ------------------ Jacobian wrt parameters for global method ------------------ */
173: PetscErrorCode RHSJacobianP_global(TS ts, PetscReal t, Vec U, Mat A, void *ctx)
174: {
175: PetscInt row[] = {0, 1}, col[] = {0, 1};
176: PetscScalar J[2][2];
177: const PetscScalar *u;
178: PetscInt curr_step;
180: PetscFunctionBeginUser;
181: PetscCall(TSGetStepNumber(ts, &curr_step));
182: PetscCall(VecGetArrayRead(U, &u));
183: J[0][0] = u[1];
184: J[1][0] = 0;
185: J[0][1] = 0;
186: J[1][1] = (1. - u[0] * u[0]) * u[1] - u[0];
187: col[0] = curr_step * 2;
188: col[1] = curr_step * 2 + 1;
189: PetscCall(MatSetValues(A, 2, row, 2, col, &J[0][0], INSERT_VALUES));
190: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
191: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
192: PetscCall(VecRestoreArrayRead(U, &u));
193: PetscFunctionReturn(PETSC_SUCCESS);
194: }
196: /* Dump solution to console if called */
197: PetscErrorCode Monitor(TS ts, PetscInt step, PetscReal t, Vec U, void *ctx)
198: {
199: PetscFunctionBeginUser;
200: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution at time %e is \n", (double)t));
201: PetscCall(VecView(U, PETSC_VIEWER_STDOUT_WORLD));
202: PetscFunctionReturn(PETSC_SUCCESS);
203: }
205: /* Customized adjoint monitor to keep track of local
206: sensitivities by storing them in a global sensitivity array.
207: Note : This routine is only used for the tracking method. */
208: PetscErrorCode AdjointMonitor(TS ts, PetscInt steps, PetscReal time, Vec u, PetscInt numcost, Vec *lambda, Vec *mu, void *ctx)
209: {
210: AppCtx *user = (AppCtx *)ctx;
211: PetscInt curr_step;
212: PetscScalar *sensmu1_glob;
213: PetscScalar *sensmu2_glob;
214: const PetscScalar *sensmu_loc;
216: PetscFunctionBeginUser;
217: PetscCall(TSGetStepNumber(ts, &curr_step));
218: /* Note that we skip the first call to the monitor in the adjoint
219: solve since the sensitivities are already set (during
220: initialization of adjoint vectors).
221: We also note that each indvidial TSAdjointSolve calls the monitor
222: twice, once at the step it is integrating from and once at the step
223: it integrates to. Only the second call is useful for transferring
224: local sensitivities to the global array. */
225: if (curr_step == user->adj_idx) {
226: PetscFunctionReturn(PETSC_SUCCESS);
227: } else {
228: PetscCall(VecGetArrayRead(*mu, &sensmu_loc));
229: PetscCall(VecGetArray(user->sens_mu1, &sensmu1_glob));
230: PetscCall(VecGetArray(user->sens_mu2, &sensmu2_glob));
231: sensmu1_glob[curr_step] = sensmu_loc[0];
232: sensmu2_glob[curr_step] = sensmu_loc[1];
233: PetscCall(VecRestoreArray(user->sens_mu1, &sensmu1_glob));
234: PetscCall(VecRestoreArray(user->sens_mu2, &sensmu2_glob));
235: PetscCall(VecRestoreArrayRead(*mu, &sensmu_loc));
236: PetscFunctionReturn(PETSC_SUCCESS);
237: }
238: }
240: int main(int argc, char **argv)
241: {
242: TS ts;
243: AppCtx user;
244: PetscScalar *x_ptr, *y_ptr, *u_ptr;
245: PetscMPIInt size;
246: PetscBool monitor = PETSC_FALSE;
247: SAMethod sa = SA_GLOBAL;
249: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
250: Initialize program
251: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
252: PetscFunctionBeginUser;
253: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
254: PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
255: PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");
257: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
258: Set runtime options
259: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
260: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "SA analysis options.", "");
261: {
262: PetscCall(PetscOptionsGetBool(NULL, NULL, "-monitor", &monitor, NULL));
263: PetscCall(PetscOptionsEnum("-sa_method", "Sensitivity analysis method (track or global)", "", SAMethods, (PetscEnum)sa, (PetscEnum *)&sa, NULL));
264: }
265: PetscOptionsEnd();
267: user.final_time = 0.1;
268: user.max_steps = 5;
269: user.time_step = user.final_time / user.max_steps;
271: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
272: Create necessary matrix and vectors for forward solve.
273: Create Jacp matrix for adjoint solve.
274: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
275: PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu1));
276: PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.mu2));
277: PetscCall(VecSet(user.mu1, 1.25));
278: PetscCall(VecSet(user.mu2, 1.0e2));
280: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
281: For tracking method : create the global sensitivity array to
282: accumulate sensitivity with respect to parameters at each step.
283: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
284: if (sa == SA_TRACK) {
285: PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu1));
286: PetscCall(VecCreateSeq(PETSC_COMM_WORLD, user.max_steps, &user.sens_mu2));
287: }
289: PetscCall(MatCreate(PETSC_COMM_WORLD, &user.A));
290: PetscCall(MatSetSizes(user.A, PETSC_DECIDE, PETSC_DECIDE, 2, 2));
291: PetscCall(MatSetFromOptions(user.A));
292: PetscCall(MatSetUp(user.A));
293: PetscCall(MatCreateVecs(user.A, &user.U, NULL));
295: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
296: Note that the dimensions of the Jacp matrix depend upon the
297: sensitivity analysis method being used !
298: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
299: PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
300: if (sa == SA_TRACK) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, 2));
301: if (sa == SA_GLOBAL) PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, 2, user.max_steps * 2));
302: PetscCall(MatSetFromOptions(user.Jacp));
303: PetscCall(MatSetUp(user.Jacp));
305: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
306: Create timestepping solver context
307: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
308: PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
309: PetscCall(TSSetEquationType(ts, TS_EQ_ODE_EXPLICIT));
310: PetscCall(TSSetType(ts, TSCN));
312: PetscCall(TSSetRHSFunction(ts, NULL, RHSFunction, &user));
313: PetscCall(TSSetRHSJacobian(ts, user.A, user.A, RHSJacobian, &user));
314: if (sa == SA_TRACK) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_track, &user));
315: if (sa == SA_GLOBAL) PetscCall(TSSetRHSJacobianP(ts, user.Jacp, RHSJacobianP_global, &user));
317: PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
318: PetscCall(TSSetMaxTime(ts, user.final_time));
319: PetscCall(TSSetTimeStep(ts, user.final_time / user.max_steps));
321: if (monitor) PetscCall(TSMonitorSet(ts, Monitor, &user, NULL));
322: if (sa == SA_TRACK) PetscCall(TSAdjointMonitorSet(ts, AdjointMonitor, &user, NULL));
324: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
325: Set initial conditions
326: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
327: PetscCall(VecGetArray(user.U, &x_ptr));
328: x_ptr[0] = 2.0;
329: x_ptr[1] = -2.0 / 3.0;
330: PetscCall(VecRestoreArray(user.U, &x_ptr));
332: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
333: Save trajectory of solution so that TSAdjointSolve() may be used
334: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
335: PetscCall(TSSetSaveTrajectory(ts));
337: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
338: Set runtime options
339: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
340: PetscCall(TSSetFromOptions(ts));
342: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
343: Execute forward model and print solution.
344: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
345: PetscCall(TSSolve(ts, user.U));
346: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Solution of forward TS :\n"));
347: PetscCall(VecView(user.U, PETSC_VIEWER_STDOUT_WORLD));
348: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n Forward TS solve successful! Adjoint run begins!\n"));
350: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
351: Adjoint model starts here! Create adjoint vectors.
352: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
353: PetscCall(MatCreateVecs(user.A, &user.lambda, NULL));
354: PetscCall(MatCreateVecs(user.Jacp, &user.mup, NULL));
356: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
357: Set initial conditions for the adjoint vector
358: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
359: PetscCall(VecGetArray(user.U, &u_ptr));
360: PetscCall(VecGetArray(user.lambda, &y_ptr));
361: y_ptr[0] = 2 * (u_ptr[0] - 1.5967);
362: y_ptr[1] = 2 * (u_ptr[1] - -(1.02969));
363: PetscCall(VecRestoreArray(user.lambda, &y_ptr));
364: PetscCall(VecRestoreArray(user.U, &y_ptr));
365: PetscCall(VecSet(user.mup, 0));
367: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
368: Set number of cost functions.
369: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
370: PetscCall(TSSetCostGradients(ts, 1, &user.lambda, &user.mup));
372: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
373: The adjoint vector mup has to be reset for each adjoint step when
374: using the tracking method as we want to treat the parameters at each
375: time step one at a time and prevent accumulation of the sensitivities
376: from parameters at previous time steps.
377: This is not necessary for the global method as each time dependent
378: parameter is treated as an independent parameter.
379: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
380: if (sa == SA_TRACK) {
381: for (user.adj_idx = user.max_steps; user.adj_idx > 0; user.adj_idx--) {
382: PetscCall(VecSet(user.mup, 0));
383: PetscCall(TSAdjointSetSteps(ts, 1));
384: PetscCall(TSAdjointSolve(ts));
385: }
386: }
387: if (sa == SA_GLOBAL) PetscCall(TSAdjointSolve(ts));
389: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
390: Display adjoint sensitivities wrt parameters and initial conditions
391: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
392: if (sa == SA_TRACK) {
393: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt mu1: d[cost]/d[mu1]\n"));
394: PetscCall(VecView(user.sens_mu1, PETSC_VIEWER_STDOUT_WORLD));
395: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt mu2: d[cost]/d[mu2]\n"));
396: PetscCall(VecView(user.sens_mu2, PETSC_VIEWER_STDOUT_WORLD));
397: }
399: if (sa == SA_GLOBAL) {
400: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt params: d[cost]/d[p], where p refers to \nthe interlaced vector made by combining mu1,mu2\n"));
401: PetscCall(VecView(user.mup, PETSC_VIEWER_STDOUT_WORLD));
402: }
404: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n sensitivity wrt initial conditions: d[cost]/d[u(t=0)]\n"));
405: PetscCall(VecView(user.lambda, PETSC_VIEWER_STDOUT_WORLD));
407: /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
408: Free work space!
409: All PETSc objects should be destroyed when they are no longer needed.
410: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
411: PetscCall(MatDestroy(&user.A));
412: PetscCall(MatDestroy(&user.Jacp));
413: PetscCall(VecDestroy(&user.U));
414: PetscCall(VecDestroy(&user.lambda));
415: PetscCall(VecDestroy(&user.mup));
416: PetscCall(VecDestroy(&user.mu1));
417: PetscCall(VecDestroy(&user.mu2));
418: if (sa == SA_TRACK) {
419: PetscCall(VecDestroy(&user.sens_mu1));
420: PetscCall(VecDestroy(&user.sens_mu2));
421: }
422: PetscCall(TSDestroy(&ts));
423: PetscCall(PetscFinalize());
424: return 0;
425: }
427: /*TEST
429: test:
430: requires: !complex
431: suffix : track
432: args : -sa_method track
434: test:
435: requires: !complex
436: suffix : global
437: args : -sa_method global
439: TEST*/