Actual source code: ex4.c

  1: static char help[] = "2D shallow water LETKF data assimilation example.\n"
  2:                      "Implements 2D shallow water equations with 3 DOF per grid point (h, hu, hv).\n\n"
  3:                      "Usage:\n"
  4:                      "  ./ex4 -steps 100 -nx 41 -ny 41 -petscda_type letkf -petscda_ensemble_size 30\n\n";

  6: #include <petscda.h>
  7: #include <petscdmda.h>
  8: #include <petscts.h>

 10: #include "ex4.h"

 12: /* Default parameter values */
 13: #define DEFAULT_NX                     40
 14: #define DEFAULT_NY                     40
 15: #define DEFAULT_STEPS                  100
 16: #define DEFAULT_OBS_FREQ               5
 17: #define DEFAULT_RANDOM_SEED            12345
 18: #define DEFAULT_G                      9.81
 19: #define DEFAULT_DT                     0.02
 20: #define DEFAULT_LX                     80.0
 21: #define DEFAULT_LY                     80.0
 22: #define DEFAULT_H0                     1.5
 23: #define DEFAULT_AX                     0.2
 24: #define DEFAULT_AY                     0.2
 25: #define DEFAULT_OBS_ERROR_STD          0.01
 26: #define DEFAULT_INIT_PERTURB_AMPLITUDE 0.05
 27: #define DEFAULT_INIT_H_BIAS            0.0
 28: #define DEFAULT_ENSEMBLE_SIZE          30
 29: #define DEFAULT_PROGRESS_FREQ          10
 30: #define DEFAULT_OBS_STRIDE             2
 31: #define DEFAULT_LOCALIZATION_RADIUS    20.0 /* kernel half-width; effective cutoff is 2*radius for gaspari_cohn/gaussian and radius for boxcar (~10 obs spacings on the default 80x80 domain with obs_stride=2) */
 32: #define SPINUP_STEPS                   0

 34: /* Minimum valid parameter values */
 35: #define MIN_ENSEMBLE_SIZE 2
 36: #define MIN_OBS_FREQ      1

 38: /*
 39:   ShallowWaterStep2D - PetscDAEnsembleForecast callback that advances every column of a dense ensemble Mat by one TS step.
 40: */
 41: static PetscErrorCode ShallowWaterStep2D(Mat ensemble, PetscCtx ctx)
 42: {
 43:   ShallowWater2DCtx *sw = (ShallowWater2DCtx *)ctx;
 44:   PetscInt           n;

 46:   PetscFunctionBeginUser;
 47:   PetscCall(MatGetSize(ensemble, NULL, &n));
 48:   /* Collective: dense ensemble Mat is row-distributed, so every rank visits every global column j and
 49:      MatDenseGetColumnVec returns the parallel column-Vec that all ranks step together. Use the
 50:      read-write variant because TSSolve reads the column as the initial condition and writes the
 51:      stepped solution back; the write-only variant would skip the device->host sync on device-backed
 52:      dense (MATSEQDENSECUDA/HIP, MATMPIDENSECUDA/HIP) and feed TSSolve stale host data.
 53:      MMS forcing is not exercised here: main() leaves cfg.verify_mms = PETSC_FALSE, so t_start = 0.0
 54:      is correct for this autonomous RHS. */
 55:   for (PetscInt j = 0; j < n; j++) {
 56:     Vec col;

 58:     PetscCall(MatDenseGetColumnVec(ensemble, j, &col));
 59:     PetscCall(ShallowWaterStep2DVec(sw, 0.0, col));
 60:     PetscCall(MatDenseRestoreColumnVec(ensemble, j, &col));
 61:   }
 62:   PetscFunctionReturn(PETSC_SUCCESS);
 63: }

 65: /*
 66:   CreateObservationMatrix2D - Create observation matrix H for 2D shallow water

 68:   Observes water height (h) at every obs_stride-th grid point in both x and y directions.
 69: */
 70: static PetscErrorCode CreateObservationMatrix2D(PetscInt nx, PetscInt ny, PetscInt ndof, PetscInt obs_stride, PetscInt local_state_size, Mat *H, Mat *H1, PetscInt *nobs_out)
 71: {
 72:   PetscInt i, j, obs_idx;
 73:   PetscInt nobs_x, nobs_y, nobs;
 74:   PetscInt rstart, rend;

 76:   PetscFunctionBeginUser;
 77:   /* Calculate number of observations */
 78:   nobs_x = (nx + obs_stride - 1) / obs_stride;
 79:   nobs_y = (ny + obs_stride - 1) / obs_stride;
 80:   nobs   = nobs_x * nobs_y;

 82:   /* Create observation matrix H (nobs x nx*ny*ndof). The column local size must match
 83:      the DMDA-partitioned state vector so that MatMult(H, state, obs) is well-defined. */
 84:   PetscCall(MatCreateAIJ(PETSC_COMM_WORLD, PETSC_DECIDE, local_state_size, nobs, nx * ny * ndof, 1, NULL, 1, NULL, H));
 85:   PetscCall(MatSetFromOptions(*H));

 87:   /* Create H1 for scalar field (nobs x nx*ny); the column local size must match the
 88:      per-grid-point coordinate vectors used for localization. */
 89:   PetscCall(MatCreateAIJ(PETSC_COMM_WORLD, PETSC_DECIDE, local_state_size / ndof, nobs, nx * ny, 1, NULL, 1, NULL, H1));
 90:   PetscCall(MatSetFromOptions(*H1));

 92:   /* Get row ownership range for local process */
 93:   PetscCall(MatGetOwnershipRange(*H, &rstart, &rend));

 95:   /* Observe water height (h) at sparse grid locations - only set local rows */
 96:   obs_idx = 0;
 97:   for (j = 0; j < ny; j += obs_stride) {
 98:     for (i = 0; i < nx; i += obs_stride) {
 99:       if (obs_idx >= rstart && obs_idx < rend) {
100:         PetscInt grid_idx = j * nx + i;
101:         /* H1: select grid point */
102:         PetscCall(MatSetValue(*H1, obs_idx, grid_idx, 1.0, INSERT_VALUES));
103:         /* H: select h component (first DOF) at that grid point */
104:         PetscCall(MatSetValue(*H, obs_idx, grid_idx * ndof, 1.0, INSERT_VALUES));
105:       }
106:       obs_idx++;
107:     }
108:   }

110:   PetscCall(MatAssemblyBegin(*H, MAT_FINAL_ASSEMBLY));
111:   PetscCall(MatAssemblyEnd(*H, MAT_FINAL_ASSEMBLY));
112:   PetscCall(MatAssemblyBegin(*H1, MAT_FINAL_ASSEMBLY));
113:   PetscCall(MatAssemblyEnd(*H1, MAT_FINAL_ASSEMBLY));

115:   PetscCall(MatViewFromOptions(*H1, NULL, "-H_view"));
116:   *nobs_out = nobs;
117:   PetscFunctionReturn(PETSC_SUCCESS);
118: }

