Actual source code: ex3.c

  1: static char help[] = "Shallow water test cases with data assimilation.\n"
  2:                      "Implements 1D shallow water equations with 2 DOF per grid point (h, hu).\n\n"
  3:                      "Example usage:\n"
  4:                      "  ./ex3 -steps 100 -obs_freq 5 -obs_error 0.1 -petscda_view -petscda_ensemble_size 30\n"
  5:                      "  ./ex3 -ex3_test wave -steps 500\n\n";

  7: /* Data assimilation framework header (provides PetscDA) */
  8: #include <petscda.h>
  9: /* PETSc DMDA header (provides DM, DMDA functionality) */
 10: #include <petscdmda.h>
 11: #include <petscdmplex.h>
 12: #include <petscts.h>
 13: #include <petscvec.h>

 15: /* Default parameter values */
 16: #define DEFAULT_N             80 /* 80 grid points */
 17: #define DEFAULT_STEPS         100
 18: #define DEFAULT_OBS_FREQ      5
 19: #define DEFAULT_RANDOM_SEED   12345
 20: #define DEFAULT_G             9.81
 21: #define DEFAULT_DT            0.02
 22: #define DEFAULT_OBS_ERROR_STD 0.01
 23: #define DEFAULT_ENSEMBLE_SIZE 30
 24: #define SPINUP_STEPS          0 /* No spinup needed - wave test has smooth analytical initial condition */

 26: /* Minimum valid parameter values */
 27: #define MIN_N                 1
 28: #define MIN_ENSEMBLE_SIZE     2
 29: #define MIN_OBS_FREQ          1
 30: #define DEFAULT_PROGRESS_FREQ 10 /* Print progress every N steps by default */

 32: /* Test case types */
 33: typedef enum {
 34:   EX3_TEST_DAM,
 35:   EX3_TEST_WAVE
 36: } Ex3TestType;

 38: static PetscFunctionList Ex3TestList               = NULL;
 39: static PetscBool         Ex3TestPackageInitialized = PETSC_FALSE;

 41: typedef enum {
 42:   EX3_FLUX_RUSANOV,
 43:   EX3_FLUX_MC
 44: } Ex3FluxType;

 46: static const char *const Ex3FluxTypes[] = {"rusanov", "mc", "Ex3FluxType", "EX3_FLUX_", NULL};

 48: typedef struct {
 49:   DM          da;        /* 1D periodic DM storing the shallow water state */
 50:   PetscInt    n_vert;    /* State dimension (number of grid points) */
 51:   PetscReal   L;         /* Domain length */
 52:   PetscReal   g;         /* Gravitational constant */
 53:   PetscReal   dx;        /* Grid spacing */
 54:   PetscReal   dt;        /* Integration time step size */
 55:   TS          ts;        /* Reusable time stepper for efficiency */
 56:   Ex3TestType test_type; /* Test case type */
 57:   Ex3FluxType flux_type; /* Flux scheme */
 58: } ShallowWaterCtx;

 60: /*
 61:   Limit - MC (Monotonized Central) limiter
 62: */
 63: static PetscReal Limit(PetscReal a, PetscReal b)
 64: {
 65:   PetscReal c = 0.5 * (a + b);
 66:   if (a * b <= 0.0) return 0.0;
 67:   if (c > 0) return PetscMin(2.0 * a, PetscMin(2.0 * b, c));
 68:   else return PetscMax(2.0 * a, PetscMax(2.0 * b, c));
 69: }

 71: /*
 72:   ComputeFlux - Compute physical flux and wave speed for shallow water
 73: */
 74: static void ComputeFlux(PetscReal g, PetscReal h, PetscReal hu, PetscReal *F_h, PetscReal *F_hu, PetscReal *u, PetscReal *c)
 75: {
 76:   if (h > 1e-10) {
 77:     *u    = hu / h;
 78:     *c    = PetscSqrtReal(g * h);
 79:     *F_h  = hu;
 80:     *F_hu = hu * *u + 0.5 * g * h * h;
 81:   } else {
 82:     *u    = 0.0;
 83:     *c    = 0.0;
 84:     *F_h  = 0.0;
 85:     *F_hu = 0.0;
 86:   }
 87: }

 89: /*
 90:   ShallowWaterRHS - Compute the right-hand side of the shallow water equations

 92:   Dispatches to appropriate flux scheme implementation.
 93: */
 94: static PetscErrorCode ShallowWaterRHS(TS ts, PetscReal t, Vec X, Vec F_vec, PetscCtx ctx)
 95: {
 96:   ShallowWaterCtx   *sw = (ShallowWaterCtx *)ctx;
 97:   Vec                X_local;
 98:   const PetscScalar *x;
 99:   PetscScalar       *f;
100:   PetscInt           xs, xm, i;
101:   const PetscInt     ndof = 2; /* h and hu */

103:   PetscFunctionBeginUser;
104:   (void)ts;
105:   (void)t;

107:   PetscCall(DMDAGetCorners(sw->da, &xs, NULL, NULL, &xm, NULL, NULL));
108:   PetscCall(DMGetLocalVector(sw->da, &X_local));
109:   PetscCall(DMGlobalToLocalBegin(sw->da, X, INSERT_VALUES, X_local));
110:   PetscCall(DMGlobalToLocalEnd(sw->da, X, INSERT_VALUES, X_local));
111:   PetscCall(DMDAVecGetArrayRead(sw->da, X_local, &x));
112:   PetscCall(DMDAVecGetArray(sw->da, F_vec, &f));

114:   if (sw->flux_type == EX3_FLUX_RUSANOV) {
115:     /* First-order Rusanov (Local Lax-Friedrichs) scheme */
116:     for (i = xs; i < xs + xm; i++) {
117:       PetscReal h  = PetscRealPart(x[i * ndof]);
118:       PetscReal hu = PetscRealPart(x[i * ndof + 1]);

120:       PetscReal h_im1  = PetscRealPart(x[(i - 1) * ndof]);
121:       PetscReal hu_im1 = PetscRealPart(x[(i - 1) * ndof + 1]);

123:       PetscReal h_ip1  = PetscRealPart(x[(i + 1) * ndof]);
124:       PetscReal hu_ip1 = PetscRealPart(x[(i + 1) * ndof + 1]);

126:       PetscReal F_h_i, F_hu_i, u, c;
127:       PetscReal F_h_im1, F_hu_im1, u_im1, c_im1;
128:       PetscReal F_h_ip1, F_hu_ip1, u_ip1, c_ip1;

130:       ComputeFlux(sw->g, h, hu, &F_h_i, &F_hu_i, &u, &c);
131:       ComputeFlux(sw->g, h_im1, hu_im1, &F_h_im1, &F_hu_im1, &u_im1, &c_im1);
132:       ComputeFlux(sw->g, h_ip1, hu_ip1, &F_h_ip1, &F_hu_ip1, &u_ip1, &c_ip1);

134:       PetscReal alpha_left  = PetscMax(PetscAbsReal(u_im1) + c_im1, PetscAbsReal(u) + c);
135:       PetscReal alpha_right = PetscMax(PetscAbsReal(u) + c, PetscAbsReal(u_ip1) + c_ip1);

137:       PetscReal flux_h_left  = 0.5 * (F_h_im1 + F_h_i - alpha_left * (h - h_im1));
138:       PetscReal flux_hu_left = 0.5 * (F_hu_im1 + F_hu_i - alpha_left * (hu - hu_im1));

140:       PetscReal flux_h_right  = 0.5 * (F_h_i + F_h_ip1 - alpha_right * (h_ip1 - h));
141:       PetscReal flux_hu_right = 0.5 * (F_hu_i + F_hu_ip1 - alpha_right * (hu_ip1 - hu));

143:       f[i * ndof]     = -(flux_h_right - flux_h_left) / sw->dx;
144:       f[i * ndof + 1] = -(flux_hu_right - flux_hu_left) / sw->dx;
145:     }
146:   } else {
147:     /* Second-order MC (Monotonized Central) scheme */
148:     for (i = xs; i < xs + xm; i++) {
149:       /* Read state */
150:       PetscReal h_im2 = PetscRealPart(x[(i - 2) * ndof]);
151:       PetscReal h_im1 = PetscRealPart(x[(i - 1) * ndof]);
152:       PetscReal h_i   = PetscRealPart(x[i * ndof]);
153:       PetscReal h_ip1 = PetscRealPart(x[(i + 1) * ndof]);
154:       PetscReal h_ip2 = PetscRealPart(x[(i + 2) * ndof]);

156:       PetscReal hu_im2 = PetscRealPart(x[(i - 2) * ndof + 1]);
157:       PetscReal hu_im1 = PetscRealPart(x[(i - 1) * ndof + 1]);
158:       PetscReal hu_i   = PetscRealPart(x[i * ndof + 1]);
159:       PetscReal hu_ip1 = PetscRealPart(x[(i + 1) * ndof + 1]);
160:       PetscReal hu_ip2 = PetscRealPart(x[(i + 2) * ndof + 1]);

162:       /* Compute slopes (MC limiter) */
163:       PetscReal s_h_im1 = Limit(h_im1 - h_im2, h_i - h_im1);
164:       PetscReal s_h_i   = Limit(h_i - h_im1, h_ip1 - h_i);
165:       PetscReal s_h_ip1 = Limit(h_ip1 - h_i, h_ip2 - h_ip1);

167:       PetscReal s_hu_im1 = Limit(hu_im1 - hu_im2, hu_i - hu_im1);
168:       PetscReal s_hu_i   = Limit(hu_i - hu_im1, hu_ip1 - hu_i);
169:       PetscReal s_hu_ip1 = Limit(hu_ip1 - hu_i, hu_ip2 - hu_ip1);

171:       /* Reconstruct states at interfaces */
172:       /* Left interface (i-1/2) */
173:       PetscReal h_L_left  = h_im1 + 0.5 * s_h_im1;
174:       PetscReal hu_L_left = hu_im1 + 0.5 * s_hu_im1;
175:       PetscReal h_R_left  = h_i - 0.5 * s_h_i;
176:       PetscReal hu_R_left = hu_i - 0.5 * s_hu_i;

178:       /* Right interface (i+1/2) */
179:       PetscReal h_L_right  = h_i + 0.5 * s_h_i;
180:       PetscReal hu_L_right = hu_i + 0.5 * s_hu_i;
181:       PetscReal h_R_right  = h_ip1 - 0.5 * s_h_ip1;
182:       PetscReal hu_R_right = hu_ip1 - 0.5 * s_hu_ip1;

184:       /* Compute fluxes */
185:       PetscReal F_h_LL, F_hu_LL, u_LL, c_LL;
186:       PetscReal F_h_RL, F_hu_RL, u_RL, c_RL;
187:       PetscReal F_h_LR, F_hu_LR, u_LR, c_LR;
188:       PetscReal F_h_RR, F_hu_RR, u_RR, c_RR;

190:       ComputeFlux(sw->g, h_L_left, hu_L_left, &F_h_LL, &F_hu_LL, &u_LL, &c_LL);
191:       ComputeFlux(sw->g, h_R_left, hu_R_left, &F_h_RL, &F_hu_RL, &u_RL, &c_RL);
192:       ComputeFlux(sw->g, h_L_right, hu_L_right, &F_h_LR, &F_hu_LR, &u_LR, &c_LR);
193:       ComputeFlux(sw->g, h_R_right, hu_R_right, &F_h_RR, &F_hu_RR, &u_RR, &c_RR);

195:       /* Rusanov flux */
196:       PetscReal speed_left   = PetscMax(PetscAbsReal(u_LL) + c_LL, PetscAbsReal(u_RL) + c_RL);
197:       PetscReal flux_h_left  = 0.5 * (F_h_LL + F_h_RL - speed_left * (h_R_left - h_L_left));
198:       PetscReal flux_hu_left = 0.5 * (F_hu_LL + F_hu_RL - speed_left * (hu_R_left - hu_L_left));

200:       PetscReal speed_right   = PetscMax(PetscAbsReal(u_LR) + c_LR, PetscAbsReal(u_RR) + c_RR);
201:       PetscReal flux_h_right  = 0.5 * (F_h_LR + F_h_RR - speed_right * (h_R_right - h_L_right));
202:       PetscReal flux_hu_right = 0.5 * (F_hu_LR + F_hu_RR - speed_right * (hu_R_right - hu_L_right));

204:       /* Update RHS using finite volume method */
205:       f[i * ndof]     = -(flux_h_right - flux_h_left) / sw->dx;
206:       f[i * ndof + 1] = -(flux_hu_right - flux_hu_left) / sw->dx;
207:     }
208:   }

210:   PetscCall(DMDAVecRestoreArrayRead(sw->da, X_local, &x));
211:   PetscCall(DMDAVecRestoreArray(sw->da, F_vec, &f));
212:   PetscCall(DMRestoreLocalVector(sw->da, &X_local));
213:   PetscFunctionReturn(PETSC_SUCCESS);
214: }