120: /*
121:   ComputeRMSE - Compute root mean square error between two vectors
122: */
123: static PetscErrorCode ComputeRMSE(Vec v1, Vec v2, Vec work, PetscInt n, PetscReal *rmse)
124: {
125:   PetscReal norm;

127:   PetscFunctionBeginUser;
128:   PetscCall(VecWAXPY(work, -1.0, v2, v1));
129:   PetscCall(VecNorm(work, NORM_2, &norm));
130:   *rmse = norm / PetscSqrtReal((PetscReal)n);
131:   PetscFunctionReturn(PETSC_SUCCESS);
132: }

134: static PetscErrorCode InitializeBalancedEnsemble(PetscDA da, DM da_state, ShallowWater2DCtx *sw_ctx, PetscInt random_seed, PetscInt ensemble_size, PetscReal init_perturb_amplitude, PetscReal init_h_bias)
135: {
136:   Vec         member;
137:   PetscRandom coef_rng;
138:   PetscInt    nx = sw_ctx->nx, ny = sw_ctx->ny;
139:   PetscReal  *alpha_x, *alpha_y, *beta_x, *beta_y;
140:   PetscReal   mean_ax = 0.0, mean_ay = 0.0, mean_bx = 0.0, mean_by = 0.0;
141:   PetscReal   Lx = sw_ctx->Lx, Ly = sw_ctx->Ly, g = sw_ctx->g, h0 = sw_ctx->h0, Ax = sw_ctx->Ax, Ay = sw_ctx->Ay;

143:   PetscFunctionBeginUser;
144:   PetscCall(PetscMalloc4(ensemble_size, &alpha_x, ensemble_size, &alpha_y, ensemble_size, &beta_x, ensemble_size, &beta_y));

146:   /* The per-member coefficients are sampled redundantly on every rank because they parameterize a
147:      globally smooth perturbation written into a parallel Vec. Use a PETSC_COMM_SELF rng with a
148:      rank-independent seed so that all ranks observe identical coefficient sequences and the
149:      ensemble member fields remain continuous across rank partitions. */
150:   PetscCall(PetscRandomCreate(PETSC_COMM_SELF, &coef_rng));
151:   PetscCall(PetscRandomSetSeed(coef_rng, (unsigned long)random_seed));
152:   PetscCall(PetscRandomSeed(coef_rng));

154:   for (PetscInt e = 0; e < ensemble_size; e++) {
155:     PetscReal r;

157:     PetscCall(PetscRandomGetValueReal(coef_rng, &r));
158:     alpha_x[e] = init_perturb_amplitude * r;
159:     mean_ax += alpha_x[e];
160:     PetscCall(PetscRandomGetValueReal(coef_rng, &r));
161:     alpha_y[e] = init_perturb_amplitude * r;
162:     mean_ay += alpha_y[e];
163:     PetscCall(PetscRandomGetValueReal(coef_rng, &r));
164:     beta_x[e] = init_perturb_amplitude * r;
165:     mean_bx += beta_x[e];
166:     PetscCall(PetscRandomGetValueReal(coef_rng, &r));
167:     beta_y[e] = init_perturb_amplitude * r;
168:     mean_by += beta_y[e];
169:   }
170:   PetscCall(PetscRandomDestroy(&coef_rng));

172:   mean_ax /= ensemble_size;
173:   mean_ay /= ensemble_size;
174:   mean_bx /= ensemble_size;
175:   mean_by /= ensemble_size;

177:   PetscCall(DMCreateGlobalVector(da_state, &member));
178:   for (PetscInt e = 0; e < ensemble_size; e++) {
179:     PetscScalar ***x_array;
180:     PetscInt       xs, ys, xm, ym;
181:     PetscReal      axp = alpha_x[e] - mean_ax;
182:     PetscReal      ayp = alpha_y[e] - mean_ay;
183:     PetscReal      bxp = beta_x[e] - mean_bx;
184:     PetscReal      byp = beta_y[e] - mean_by;
185:     PetscReal      dx  = Lx / nx;
186:     PetscReal      dy  = Ly / ny;
187:     PetscReal      kx  = 2.0 * PETSC_PI / Lx;
188:     PetscReal      ky  = 2.0 * PETSC_PI / Ly;
189:     PetscReal      c   = PetscSqrtReal(g * h0);

191:     PetscCall(DMDAGetCorners(da_state, &xs, &ys, NULL, &xm, &ym, NULL));
192:     PetscCall(DMDAVecGetArrayDOFWrite(da_state, member, &x_array));
193:     for (PetscInt j = ys; j < ys + ym; j++) {
194:       for (PetscInt i = xs; i < xs + xm; i++) {
195:         PetscReal x  = ((PetscReal)i + 0.5) * dx;
196:         PetscReal y  = ((PetscReal)j + 0.5) * dy;
197:         PetscReal sx = PetscSinReal(kx * x), cx = PetscCosReal(kx * x);
198:         PetscReal sy = PetscSinReal(ky * y), cy = PetscCosReal(ky * y);
199:         PetscReal eta_x = axp * sx + bxp * cx;
200:         PetscReal eta_y = ayp * sy + byp * cy;
201:         PetscReal eta   = eta_x + eta_y;
202:         PetscReal u     = (c / h0) * (axp * cx - bxp * sx);
203:         PetscReal v     = (c / h0) * (ayp * cy - byp * sy);
204:         PetscReal hbase, hubase, hvbase;

206:         ShallowWaterSolution_Wave2D(Lx, Ly, x, y, 0.0, g, h0, Ax, Ay, &hbase, &hubase, &hvbase);
207:         x_array[j][i][0] = hbase + init_h_bias + eta;
208:         x_array[j][i][1] = hubase + h0 * u;
209:         x_array[j][i][2] = hvbase + h0 * v;
210:       }
211:     }
212:     PetscCall(DMDAVecRestoreArrayDOFWrite(da_state, member, &x_array));
213:     PetscCall(PetscDAEnsembleSetMember(da, e, member));
214:   }

216:   PetscCall(VecDestroy(&member));
217:   PetscCall(PetscFree4(alpha_x, alpha_y, beta_x, beta_y));
218:   PetscFunctionReturn(PETSC_SUCCESS);
219: }