216: /*
217:   ShallowWaterContextCreate - Create and initialize a shallow water context with reusable TS object
218: */
219: static PetscErrorCode ShallowWaterContextCreate(DM da, PetscInt n_vert, PetscReal L, PetscReal g, PetscReal dt, Ex3TestType test_type, Ex3FluxType flux_type, ShallowWaterCtx **ctx)
220: {
221:   ShallowWaterCtx *sw;

223:   PetscFunctionBeginUser;
224:   PetscCall(PetscNew(&sw));
225:   sw->da        = da;
226:   sw->n_vert    = n_vert;
227:   sw->L         = L;
228:   sw->g         = g;
229:   sw->dx        = L / n_vert; /* Domain is [0, L] */
230:   sw->dt        = dt;
231:   sw->test_type = test_type;
232:   sw->flux_type = flux_type;

234:   PetscCall(TSCreate(PetscObjectComm((PetscObject)da), &sw->ts));
235:   PetscCall(TSSetProblemType(sw->ts, TS_NONLINEAR));
236:   PetscCall(TSSetRHSFunction(sw->ts, NULL, ShallowWaterRHS, sw));
237:   PetscCall(TSSetType(sw->ts, TSRK));
238:   PetscCall(TSRKSetType(sw->ts, TSRK4));
239:   PetscCall(TSSetTimeStep(sw->ts, dt));
240:   PetscCall(TSSetMaxSteps(sw->ts, 1));
241:   PetscCall(TSSetMaxTime(sw->ts, dt));
242:   PetscCall(TSSetExactFinalTime(sw->ts, TS_EXACTFINALTIME_MATCHSTEP));
243:   PetscCall(TSSetFromOptions(sw->ts));
244:   /* Note: TSSetUp() will be called automatically by TSSolve() when needed */

246:   *ctx = sw;
247:   PetscFunctionReturn(PETSC_SUCCESS);
248: }

250: /*
251:   ShallowWaterContextDestroy - Destroy a shallow water context and its TS object
252: */
253: static PetscErrorCode ShallowWaterContextDestroy(ShallowWaterCtx **ctx)
254: {
255:   PetscFunctionBeginUser;
256:   if (!ctx || !*ctx) PetscFunctionReturn(PETSC_SUCCESS);
257:   PetscCall(TSDestroy(&(*ctx)->ts));
258:   PetscCall(PetscFree(*ctx));
259:   PetscFunctionReturn(PETSC_SUCCESS);
260: }

262: /*
263:   ShallowWaterStep - Advance state vector one time step using shallow water dynamics
264: */
265: static PetscErrorCode ShallowWaterStep(Vec input, Vec output, PetscCtx ctx)
266: {
267:   ShallowWaterCtx *sw = (ShallowWaterCtx *)ctx;

269:   PetscFunctionBeginUser;
270:   /* Copy input to output if they are different vectors */
271:   if (input != output) PetscCall(VecCopy(input, output));

273:   /* Reset the TS time for each integration (required for proper RK4 stepping) */
274:   PetscCall(TSSetTime(sw->ts, 0.0));
275:   PetscCall(TSSetStepNumber(sw->ts, 0));
276:   PetscCall(TSSetMaxTime(sw->ts, sw->dt));
277:   /* Solve one time step: advances output from t=0 to t=dt */
278:   PetscCall(TSSolve(sw->ts, output));
279:   PetscFunctionReturn(PETSC_SUCCESS);
280: }

282: /*
283:   ShallowWaterSolution_Dam - Smooth periodic "dam-like" initial condition

285:   Creates a smooth Gaussian bump compatible with periodic boundaries.
286:   This avoids boundary artifacts while maintaining dam-like evolution.
287: */
288: static PetscErrorCode ShallowWaterSolution_Dam(PetscReal L, PetscReal x, PetscReal *h, PetscReal *hu)
289: {
290:   const PetscReal h_mean = 1.5;      /* Mean water height */
291:   const PetscReal h_amp  = 0.4;      /* Bump amplitude */
292:   const PetscReal x_c    = 0.25 * L; /* Bump center */
293:   const PetscReal sigma  = 0.1 * L;  /* Gaussian width */

295:   PetscFunctionBeginUser;
296:   /* Smooth Gaussian bump: h = h_mean + h_amp * exp(-(x-x_c)^2/(2*sigma^2)) */
297:   PetscReal dx = x - x_c;
298:   /* Handle periodicity: use minimum distance on periodic domain */
299:   if (dx > 0.5 * L) dx -= L;
300:   if (dx < -0.5 * L) dx += L;

302:   *h = h_mean + h_amp * PetscExpReal(-dx * dx / (2.0 * sigma * sigma));
303:   /* Initially at rest */
304:   *hu = 0.0;
305:   PetscFunctionReturn(PETSC_SUCCESS);
306: }

308: /*
309:   ShallowWaterSolution_Wave - Traveling wave initial condition

311:   Sets smooth traveling wave with sinusoidal perturbation.
312:   For shallow water, a rightward-traveling wave requires velocity perturbation
313:   coupled to height: u' = c * (h'/h_mean) where c = sqrt(g*h_mean).
314: */
315: static PetscErrorCode ShallowWaterSolution_Wave(PetscReal L, PetscReal x, PetscReal *h, PetscReal *hu)
316: {
317:   const PetscReal h_mean = 1.5;                       /* Mean water height */
318:   const PetscReal h_amp  = 0.3;                       /* Wave amplitude */
319:   const PetscReal g      = DEFAULT_G;                 /* Gravitational constant */
320:   const PetscReal k      = 2.0 * PETSC_PI / L;        /* Wave number (one wavelength over domain) */
321:   const PetscReal c      = PetscSqrtReal(g * h_mean); /* Wave speed */

323:   PetscFunctionBeginUser;
324:   /* Height field: h = h_mean + h_amp * sin(k*x) */
325:   PetscReal h_pert = h_amp * PetscSinReal(k * x);
326:   *h               = h_mean + h_pert;

328:   /* Velocity for rightward-traveling wave: u = c * (h'/h_mean)
329:      Using linearized shallow water: u ~= (c/h_mean) * h_pert */
330:   PetscReal u = (c / h_mean) * h_pert;
331:   *hu         = (*h) * u;
332:   PetscFunctionReturn(PETSC_SUCCESS);
333: }

335: /*
336:   ShallowWaterSolution - Dispatch to appropriate initial condition based on test type
337: */
338: static PetscErrorCode ShallowWaterSolution(Ex3TestType test_type, PetscReal L, PetscReal x, PetscReal *h, PetscReal *hu)
339: {
340:   PetscFunctionBeginUser;
341:   switch (test_type) {
342:   case EX3_TEST_DAM:
343:     PetscCall(ShallowWaterSolution_Dam(L, x, h, hu));
344:     break;
345:   case EX3_TEST_WAVE:
346:     PetscCall(ShallowWaterSolution_Wave(L, x, h, hu));
347:     break;
348:   default:
349:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Unknown test type");
350:   }
351:   PetscFunctionReturn(PETSC_SUCCESS);
352: }