221: /*
222:   ValidateParameters - Validate input parameters
223: */
224: static PetscErrorCode ValidateParameters(PetscInt *nx, PetscInt *ny, PetscInt *steps, PetscInt *obs_freq, PetscInt *ensemble_size, PetscReal *dt, PetscReal *g, PetscReal *obs_error_std)
225: {
226:   PetscFunctionBeginUser;
227:   PetscCheck(*nx > 0 && *ny > 0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Grid dimensions must be positive");
228:   PetscCheck(*steps >= 0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Number of steps must be non-negative");
229:   PetscCheck(*ensemble_size >= MIN_ENSEMBLE_SIZE, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size must be at least %d", MIN_ENSEMBLE_SIZE);

231:   if (*obs_freq < MIN_OBS_FREQ) {
232:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Warning: Observation frequency adjusted from %" PetscInt_FMT " to %d\n", *obs_freq, MIN_OBS_FREQ));
233:     *obs_freq = MIN_OBS_FREQ;
234:   }
235:   if (*obs_freq > *steps && *steps > 0) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Warning: Observation frequency > total steps, no observations will be assimilated.\n"));

237:   PetscCheck(*dt > 0.0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Time step must be positive");
238:   PetscCheck(*obs_error_std > 0.0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Observation error std must be positive");
239:   PetscCheck(PetscIsNormalReal(*g), PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Gravitational constant must be a normal real number");
240:   PetscFunctionReturn(PETSC_SUCCESS);
241: }

243: /* Wire LETKF localization (kernel type already chosen by SetFromOptions): build per-grid-point
244:    coordinate vectors, register them, honor a user-supplied radius via the options database, and
245:    print a summary. The NONE kernel needs no setup and is handled by the outer guard. */
246: static PetscErrorCode ConfigureLETKFLocalization(PetscDA da, DM da_state, PetscInt nx, PetscInt ny, PetscInt ndof, PetscInt local_state_size, PetscReal Lx, PetscReal Ly, Mat H1, PetscReal *localization_radius)
247: {
248:   PetscDALETKFLocalizationType loc_type;
249:   Vec                          xyz[3] = {NULL, NULL, NULL};
250:   Vec                          coord;
251:   DM                           cda;
252:   PetscReal                    bd[3] = {Lx, Ly, 0.0};
253:   PetscBool                    radius_set;
254:   const char                  *da_prefix, *kname = NULL;

256:   PetscFunctionBeginUser;
257:   PetscCall(PetscDALETKFGetLocalizationType(da, &loc_type));
258:   if (loc_type == PETSCDA_LETKF_LOC_NONE) {
259:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Localization disabled (LETKF NONE; equivalent to global ETKF)\n"));
260:     PetscFunctionReturn(PETSC_SUCCESS);
261:   }

263:   PetscCall(DMDASetUniformCoordinates(da_state, 0.0, Lx, 0.0, Ly, 0.0, 0.0));
264:   PetscCall(DMGetCoordinateDM(da_state, &cda));
265:   PetscCall(DMGetCoordinates(da_state, &coord));

267:   /* xyz must share the DMDA per-grid-point partition so that VecStrideGather from
268:      the (block-2) coordinate vector lands in matching local rows. */
269:   for (PetscInt d = 0; d < 2; d++) {
270:     PetscCall(VecCreate(PETSC_COMM_WORLD, &xyz[d]));
271:     PetscCall(VecSetSizes(xyz[d], local_state_size / ndof, nx * ny));
272:     PetscCall(VecSetFromOptions(xyz[d]));
273:     PetscCall(PetscObjectSetName((PetscObject)xyz[d], d == 0 ? "x_coordinate" : "y_coordinate"));
274:     PetscCall(VecStrideGather(coord, d, xyz[d], INSERT_VALUES));
275:   }

277:   PetscCall(PetscObjectGetOptionsPrefix((PetscObject)da, &da_prefix));
278:   PetscCall(PetscOptionsHasName(NULL, da_prefix, "-petscda_letkf_localization_radius", &radius_set));
279:   if (!radius_set) PetscCall(PetscDALETKFSetLocalizationRadius(da, *localization_radius));
280:   PetscCall(PetscDALETKFGetLocalizationRadius(da, localization_radius));
281:   PetscCall(PetscDALETKFSetLocalizationCoordinates(da, xyz, bd, H1));
282:   for (PetscInt d = 0; d < 2; d++) PetscCall(VecDestroy(&xyz[d]));

284:   switch (loc_type) {
285:   case PETSCDA_LETKF_LOC_GASPARI_COHN:
286:     kname = "Gaspari-Cohn";
287:     break;
288:   case PETSCDA_LETKF_LOC_GAUSSIAN:
289:     kname = "Gaussian";
290:     break;
291:   case PETSCDA_LETKF_LOC_BOXCAR:
292:     kname = "boxcar";
293:     break;
294:   case PETSCDA_LETKF_LOC_NONE:
295:   case PETSCDA_LETKF_LOC_NUM_TYPES:
296:     break;
297:   }
298:   PetscCheck(kname, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Unexpected localization type %d", (int)loc_type);
299:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Using %s localization with radius %g\n", kname, (double)*localization_radius));
300:   PetscFunctionReturn(PETSC_SUCCESS);
301: }

303: int main(int argc, char **argv)
304: {
305:   ShallowWater2DCtx   *sw_ctx = NULL;
306:   ShallowWater2DConfig cfg;
307:   DM                   da_state;
308:   PetscDA              da;
309:   Vec                  x0, x_mean, x_forecast;
310:   Vec                  truth_state, rmse_work;
311:   Vec                  observation, obs_noise, obs_error_var;
312:   Mat                  H, H1;
313:   PetscRandom          rng;
314:   FILE                *fp = NULL;
315:   char                 output_file[PETSC_MAX_PATH_LEN];
316:   const PetscInt       ndof = 3; /* h, hu, hv */
317:   PetscInt             nx = DEFAULT_NX, ny = DEFAULT_NY;
318:   PetscInt             steps = DEFAULT_STEPS, obs_freq = DEFAULT_OBS_FREQ, obs_stride = DEFAULT_OBS_STRIDE;
319:   PetscInt             ensemble_size = DEFAULT_ENSEMBLE_SIZE, n_spin = SPINUP_STEPS, random_seed = DEFAULT_RANDOM_SEED;
320:   PetscInt             progress_freq = DEFAULT_PROGRESS_FREQ;
321:   PetscInt             nobs = 0, n_stat_steps = 0, obs_count = 0, step;
322:   PetscInt             local_state_size, local_obs_size;
323:   PetscReal            g = DEFAULT_G, dt = DEFAULT_DT;
324:   PetscReal            Lx = DEFAULT_LX, Ly = DEFAULT_LY, h0 = DEFAULT_H0;
325:   PetscReal            Ax = DEFAULT_AX, Ay = DEFAULT_AY;
326:   PetscReal            init_perturb_amplitude = DEFAULT_INIT_PERTURB_AMPLITUDE, init_h_bias = DEFAULT_INIT_H_BIAS;
327:   PetscReal            obs_error_std       = DEFAULT_OBS_ERROR_STD;
328:   PetscReal            localization_radius = DEFAULT_LOCALIZATION_RADIUS;
329:   PetscReal            rmse_initial = 0.0, rmse_forecast = 0.0, rmse_analysis = 0.0, truth_time = 0.0;
330:   PetscReal            sum_rmse_forecast = 0.0, sum_rmse_analysis = 0.0;
331:   PetscReal            dx, dy, c, cfl;
332:   PetscBool            output_enabled = PETSC_FALSE;
333:   PetscMPIInt          rank;

335:   PetscFunctionBeginUser;
336:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));

338:   /* Parse command-line options */
339:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "2D Shallow Water LETKF Example", NULL);
340:   PetscCall(PetscOptionsInt("-nx", "Number of grid points in x", "", nx, &nx, NULL));
341:   PetscCall(PetscOptionsInt("-ny", "Number of grid points in y", "", ny, &ny, NULL));
342:   PetscCall(PetscOptionsInt("-steps", "Number of time steps", "", steps, &steps, NULL));
343:   PetscCall(PetscOptionsInt("-obs_freq", "Observation frequency", "", obs_freq, &obs_freq, NULL));
344:   PetscCall(PetscOptionsInt("-obs_stride", "Observation stride (sample every Nth grid point)", "", obs_stride, &obs_stride, NULL));
345:   PetscCall(PetscOptionsReal("-g", "Gravitational constant", "", g, &g, NULL));
346:   PetscCall(PetscOptionsReal("-dt", "Time step size", "", dt, &dt, NULL));
347:   PetscCall(PetscOptionsReal("-Lx", "Domain length in x", "", Lx, &Lx, NULL));
348:   PetscCall(PetscOptionsReal("-Ly", "Domain length in y", "", Ly, &Ly, NULL));
349:   PetscCall(PetscOptionsReal("-h0", "Mean water height", "", h0, &h0, NULL));
350:   PetscCall(PetscOptionsReal("-Ax", "Wave amplitude in x", "", Ax, &Ax, NULL));
351:   PetscCall(PetscOptionsReal("-Ay", "Wave amplitude in y", "", Ay, &Ay, NULL));
352:   PetscCall(PetscOptionsReal("-init_perturb_amplitude", "Initial ensemble perturbation amplitude (uniform on [0, amplitude), then mean-centered)", "", init_perturb_amplitude, &init_perturb_amplitude, NULL));
353:   PetscCall(PetscOptionsReal("-init_h_bias", "Initial ensemble-mean bias applied to height", "", init_h_bias, &init_h_bias, NULL));
354:   PetscCall(PetscOptionsReal("-obs_error", "Observation error standard deviation", "", obs_error_std, &obs_error_std, NULL));
355:   PetscCall(PetscOptionsInt("-random_seed", "Random seed for ensemble perturbations", "", random_seed, &random_seed, NULL));
356:   PetscCall(PetscOptionsInt("-n_spin", "Number of spinup steps for the truth trajectory before assimilation starts", "", n_spin, &n_spin, NULL));
357:   PetscCall(PetscOptionsInt("-progress_freq", "Print progress every N steps (0 = only first/last)", "", progress_freq, &progress_freq, NULL));
358:   PetscCall(PetscOptionsString("-output_file", "Output file for visualization data", "", "", output_file, sizeof(output_file), &output_enabled));
359:   PetscOptionsEnd();

361:   PetscCall(ValidateParameters(&nx, &ny, &steps, &obs_freq, &ensemble_size, &dt, &g, &obs_error_std));
362:   PetscCheck(init_perturb_amplitude > 0.0, PETSC_COMM_WORLD, PETSC_ERR_ARG_OUTOFRANGE, "Initial perturbation amplitude must be positive");

364:   cfg.nx         = nx;
365:   cfg.ny         = ny;
366:   cfg.Lx         = Lx;
367:   cfg.Ly         = Ly;
368:   cfg.g          = g;
369:   cfg.dt         = dt;
370:   cfg.h0         = h0;
371:   cfg.Ax         = Ax;
372:   cfg.Ay         = Ay;
373:   cfg.verify_mms = PETSC_FALSE;
374:   PetscCall(SetupForwardProblem(&cfg, &da_state, &sw_ctx, &x0));

376:   /* Initialize random number generator */
377:   PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
378:   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rng));
379:   PetscCall(PetscRandomSetSeed(rng, (unsigned long)(random_seed + rank)));
380:   PetscCall(PetscRandomSetFromOptions(rng));
381:   PetscCall(PetscRandomSeed(rng));

383:   /* Set initial condition from analytic wave solution */
384:   PetscCall(SetInitialCondition(da_state, x0, sw_ctx, PETSC_FALSE));

386:   /* Initialize truth trajectory */
387:   PetscCall(VecDuplicate(x0, &truth_state));
388:   PetscCall(VecCopy(x0, truth_state));
389:   PetscCall(VecDuplicate(x0, &rmse_work));

391:   /* Spinup if needed */
392:   if (n_spin > 0) {
393:     PetscInt spinup_progress_interval = (n_spin >= 10) ? (n_spin / 10) : 1;
394:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Spinning up truth trajectory for %" PetscInt_FMT " steps...\n", n_spin));
395:     for (PetscInt k = 0; k < n_spin; k++) {
396:       PetscCall(ShallowWaterStep2DVec(sw_ctx, truth_time, truth_state));
397:       truth_time += dt;
398:       if ((k + 1) % spinup_progress_interval == 0 || (k + 1) == n_spin) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Spinup progress: %" PetscInt_FMT "/%" PetscInt_FMT "\n", k + 1, n_spin));
399:     }
400:     PetscCall(VecCopy(truth_state, x0));
401:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Spinup complete.\n\n"));
402:   }

404:   /* Create observation matrix H. Pass the DMDA-partitioned state local size so that
405:      H's column layout matches the state vector. */
406:   PetscCall(VecGetLocalSize(x0, &local_state_size));
407:   PetscCheck(local_state_size % ndof == 0, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "local_state_size (%" PetscInt_FMT ") not a multiple of ndof (%" PetscInt_FMT ")", local_state_size, ndof);
408:   PetscCall(CreateObservationMatrix2D(nx, ny, ndof, obs_stride, local_state_size, &H, &H1, &nobs));

410:   /* Initialize observation vectors */
411:   PetscCall(MatCreateVecs(H, NULL, &observation));
412:   PetscCall(VecDuplicate(observation, &obs_noise));
413:   PetscCall(VecDuplicate(observation, &obs_error_var));
414:   PetscCall(VecSet(obs_error_var, obs_error_std * obs_error_std));

416:   /* Create and configure PetscDA for ensemble data assimilation */
417:   PetscCall(PetscDACreate(PETSC_COMM_WORLD, &da));
418:   PetscCall(PetscDASetSizes(da, nx * ny * ndof, nobs));
419:   PetscCall(PetscDAEnsembleSetSize(da, ensemble_size));
420:   PetscCall(VecGetLocalSize(observation, &local_obs_size));
421:   PetscCall(PetscDASetLocalSizes(da, local_state_size, local_obs_size));
422:   PetscCall(PetscDASetNDOF(da, ndof));
423:   PetscCall(PetscDASetFromOptions(da));
424:   PetscCall(PetscDAEnsembleGetSize(da, &ensemble_size));
425:   PetscCall(PetscDASetUp(da));

427:   /* Initialize ensemble statistics vectors */
428:   PetscCall(VecDuplicate(x0, &x_mean));
429:   PetscCall(VecDuplicate(x0, &x_forecast));

431:   /* Set observation error variance */
432:   PetscCall(PetscDASetObsErrorVariance(da, obs_error_var));

434:   /* Configure localization for LETKF. Built-in distance-based kernels (Gaspari-Cohn,
435:      Gaussian, boxcar) are wired through SetLocalizationCoordinates and the matrix Q
436:      is built lazily on the first analysis; the NONE kernel needs no setup. */
437:   PetscCall(ConfigureLETKFLocalization(da, da_state, nx, ny, ndof, local_state_size, Lx, Ly, H1, &localization_radius));

439:   /* Initialize ensemble members with perturbations */
440:   PetscCall(InitializeBalancedEnsemble(da, da_state, sw_ctx, random_seed, ensemble_size, init_perturb_amplitude, init_h_bias));

442:   /* Print configuration summary */
443:   dx  = Lx / nx;
444:   dy  = Ly / ny;
445:   c   = PetscSqrtReal(g * h0);
446:   cfl = dt * c * (1.0 / dx + 1.0 / dy);
447:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "2D Shallow Water LETKF Example\n"));
448:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "==============================\n"));
449:   PetscCall(PetscPrintf(PETSC_COMM_WORLD,
450:                         "  Mode                  : Data Assimilation\n"
451:                         "  Flux scheme           : Rusanov (1st order)\n"
452:                         "  Grid dimensions       : %" PetscInt_FMT " x %" PetscInt_FMT "\n"
453:                         "  State dimension       : %" PetscInt_FMT " (%" PetscInt_FMT " grid points x %" PetscInt_FMT " DOF)\n"
454:                         "  Observation dimension : %" PetscInt_FMT "\n"
455:                         "  Observation stride    : %" PetscInt_FMT "\n"
456:                         "  Ensemble size         : %" PetscInt_FMT "\n"
457:                         "  Domain size           : %.2f x %.2f\n"
458:                         "  Grid spacing          : dx=%.4f, dy=%.4f\n"
459:                         "  Mean height (h0)      : %.4f\n"
460:                         "  Wave amplitudes       : Ax=%.4f, Ay=%.4f\n"
461:                         "  Gravitational const   : %.4f\n"
462:                         "  Wave speed (c)        : %.4f\n"
463:                         "  Time step (dt)        : %.4f\n"
464:                         "  CFL number            : %.4f\n"
465:                         "  Total steps           : %" PetscInt_FMT "\n"
466:                         "  Observation frequency : %" PetscInt_FMT "\n"
467:                         "  Init perturb amp      : %.3f\n"
468:                         "  Init height bias      : %.3f\n"
469:                         "  Observation noise std : %.3f\n"
470:                         "  Random seed           : %" PetscInt_FMT "\n",
471:                         nx, ny, nx * ny * ndof, nx * ny, ndof, nobs, obs_stride, ensemble_size, (double)Lx, (double)Ly, (double)dx, (double)dy, (double)h0, (double)Ax, (double)Ay, (double)g, (double)c, (double)dt, (double)cfl, steps, obs_freq, (double)init_perturb_amplitude, (double)init_h_bias, (double)obs_error_std, random_seed));
472:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n"));