354: /*
355:   CreateObservationMatrix - Create observation matrix H for shallow water, and H1 as scalar version

357:   Observes water height (h) at every other grid point.
358:   This creates a sparse matrix mapping from full state (n_vert*ndof) to observations.
359:   For n_vert=80 grid points, we observe at points 0, 2, 4, ..., 78
360: */
361: static PetscErrorCode CreateObservationMatrix(PetscInt n_vert, PetscInt ndof, PetscInt nobs, Vec state, Mat *H, Mat *H1)
362: {
363:   PetscInt i, local_state_size;

365:   PetscFunctionBeginUser;
366:   PetscCheck(n_vert == 2 * nobs, PETSC_COMM_WORLD, PETSC_ERR_ARG_INCOMP, "Number of grid points (%" PetscInt_FMT ") must equal 2*nobs (%" PetscInt_FMT ")", n_vert, 2 * nobs);

368:   PetscCall(VecGetLocalSize(state, &local_state_size));

370:   /* Create observation matrix H (nobs x n_vert*ndof) */
371:   PetscCall(MatCreateAIJ(PETSC_COMM_WORLD, PETSC_DECIDE, local_state_size, nobs, n_vert * ndof, 1, NULL, 0, NULL, H));
372:   PetscCall(MatSetFromOptions(*H));

374:   PetscCall(MatCreateAIJ(PETSC_COMM_WORLD, PETSC_DECIDE, local_state_size / ndof, nobs, n_vert, 1, NULL, 0, NULL, H1));
375:   PetscCall(MatSetFromOptions(*H1));

377:   /* Observe water height (h) at every other grid point */
378:   for (i = 0; i < nobs; i++) {
379:     PetscInt grid_point = 2 * i; /* Observe at points 0, 2, 4, ... */
380:     PetscCall(MatSetValue(*H1, i, grid_point, 1.0, INSERT_VALUES));
381:     /* pick out the h component (first DOF) at that grid point */
382:     PetscCall(MatSetValue(*H, i, grid_point * ndof, 1.0, INSERT_VALUES));
383:   }

385:   PetscCall(MatAssemblyBegin(*H, MAT_FINAL_ASSEMBLY));
386:   PetscCall(MatAssemblyEnd(*H, MAT_FINAL_ASSEMBLY));
387:   PetscCall(MatAssemblyBegin(*H1, MAT_FINAL_ASSEMBLY));
388:   PetscCall(MatAssemblyEnd(*H1, MAT_FINAL_ASSEMBLY));

390:   PetscCall(MatViewFromOptions(*H1, NULL, "-H_view"));
391:   PetscFunctionReturn(PETSC_SUCCESS);
392: }

394: /*
395:   CreateLocalizationMatrix - Create and initialize localization matrix Q for shallow water

397:   Q is a (num_vert x obs_size) matrix that specifies which observations affect each state variable.
398:   For no localization (global assimilation), each state variable uses all observations.
399: */
400: static PetscErrorCode CreateLocalizationMatrix(PetscInt num_vert, PetscInt obs_size, Mat *Q)
401: {
402:   PetscInt i, j;

404:   PetscFunctionBeginUser;
405:   /* Create Q matrix (num_vert x obs_size)
406:      Each row will have obs_size non-zeros (all observations affect each state variable) */
407:   PetscCall(MatCreateAIJ(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, num_vert, obs_size, obs_size, NULL, 0, NULL, Q));
408:   PetscCall(MatSetFromOptions(*Q));

410:   /* Initialize with no localization (global): each state variable uses all observations */
411:   for (i = 0; i < num_vert; i++) {
412:     for (j = 0; j < obs_size; j++) PetscCall(MatSetValue(*Q, i, j, 1.0, INSERT_VALUES));
413:   }
414:   PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY));
415:   PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY));
416:   PetscFunctionReturn(PETSC_SUCCESS);
417: }

419: /*
420:   ValidateParameters - Validate input parameters and apply constraints
421: */
422: static PetscErrorCode ValidateParameters(PetscInt *n, PetscInt *nobs, PetscInt *steps, PetscInt *obs_freq, PetscInt *ensemble_size, PetscReal *dt, PetscReal *g, PetscReal *obs_error_std)
423: {
424:   PetscFunctionBeginUser;
425:   PetscCheck(*n > 0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "State dimension n must be positive, got %" PetscInt_FMT, *n);
426:   PetscCheck(*steps >= 0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Number of steps must be non-negative, got %" PetscInt_FMT, *steps);
427:   PetscCheck(*ensemble_size >= MIN_ENSEMBLE_SIZE, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size must be at least %d for meaningful statistics, got %" PetscInt_FMT, MIN_ENSEMBLE_SIZE, *ensemble_size);

429:   if (*obs_freq < MIN_OBS_FREQ) {
430:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Warning: Observation frequency adjusted from %" PetscInt_FMT " to %d\n", *obs_freq, MIN_OBS_FREQ));
431:     *obs_freq = MIN_OBS_FREQ;
432:   }
433:   if (*obs_freq > *steps && *steps > 0) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Warning: Observation frequency (%" PetscInt_FMT ") > total steps (%" PetscInt_FMT "), no observations will be assimilated.\n", *obs_freq, *steps));

435:   PetscCheck(*dt > 0.0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Time step dt must be positive, got %g", (double)*dt);
436:   PetscCheck(*obs_error_std > 0.0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Observation error std must be positive, got %g", (double)*obs_error_std);
437:   PetscCheck(PetscIsNormalReal(*g), PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Gravitational constant g must be a normal real number");
438:   PetscFunctionReturn(PETSC_SUCCESS);
439: }

441: /*
442:   ComputeRMSE - Compute root mean square error between two vectors
443: */
444: static PetscErrorCode ComputeRMSE(Vec v1, Vec v2, Vec work, PetscInt n, PetscReal *rmse)
445: {
446:   PetscReal norm;

448:   PetscFunctionBeginUser;
449:   PetscCall(VecWAXPY(work, -1.0, v2, v1));
450:   PetscCall(VecNorm(work, NORM_2, &norm));
451:   *rmse = norm / PetscSqrtReal((PetscReal)n);
452:   PetscFunctionReturn(PETSC_SUCCESS);
453: }

455: /* Forward declaration */
456: static PetscErrorCode Ex3TestFinalizePackage(void);

458: /* Test type setters */
459: static PetscErrorCode Ex3SetTest_Dam(Ex3TestType *test_type)
460: {
461:   PetscFunctionBeginUser;
462:   *test_type = EX3_TEST_DAM;
463:   PetscFunctionReturn(PETSC_SUCCESS);
464: }

466: static PetscErrorCode Ex3SetTest_Wave(Ex3TestType *test_type)
467: {
468:   PetscFunctionBeginUser;
469:   *test_type = EX3_TEST_WAVE;
470:   PetscFunctionReturn(PETSC_SUCCESS);
471: }

473: /* Package initialization */
474: static PetscErrorCode Ex3TestInitializePackage(void)
475: {
476:   PetscFunctionBeginUser;
477:   if (Ex3TestPackageInitialized) PetscFunctionReturn(PETSC_SUCCESS);
478:   Ex3TestPackageInitialized = PETSC_TRUE;
479:   PetscCall(PetscFunctionListAdd(&Ex3TestList, "dam", Ex3SetTest_Dam));
480:   PetscCall(PetscFunctionListAdd(&Ex3TestList, "wave", Ex3SetTest_Wave));
481:   PetscCall(PetscRegisterFinalize(Ex3TestFinalizePackage));
482:   PetscFunctionReturn(PETSC_SUCCESS);
483: }

485: static PetscErrorCode Ex3TestFinalizePackage(void)
486: {
487:   PetscFunctionBeginUser;
488:   Ex3TestPackageInitialized = PETSC_FALSE;
489:   PetscCall(PetscFunctionListDestroy(&Ex3TestList));
490:   PetscFunctionReturn(PETSC_SUCCESS);
491: }

493: int main(int argc, char **argv)
494: {
495:   /* Configuration parameters */
496:   const PetscInt ndof                    = 2; /* Degrees of freedom per grid point: h and hu */
497:   PetscInt       n_vert                  = DEFAULT_N;
498:   PetscInt       steps                   = DEFAULT_STEPS;
499:   PetscInt       obs_freq                = DEFAULT_OBS_FREQ;
500:   PetscInt       random_seed             = DEFAULT_RANDOM_SEED;
501:   PetscInt       ensemble_size           = DEFAULT_ENSEMBLE_SIZE;
502:   PetscInt       n_spin                  = SPINUP_STEPS;
503:   PetscInt       progress_freq           = DEFAULT_PROGRESS_FREQ;
504:   PetscReal      g                       = DEFAULT_G;
505:   PetscReal      dt                      = DEFAULT_DT;
506:   PetscReal      obs_error_std           = DEFAULT_OBS_ERROR_STD;
507:   PetscBool      use_fake_localization   = PETSC_FALSE;
508:   PetscInt       num_observations_vertex = 7;
509:   PetscReal      L                       = (PetscReal)DEFAULT_N; /* Domain length */
510:   PetscReal      bd[3]                   = {L, 0, 0};
511:   Ex3TestType    test_type               = EX3_TEST_DAM;     /* Default to dam-break */
512:   Ex3FluxType    flux_type               = EX3_FLUX_RUSANOV; /* Default to first-order Rusanov */
513:   char           output_file[PETSC_MAX_PATH_LEN];
514:   PetscBool      output_enabled = PETSC_FALSE;
515:   FILE          *fp             = NULL;

517:   /* PETSc objects */
518:   ShallowWaterCtx *sw_ctx = NULL;
519:   DM               da_state;
520:   PetscDA          da;
521:   Vec              x0, x_mean, x_forecast;
522:   Vec              truth_state, rmse_work;
523:   Vec              observation, obs_noise, obs_error_var;
524:   PetscRandom      rng;
525:   Mat              Q = NULL;            /* Localization matrix */
526:   Mat              H = NULL, H1 = NULL; /* Observation operator matrix (h at every other grid point) and scalar version */

528:   /* Statistics tracking */
529:   PetscReal rmse_forecast = 0.0, rmse_analysis = 0.0;
530:   PetscReal sum_rmse_forecast = 0.0, sum_rmse_analysis = 0.0;
531:   PetscInt  n_stat_steps = 0;
532:   PetscInt  obs_count    = 0;
533:   PetscInt  step;

535:   PetscFunctionBeginUser;
536:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
537:   /* Kokkos initialization deferred to Phase 5 optimization */

539:   /* Initialize test type package */
540:   PetscCall(Ex3TestInitializePackage());

542:   /* Parse command-line options */
543:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Shallow Water [L]ETKF Example", NULL);
544:   PetscCall(PetscOptionsInt("-n", "Number of grid points", "", n_vert, &n_vert, NULL));
545:   PetscCall(PetscOptionsInt("-steps", "Number of time steps", "", steps, &steps, NULL));
546:   PetscCall(PetscOptionsInt("-obs_freq", "Observation frequency", "", obs_freq, &obs_freq, NULL));
547:   PetscCall(PetscOptionsReal("-g", "Gravitational constant", "", g, &g, NULL));
548:   PetscCall(PetscOptionsReal("-dt", "Time step size", "", dt, &dt, NULL));
549:   PetscCall(PetscOptionsReal("-obs_error", "Observation error standard deviation", "", obs_error_std, &obs_error_std, NULL));
550:   PetscCall(PetscOptionsReal("-L", "Domain length", "", L, &L, NULL));
551:   bd[0] = L;
552:   PetscCall(PetscOptionsInt("-random_seed", "Random seed for ensemble perturbations", "", random_seed, &random_seed, NULL));
553:   PetscCall(PetscOptionsInt("-progress_freq", "Print progress every N steps (0 = only first/last)", "", progress_freq, &progress_freq, NULL));
554:   PetscCall(PetscOptionsString("-output_file", "Output file for visualization data", "", "", output_file, sizeof(output_file), &output_enabled));
555:   PetscCall(PetscOptionsBool("-use_fake_localization", "Use fake localization matrix", "", use_fake_localization, &use_fake_localization, NULL));
556:   if (!use_fake_localization) PetscCall(PetscOptionsInt("-petscda_letkf_obs_per_vertex", "Number of observations per vertex", "", num_observations_vertex, &num_observations_vertex, NULL));
557:   else num_observations_vertex = n_vert;
558:   /* Parse test type option */
559:   {
560:     char        testTypeName[256];
561:     const char *defaultType                 = "dam";
562:     PetscBool   set                         = PETSC_FALSE;
563:     PetscErrorCode (*setter)(Ex3TestType *) = NULL;

565:     PetscCall(PetscStrncpy(testTypeName, defaultType, sizeof(testTypeName)));
566:     PetscCall(PetscOptionsFList("-ex3_test", "Test case type", "Ex3SetTest", Ex3TestList, defaultType, testTypeName, sizeof(testTypeName), &set));
567:     if (set) {
568:       PetscCall(PetscFunctionListFind(Ex3TestList, testTypeName, &setter));
569:       PetscCheck(setter, PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Unknown test type \"%s\"", testTypeName);
570:       PetscCall((*setter)(&test_type));
571:     }
572:   }

574:   /* Parse flux type option */
575:   PetscCall(PetscOptionsEnum("-ex3_flux", "Flux scheme (rusanov/mc)", "", Ex3FluxTypes, (PetscEnum)flux_type, (PetscEnum *)&flux_type, NULL));
576:   n_spin = 0; /* No spinup needed for either test - dam evolves naturally, wave is already smooth */
577:   PetscOptionsEnd();

579:   /* LETKF constraint: nobs = n_vert/2, observe every other point */
580:   PetscInt nobs = n_vert / 2;

582:   /* Validate and constrain parameters */
583:   PetscCall(ValidateParameters(&n_vert, &nobs, &steps, &obs_freq, &ensemble_size, &dt, &g, &obs_error_std));

585:   /* Validate progress frequency */
586:   if (progress_freq < 0) {
587:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Warning: Progress frequency adjusted from %" PetscInt_FMT " to 0 (only first/last)\n", progress_freq));
588:     progress_freq = 0;
589:   }

591:   /* Create 1D periodic DM for state space with ndof=2 */
592:   PetscCall(DMDACreate1d(PETSC_COMM_WORLD, DM_BOUNDARY_PERIODIC, n_vert, ndof, 2, NULL, &da_state));
593:   PetscCall(DMSetFromOptions(da_state));
594:   PetscCall(DMSetUp(da_state));

596:   /* Create shallow water context with reusable TS object */
597:   PetscCall(ShallowWaterContextCreate(da_state, n_vert, L, g, dt, test_type, flux_type, &sw_ctx));

599:   /* Initialize random number generator */
600:   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rng));
601:   {
602:     PetscMPIInt rank;
603:     PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
604:     PetscCall(PetscRandomSetSeed(rng, (unsigned long)(random_seed + rank)));
605:   }
606:   PetscCall(PetscRandomSetFromOptions(rng));
607:   PetscCall(PetscRandomSeed(rng));

609:   /* Initialize state vectors */
610:   PetscCall(DMCreateGlobalVector(da_state, &x0));

612:   /* Set initial condition based on test type */
613:   {
614:     PetscScalar *x_array;
615:     PetscInt     xs, xm, i;
616:     PetscCall(DMDAGetCorners(da_state, &xs, NULL, NULL, &xm, NULL, NULL));
617:     PetscCall(DMDAVecGetArray(da_state, x0, &x_array));
618:     for (i = xs; i < xs + xm; i++) {
619:       PetscReal x = ((PetscReal)i + 0.5) * L / n_vert;
620:       PetscReal h, hu;
621:       PetscCall(ShallowWaterSolution(test_type, L, x, &h, &hu));
622:       x_array[i * ndof]     = h;
623:       x_array[i * ndof + 1] = hu;
624:     }
625:     PetscCall(DMDAVecRestoreArray(da_state, x0, &x_array));
626:   }

628:   /* Initialize truth trajectory */
629:   PetscCall(VecDuplicate(x0, &truth_state));
630:   PetscCall(VecCopy(x0, truth_state));
631:   PetscCall(VecDuplicate(x0, &rmse_work));

633:   /* Spinup if needed (not used by default - both tests start from their analytical initial conditions) */
634:   if (n_spin > 0) {
635:     PetscInt spinup_progress_interval = (n_spin >= 10) ? (n_spin / 10) : 1;
636:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Spinning up truth trajectory for %" PetscInt_FMT " steps...\n", n_spin));

638:     for (PetscInt k = 0; k < n_spin; k++) {
639:       PetscCall(ShallowWaterStep(truth_state, truth_state, sw_ctx));

641:       /* Progress reporting for long spinups */
642:       if ((k + 1) % spinup_progress_interval == 0 || (k + 1) == n_spin) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Spinup progress: %" PetscInt_FMT "/%" PetscInt_FMT " (%.0f%%)\n", k + 1, n_spin, 100.0 * (k + 1) / n_spin));
643:     }

645:     /* Update x0 to match spun-up state for consistent ensemble initialization */
646:     PetscCall(VecCopy(truth_state, x0));
647:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Spinup complete. Ensemble will be initialized from spun-up state.\n\n"));
648:   }

650:   /* Create observation matrix H, observing h at every other grid point) */
651:   PetscCall(CreateObservationMatrix(n_vert, ndof, nobs, x0, &H, &H1));

653:   /* Initialize observation vectors using MatCreateVecs from H (same as H1) */
654:   PetscCall(MatCreateVecs(H, NULL, &observation));
655:   PetscCall(VecDuplicate(observation, &obs_noise));
656:   PetscCall(VecDuplicate(observation, &obs_error_var));
657:   PetscCall(VecSet(obs_error_var, obs_error_std * obs_error_std));

659:   /* Create and configure PetscDA for ensemble data assimilation */
660:   PetscCall(PetscDACreate(PETSC_COMM_WORLD, &da));
661:   PetscCall(PetscDASetSizes(da, n_vert * ndof, nobs));  /* State size includes ndof */
662:   PetscCall(PetscDAEnsembleSetSize(da, ensemble_size)); /* State size includes ndof */
663:   {
664:     PetscInt local_state_size, local_obs_size;
665:     PetscCall(VecGetLocalSize(x0, &local_state_size));
666:     PetscCall(VecGetLocalSize(observation, &local_obs_size));
667:     PetscCall(PetscDASetLocalSizes(da, local_state_size, local_obs_size));
668:   }
669:   PetscCall(PetscDASetNDOF(da, ndof)); /* Set number of degrees of freedom per grid point */
670:   PetscCall(PetscDASetFromOptions(da));
671:   PetscCall(PetscDAEnsembleGetSize(da, &ensemble_size));
672:   PetscCall(PetscDASetUp(da));

674:   /* Initialize ensemble statistics vectors */
675:   PetscCall(VecDuplicate(x0, &x_mean));
676:   PetscCall(VecDuplicate(x0, &x_forecast));

678:   /* Set observation error variance */
679:   PetscCall(PetscDASetObsErrorVariance(da, obs_error_var));

681:   /* Create and set localization matrix Q */
682:   {
683:     PetscBool isletkf;
684:     PetscCall(PetscObjectTypeCompare((PetscObject)da, PETSCDALETKF, &isletkf));

686:     if (!use_fake_localization && isletkf) {
687:       Vec          Vecxyz[3] = {NULL, NULL, NULL};
688:       Vec          coord;
689:       DM           cda;
690:       PetscScalar *x_coord;
691:       PetscInt     xs, xm, i;

693:       /* Ensure coordinates are set */
694:       PetscCall(DMDASetUniformCoordinates(da_state, 0.0, L, 0.0, 0.0, 0.0, 0.0));
695:       /* Update coordinates to match cell centers as used in initial condition */
696:       PetscCall(DMGetCoordinateDM(da_state, &cda));
697:       PetscCall(DMGetCoordinates(da_state, &coord));
698:       PetscCall(DMDAGetCorners(cda, &xs, NULL, NULL, &xm, NULL, NULL));
699:       PetscCall(DMDAVecGetArray(cda, coord, &x_coord));
700:       for (i = xs; i < xs + xm; i++) x_coord[i] = ((PetscReal)i + 0.5) * L / n_vert;
701:       PetscCall(DMDAVecRestoreArray(cda, coord, &x_coord));

703:       /* Create Vecxyz[0] */
704:       PetscCall(DMCreateGlobalVector(cda, &Vecxyz[0]));
705:       PetscCall(VecSetFromOptions(Vecxyz[0]));
706:       PetscCall(PetscObjectSetName((PetscObject)Vecxyz[0], "x_coordinate"));
707:       PetscCall(VecCopy(coord, Vecxyz[0]));

709:       PetscCall(PetscDALETKFGetLocalizationMatrix(num_observations_vertex, 1, Vecxyz, bd, H1, &Q));
710:       PetscCall(VecDestroy(&Vecxyz[0]));
711:       PetscCall(PetscDALETKFSetObsPerVertex(da, num_observations_vertex));
712:       PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Localization matrix Q created using PetscDALETKFGetLocalizationMatrix\n"));
713:     } else {
714:       PetscCall(CreateLocalizationMatrix(n_vert, nobs, &Q));
715:       PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Localization matrix Q created: %dx%d, no localization/global (all weights = 1.0)\n", (int)n_vert, (int)nobs));
716:       if (isletkf) {
717:         PetscCall(PetscDALETKFSetObsPerVertex(da, num_observations_vertex)); // fully observed
718:       }
719:     }
720:     PetscCall(PetscDALETKFSetLocalization(da, Q, H));
721:     PetscCall(MatViewFromOptions(Q, NULL, "-Q_view"));
722:     PetscCall(MatDestroy(&Q));
723:   }

725:   /* Initialize ensemble members with perturbations around spun-up state
726:      This is critical for convergence - ensemble needs spread even after spinup */
727:   PetscCall(PetscDAEnsembleInitialize(da, x0, obs_error_std, rng));

729:   PetscCall(PetscDAViewFromOptions(da, NULL, "-petscda_view"));

731:   /* Print configuration summary */
732:   {
733:     const char *test_name = (test_type == EX3_TEST_DAM) ? "Dam-break" : "Traveling wave";
734:     const char *flux_name = (flux_type == EX3_FLUX_RUSANOV) ? "Rusanov (1st order)" : "MC (2nd order)";
735:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Shallow Water [L]ETKF Example\n"));
736:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "============================\n"));
737:     PetscCall(PetscPrintf(PETSC_COMM_WORLD,
738:                           "  Test case             : %s\n"
739:                           "  Flux scheme           : %s\n"
740:                           "  State dimension       : %" PetscInt_FMT " (%" PetscInt_FMT " grid points x %d DOF)\n"
741:                           "  Observation dimension : %" PetscInt_FMT "\n"
742:                           "  Ensemble size         : %" PetscInt_FMT "\n"
743:                           "  Domain length (L)     : %.4f\n"
744:                           "  Gravitational const   : %.4f\n"
745:                           "  Time step (dt)        : %.4f\n"
746:                           "  Total steps           : %" PetscInt_FMT "\n"
747:                           "  Observation frequency : %" PetscInt_FMT "\n"
748:                           "  Observation noise std : %.3f\n"
749:                           "  Random seed           : %" PetscInt_FMT "\n"
750:                           "  Localization          : None/Global (%d obs per vertex)\n\n",
751:                           test_name, flux_name, n_vert * ndof, n_vert, (int)ndof, nobs, ensemble_size, (double)L, (double)g, (double)dt, steps, obs_freq, (double)obs_error_std, random_seed, num_observations_vertex));
752:   }

754:   /* Open output file if requested */
755:   if (output_enabled) {
756:     PetscCall(PetscFOpen(PETSC_COMM_WORLD, output_file, "w", &fp));
757:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# Shallow Water [L]ETKF Data Assimilation Output\n"));
758:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# Test case: %s\n", (test_type == EX3_TEST_DAM) ? "Dam-break" : "Traveling wave"));
759:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# n_vert=%d, ndof=%d, nobs=%d, ensemble_size=%d\n", (int)n_vert, (int)ndof, (int)nobs, (int)ensemble_size));
760:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# dt=%.6f, g=%.6f, obs_error_std=%.6f\n", (double)dt, (double)g, (double)obs_error_std));
761:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# Format: step time [truth_h truth_hu]x%d [mean_h mean_hu]x%d [obs]x%d\n", (int)n_vert, (int)n_vert, (int)nobs));
762:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Writing output to: %s\n\n", output_file));

764:     /* Write initial condition (step 0) */
765:     const PetscScalar *truth_array, *mean_array;
766:     PetscInt           i;

768:     /* Compute initial ensemble mean */
769:     PetscCall(PetscDAEnsembleComputeMean(da, x_mean));

771:     PetscCall(DMDAVecGetArrayRead(da_state, truth_state, &truth_array));
772:     PetscCall(DMDAVecGetArrayRead(da_state, x_mean, &mean_array));

774:     /* Write step 0 and time 0 */
775:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "0 0.000000"));

777:     /* Write truth state (h, hu for each grid point) */
778:     for (i = 0; i < n_vert * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(truth_array[i])));