474:   /* Open output file if requested - only in serial mode */
475:   if (output_enabled) {
476:     PetscMPIInt size;

478:     PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
479:     if (size > 1) {
480:       PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Warning: Output file generation is only supported in serial mode (currently running with %d processes)\n", (int)size));
481:       PetscCall(PetscPrintf(PETSC_COMM_WORLD, "         Disabling output file. Run with single process to enable.\n\n"));
482:       output_enabled = PETSC_FALSE;
483:       fp             = NULL;
484:     } else {
485:       PetscCall(PetscFOpen(PETSC_COMM_WORLD, output_file, "w", &fp));
486:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# 2D Shallow Water LETKF Output\n"));
487:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# nx=%" PetscInt_FMT ", ny=%" PetscInt_FMT ", ndof=%" PetscInt_FMT ", nobs=%" PetscInt_FMT ", ensemble_size=%" PetscInt_FMT "\n", nx, ny, ndof, nobs, ensemble_size));
488:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# dt=%.6f, g=%.6f, obs_error_std=%.6f\n", (double)dt, (double)g, (double)obs_error_std));
489:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "# Format: step time [truth]x(nx*ny*ndof) [mean]x(nx*ny*ndof) [obs]x(nobs) rmse_forecast rmse_analysis\n"));
490:       PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Writing output to: %s\n\n", output_file));
491:     }
492:   }