780:     /* Write ensemble mean (h, hu for each grid point) */
781:     for (i = 0; i < n_vert * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(mean_array[i])));

783:     /* Write nan for observations (no observations at step 0) */
784:     for (i = 0; i < nobs; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " nan"));

786:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "\n"));

788:     PetscCall(DMDAVecRestoreArrayRead(da_state, truth_state, &truth_array));
789:     PetscCall(DMDAVecRestoreArrayRead(da_state, x_mean, &mean_array));
790:   }

792:   /* Print initial condition (step 0) */
793:   {
794:     PetscReal rmse_initial;
795:     PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
796:     PetscCall(ComputeRMSE(x_mean, truth_state, rmse_work, n_vert * ndof, &rmse_initial));
797:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Step %4d, time %6.3f  RMSE_forecast %.5f  RMSE_analysis %.5f [initial]\n", 0, 0.0, (double)rmse_initial, (double)rmse_initial));
798:   }

800:   /* Main assimilation cycle: forecast and analysis steps */
801:   for (step = 1; step <= steps; step++) {
802:     PetscReal time = step * dt;

804:     /* Propagate ensemble and truth trajectory from t_{k-1} to t_k */
805:     PetscCall(PetscDAEnsembleForecast(da, ShallowWaterStep, sw_ctx));
806:     PetscCall(ShallowWaterStep(truth_state, truth_state, sw_ctx));

808:     /* Forecast step: compute ensemble mean and forecast RMSE */
809:     PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
810:     PetscCall(VecCopy(x_mean, x_forecast));
811:     PetscCall(ComputeRMSE(x_forecast, truth_state, rmse_work, n_vert * ndof, &rmse_forecast));
812:     rmse_analysis = rmse_forecast;

814:     /* Analysis step: assimilate observations when available */
815:     if (step % obs_freq == 0 && step > 0) {
816:       /* Generate synthetic noisy observations from truth using observation matrix H */
817:       Vec truth_obs, temp_truth;
818:       PetscCall(MatCreateVecs(H, NULL, &truth_obs));
819:       PetscCall(MatCreateVecs(H, &temp_truth, NULL));

821:       /* Apply H to get observations: y = H*x_true
822:          Use temporary vector compatible with H's type to avoid Kokkos vector type issues */
823:       PetscCall(VecCopy(truth_state, temp_truth));
824:       PetscCall(MatMult(H, temp_truth, truth_obs));

826:       /* Add observation noise */
827:       PetscCall(VecSetRandomGaussian(obs_noise, rng, 0.0, obs_error_std));
828:       PetscCall(VecWAXPY(observation, 1.0, obs_noise, truth_obs));

830:       /* Perform LETKF analysis with observation matrix H */
831:       PetscCall(PetscDAEnsembleAnalysis(da, observation, H));

833:       /* Clean up */
834:       PetscCall(VecDestroy(&temp_truth));
835:       PetscCall(VecDestroy(&truth_obs));

837:       /* Compute analysis RMSE */
838:       PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
839:       PetscCall(ComputeRMSE(x_mean, truth_state, rmse_work, n_vert * ndof, &rmse_analysis));
840:       obs_count++;
841:     }

843:     /* Accumulate statistics */
844:     sum_rmse_forecast += rmse_forecast;
845:     sum_rmse_analysis += rmse_analysis;
846:     n_stat_steps++;

848:     /* Write data to output file if enabled */
849:     if (output_enabled && fp) {
850:       const PetscScalar *truth_array, *mean_array, *obs_array;
851:       PetscInt           i;

853:       PetscCall(DMDAVecGetArrayRead(da_state, truth_state, &truth_array));
854:       PetscCall(DMDAVecGetArrayRead(da_state, x_mean, &mean_array));

856:       /* Write step and time */
857:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "%d %.6f", (int)step, (double)time));