494:   /* Print initial condition */
495:   PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
496:   PetscCall(ComputeRMSE(x_mean, truth_state, rmse_work, nx * ny * ndof, &rmse_initial));
497:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Step %4" PetscInt_FMT ", time %6.3f  RMSE_forecast %.5f  RMSE_analysis %.5f [initial]\n", (PetscInt)0, 0.0, (double)rmse_initial, (double)rmse_initial));

499:   if (output_enabled && fp) {
500:     const PetscScalar *truth_array, *mean_array;
501:     PetscInt           i;
502:     PetscCall(VecGetArrayRead(truth_state, &truth_array));
503:     PetscCall(VecGetArrayRead(x_mean, &mean_array));
504:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "0 0.000000"));
505:     for (i = 0; i < nx * ny * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(truth_array[i])));
506:     for (i = 0; i < nx * ny * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(mean_array[i])));
507:     for (i = 0; i < nobs; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " nan"));
508:     PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e %.8e\n", (double)rmse_initial, (double)rmse_initial));
509:     PetscCall(VecRestoreArrayRead(truth_state, &truth_array));
510:     PetscCall(VecRestoreArrayRead(x_mean, &mean_array));
511:   }

513:   /* Main simulation loop */
514:   for (step = 1; step <= steps; step++) {
515:     PetscReal time = step * dt;

517:     /* Propagate ensemble and truth trajectory */
518:     PetscCall(PetscDAEnsembleForecast(da, ShallowWaterStep2D, sw_ctx));
519:     PetscCall(ShallowWaterStep2DVec(sw_ctx, truth_time, truth_state));
520:     truth_time += dt;

522:     /* Forecast step: compute ensemble mean and forecast RMSE */
523:     PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
524:     PetscCall(VecCopy(x_mean, x_forecast));
525:     PetscCall(ComputeRMSE(x_forecast, truth_state, rmse_work, nx * ny * ndof, &rmse_forecast));
526:     rmse_analysis = rmse_forecast;

528:     /* Analysis step: assimilate observations when available */
529:     if (step % obs_freq == 0) {
530:       Vec truth_obs, temp_truth;
531:       PetscCall(MatCreateVecs(H, NULL, &truth_obs));
532:       PetscCall(MatCreateVecs(H, &temp_truth, NULL));

534:       /* Generate observations from truth */
535:       PetscCall(VecCopy(truth_state, temp_truth));
536:       PetscCall(MatMult(H, temp_truth, truth_obs));

538:       /* Add observation noise */
539:       PetscCall(VecSetRandomGaussian(obs_noise, rng, 0.0, obs_error_std));
540:       PetscCall(VecWAXPY(observation, 1.0, obs_noise, truth_obs));

542:       /* Perform LETKF analysis */
543:       PetscCall(PetscDAEnsembleAnalysis(da, observation, H));

545:       /* Clean up */
546:       PetscCall(VecDestroy(&temp_truth));
547:       PetscCall(VecDestroy(&truth_obs));

549:       /* Compute analysis RMSE */
550:       PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
551:       PetscCall(ComputeRMSE(x_mean, truth_state, rmse_work, nx * ny * ndof, &rmse_analysis));
552:       obs_count++;
553:     }

555:     /* Accumulate statistics */
556:     sum_rmse_forecast += rmse_forecast;
557:     sum_rmse_analysis += rmse_analysis;
558:     n_stat_steps++;

560:     /* Write data to output file */
561:     if (output_enabled && fp) {
562:       const PetscScalar *truth_array, *mean_array, *obs_array;
563:       PetscInt           i;
564:       PetscCall(VecGetArrayRead(truth_state, &truth_array));
565:       PetscCall(VecGetArrayRead(x_mean, &mean_array));
566:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, "%" PetscInt_FMT " %.6f", step, (double)time));
567:       for (i = 0; i < nx * ny * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(truth_array[i])));
568:       for (i = 0; i < nx * ny * ndof; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(mean_array[i])));
569:       if (step % obs_freq == 0) {
570:         PetscCall(VecGetArrayRead(observation, &obs_array));
571:         for (i = 0; i < nobs; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e", (double)PetscRealPart(obs_array[i])));
572:         PetscCall(VecRestoreArrayRead(observation, &obs_array));
573:       } else
574:         for (i = 0; i < nobs; i++) PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " nan"));
575:       PetscCall(PetscFPrintf(PETSC_COMM_WORLD, fp, " %.8e %.8e\n", (double)rmse_forecast, (double)rmse_analysis));
576:       PetscCall(VecRestoreArrayRead(truth_state, &truth_array));
577:       PetscCall(VecRestoreArrayRead(x_mean, &mean_array));
578:     }