859:       /* Write truth state (h, hu for each grid point) */
860:       for (i = 0; i < n_vert * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(truth_array[i])));

862:       /* Write ensemble mean (h, hu for each grid point) */
863:       for (i = 0; i < n_vert * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(mean_array[i])));

865:       /* Write observations (or nan if no observation at this step) */
866:       if (step % obs_freq == 0 && step > 0) {
867:         PetscCall(VecGetArrayRead(observation, &obs_array));
868:         for (i = 0; i < nobs; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(obs_array[i])));
869:         PetscCall(VecRestoreArrayRead(observation, &obs_array));
870:       } else {
871:         for (i = 0; i < nobs; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " nan"));
872:       }

874:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "\n"));

876:       PetscCall(DMDAVecRestoreArrayRead(da_state, truth_state, &truth_array));
877:       PetscCall(DMDAVecRestoreArrayRead(da_state, x_mean, &mean_array));
878:     }

880:     /* Progress reporting */
881:     if (progress_freq == 0) {
882:       /* Only print first and last steps */
883:       if (step == 0 || step == steps) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Step %4" PetscInt_FMT ", time %6.3f  RMSE_forecast %.5f  RMSE_analysis %.5f\n", step, (double)time, (double)rmse_forecast, (double)rmse_analysis));
884:     } else {
885:       /* Print every progress_freq steps, plus first and last */
886:       if ((step % progress_freq == 0) || (step == steps)) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Step %4" PetscInt_FMT ", time %6.3f  RMSE_forecast %.5f  RMSE_analysis %.5f\n", step, (double)time, (double)rmse_forecast, (double)rmse_analysis));
887:     }
888:   }