580:     /* Progress reporting */
581:     if (step == steps || (progress_freq > 0 && step % progress_freq == 0))
582:       PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Step %4" PetscInt_FMT ", time %6.3f  RMSE_forecast %.5f  RMSE_analysis %.5f\n", step, (double)time, (double)rmse_forecast, (double)rmse_analysis));
583:   }

585:   /* Report final statistics */
586:   if (n_stat_steps > 0) {
587:     PetscReal avg_rmse_forecast = sum_rmse_forecast / n_stat_steps;
588:     PetscReal avg_rmse_analysis = sum_rmse_analysis / n_stat_steps;
589:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\nStatistics (%" PetscInt_FMT " steps):\n", n_stat_steps));
590:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "==================================================\n"));
591:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Mean RMSE (forecast) : %.5f\n", (double)avg_rmse_forecast));
592:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Mean RMSE (analysis) : %.5f\n", (double)avg_rmse_analysis));
593:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "  Observations used    : %" PetscInt_FMT "\n\n", obs_count));
594:   }

596:   /* Close output file */
597:   if (output_enabled && fp) {
598:     PetscCall(PetscFClose(PETSC_COMM_WORLD, fp));
599:     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output written to: %s\n", output_file));
600:   }

602:   PetscCall(PetscDAView(da, PETSC_VIEWER_STDOUT_WORLD));