890:   /* Report final statistics */
891:   if (n_stat_steps > 0) {
892:     PetscReal avg_rmse_forecast = sum_rmse_forecast / n_stat_steps;
893:     PetscReal avg_rmse_analysis = sum_rmse_analysis / n_stat_steps;
894:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\nStatistics (%" PetscInt_FMT " steps):\n", n_stat_steps));
895:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "==================================================\n"));
896:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Mean RMSE (forecast) : %.5f\n", (double)avg_rmse_forecast));
897:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Mean RMSE (analysis) : %.5f\n", (double)avg_rmse_analysis));
898:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Observations used    : %" PetscInt_FMT "\n\n", obs_count));
899:   } else {
900:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\nWarning: No statistics collected\n\n"));
901:   }

903:   /* Close output file if opened */
904:   if (output_enabled && fp) {
905:     PetscCall(PetscFClose(PETSC_COMM_WORLD, fp));
906:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output written to: %s\n", output_file));
907:   }

909:   /* Cleanup */
910:   PetscCall(MatDestroy(&H));
911:   PetscCall(MatDestroy(&H1));
912:   PetscCall(VecDestroy(&x_forecast));
913:   PetscCall(VecDestroy(&x_mean));
914:   PetscCall(VecDestroy(&obs_error_var));
915:   PetscCall(VecDestroy(&obs_noise));
916:   PetscCall(VecDestroy(&observation));
917:   PetscCall(VecDestroy(&rmse_work));
918:   PetscCall(VecDestroy(&truth_state));
919:   PetscCall(VecDestroy(&x0));
920:   PetscCall(PetscDADestroy(&da));
921:   PetscCall(DMDestroy(&da_state));
922:   PetscCall(ShallowWaterContextDestroy(&sw_ctx));
923:   PetscCall(PetscRandomDestroy(&rng));

925:   PetscCall(PetscFinalize());
926:   return 0;
927: }

929: /*TEST

931:   testset:
932:     requires: kokkos_kernels !complex
933:     diff_args: -j
934:     args: -ex3_test dam -steps 10 -progress_freq 1 -petscda_view -petscda_ensemble_size 10 -obs_freq 2 -obs_error 0.03

936:     test:
937:       suffix: letkf_dam
938:       args: -petscda_type letkf -petscda_ensemble_size 7

940:     test:
941:       suffix: etkf_dam
942:       args: -petscda_ensemble_sqrt_type cholesky -petscda_type etkf

944:     test:
945:       nsize: 3
946:       suffix: kokkos_dam
947:       args: -petscda_type letkf -mat_type aijkokkos -vec_type kokkos -petscda_letkf_batch_size 13 -info :vec -petscda_ensemble_size 5 -petscda_letkf_obs_per_vertex 5

949:   testset:
950:     requires: kokkos_kernels !complex
951:     diff_args: -j
952:     args: -ex3_test wave -steps 10 -petscda_view -petscda_ensemble_size 10 -petscda_type letkf -obs_freq 2 -obs_error 0.03

954:     test:
955:       suffix: letkf_wave
956:       args: -petscda_type letkf -petscda_ensemble_size 5

958:     test:
959:       nsize: 3
960:       suffix: kokkos_wave
961:       args: -petscda_type letkf -mat_type aijkokkos -vec_type kokkos -petscda_letkf_batch_size 13 -info :vec -petscda_ensemble_size 5 -petscda_letkf_obs_per_vertex 5

963:     test:
964:       suffix: wave_mc
965:       args: -ex3_flux mc -petscda_type etkf

967: TEST*/