604:   /* Cleanup */
605:   PetscCall(MatDestroy(&H));
606:   PetscCall(MatDestroy(&H1));
607:   PetscCall(VecDestroy(&x_forecast));
608:   PetscCall(VecDestroy(&x_mean));
609:   PetscCall(VecDestroy(&obs_error_var));
610:   PetscCall(VecDestroy(&obs_noise));
611:   PetscCall(VecDestroy(&observation));
612:   PetscCall(PetscDADestroy(&da));
613:   PetscCall(VecDestroy(&rmse_work));
614:   PetscCall(VecDestroy(&truth_state));
615:   PetscCall(VecDestroy(&x0));
616:   PetscCall(DMDestroy(&da_state));
617:   PetscCall(ShallowWater2DContextDestroy(&sw_ctx));
618:   PetscCall(PetscRandomDestroy(&rng));

620:   PetscCall(PetscFinalize());
621:   return 0;
622: }

624: /*TEST

626:   testset:
627:     requires: !complex
628:     args: -petscda_type letkf -steps 5 -progress_freq 1 -petscda_ensemble_size 10 -obs_freq 2 -obs_error 0.03 -nx 21 -ny 21

630:     test:
631:       suffix: letkf_wave2d
632:       args: -petscda_ensemble_size 7

634:     test:
635:       nsize: 3
636:       suffix: letkf_wave2d_mpi
637:       args: -petscda_ensemble_size 5 -petscda_letkf_localization_radius 10.0

639:     test:
640:       suffix: kokkos_wave2d_serial
641:       requires: kokkos_kernels
642:       args: -mat_type aijkokkos -vec_type kokkos -petscda_ensemble_size 7

644:     test:
645:       nsize: 3
646:       suffix: kokkos_wave2d
647:       requires: kokkos_kernels
648:       args: -mat_type aijkokkos -vec_type kokkos -petscda_ensemble_size 5 -petscda_letkf_localization_radius 10.0

650:     test:
651:       suffix: letkf_none
652:       args: -petscda_ensemble_size 7 -petscda_letkf_localization_type none

654:     test:
655:       suffix: letkf_gaussian
656:       args: -petscda_ensemble_size 7 -petscda_letkf_localization_type gaussian -petscda_letkf_localization_radius 10.0

658:     test:
659:       suffix: letkf_boxcar
660:       args: -petscda_ensemble_size 3 -petscda_letkf_localization_type boxcar -petscda_letkf_localization_radius 15.0

662:   # Exercises truth-trajectory spinup (-n_spin) and the visualization-data
663:   # writer (-output_file). Serial-only: the writer guards itself off under MPI.
664:   test:
665:     suffix: letkf_wave2d_spinup_io
666:     requires: !complex
667:     nsize: 1
668:     args: -petscda_type letkf -steps 3 -n_spin 2 -nx 11 -ny 11 -obs_freq 2 -obs_error 0.1 -petscda_ensemble_size 5 -petscda_letkf_localization_radius 10.0 -output_file ex4_spinup_io.dat
669:     temporaries: ex4_spinup_io.dat

671:   # Sparse + noisy observation regime (16x sparser, 10x noisier than the dense testset)
672:   # exercises the localized analysis path on a configuration where it is expected to
673:   # behave qualitatively differently from the unlocalized fast path.
674:   test:
675:     suffix: letkf_wave2d_sparse
676:     requires: !complex
677:     args: -steps 5 -progress_freq 1 -nx 21 -ny 21 -obs_freq 2 -obs_stride 8 -obs_error 0.3 -petscda_type letkf -petscda_ensemble_size 7 -petscda_letkf_localization_type gaspari_cohn -petscda_letkf_localization_radius 30.0
678: TEST*/