Actual source code: letkfilter.c

  1: #include <petscda.h>
  2: #include <petsc/private/daimpl.h>
  3: #include <petsc/private/daensembleimpl.h>
  4: #include <../src/ml/da/impls/ensemble/letkf/letkf.h>

  6: static PetscErrorCode PetscDADestroy_LETKF(PetscDA da)
  7: {
  8:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

 10:   PetscFunctionBegin;
 11:   PetscCall(VecDestroy(&impl->mean));
 12:   PetscCall(VecDestroy(&impl->y_mean));
 13:   PetscCall(VecDestroy(&impl->delta_scaled));
 14:   PetscCall(VecDestroy(&impl->w));
 15:   PetscCall(VecDestroy(&impl->r_inv_sqrt));
 16:   PetscCall(MatDestroy(&impl->Z));
 17:   PetscCall(MatDestroy(&impl->S));
 18:   PetscCall(MatDestroy(&impl->T_sqrt));
 19:   PetscCall(MatDestroy(&impl->w_ones));
 20:   PetscCall(MatDestroy(&impl->Q));
 21: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
 22:   PetscCall(PetscDALETKFDestroyLocalization_Kokkos(impl));
 23: #endif
 24:   PetscCall(ISDestroy(&impl->obs_is_local));
 25:   PetscCall(VecScatterDestroy(&impl->obs_scat));
 26:   PetscCall(VecDestroy(&impl->obs_work));
 27:   PetscCall(VecDestroy(&impl->y_mean_work));
 28:   PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
 29:   PetscCall(MatDestroy(&impl->Z_work));
 30:   PetscCall(PetscDADestroy_Ensemble(da));
 31:   PetscCall(PetscFree(da->data));

 33:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalization_C", NULL));
 34:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetObsPerVertex_C", NULL));
 35:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetObsPerVertex_C", NULL));
 36:   PetscFunctionReturn(PETSC_SUCCESS);
 37: }

 39: /*
 40:   ExtractLocalObservations - Extracts local observations for a vertex using localization matrix Q (CPU version)

 42:   Input Parameters:
 43: + Q          - localization matrix (state_size/ndof x obs_size), each row has constant non-zeros
 44: . vertex_idx - index of the vertex (row of Q)
 45: . Z_global   - global observation ensemble matrix (obs_size x m) OR local work matrix
 46: . y_global   - global observation vector (size obs_size) OR local work vector
 47: . y_mean_global - global observation mean (size obs_size) OR local work vector
 48: . r_inv_sqrt_global - global R^{-1/2} (size obs_size) OR local work vector
 49: . obs_g2l    - map from global observation index to local index (if using local work vectors)
 50: . m          - ensemble size

 52:   Output Parameters:
 53: . Z_local    - local observation ensemble (p_local x m), pre-allocated
 54: . y_local    - local observation vector (size p_local), pre-allocated
 55: . y_mean_local - local observation mean (size p_local), pre-allocated
 56: - r_inv_sqrt_local - local R^{-1/2} (size p_local), pre-allocated
 57: */
 58: static PetscErrorCode ExtractLocalObservations(Mat Q, PetscInt vertex_idx, Mat Z_global, Vec y_global, Vec y_mean_global, Vec r_inv_sqrt_global, PetscHMapI obs_g2l, PetscInt m, Mat Z_local, Vec y_local, Vec y_mean_local, Vec r_inv_sqrt_local)
 59: {
 60:   const PetscInt    *cols;
 61:   const PetscScalar *vals;
 62:   PetscInt           ncols, k, j;
 63:   const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
 64:   PetscScalar       *z_local_array, *y_local_array, *y_mean_local_array, *r_inv_sqrt_local_array;
 65:   PetscInt           lda_z_global, lda_z_local;

 67:   PetscFunctionBegin;
 68:   /* Get the row of Q corresponding to this vertex */
 69:   PetscCall(MatGetRow(Q, vertex_idx, &ncols, &cols, &vals));

 71:   /* Get array access to global data */
 72:   PetscCall(MatDenseGetArrayRead(Z_global, &z_global_array));
 73:   PetscCall(VecGetArrayRead(y_global, &y_global_array));
 74:   PetscCall(VecGetArrayRead(y_mean_global, &y_mean_global_array));
 75:   PetscCall(VecGetArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));

 77:   /* Get array access to local data */
 78:   PetscCall(MatDenseGetArrayWrite(Z_local, &z_local_array));
 79:   PetscCall(VecGetArray(y_local, &y_local_array));
 80:   PetscCall(VecGetArray(y_mean_local, &y_mean_local_array));
 81:   PetscCall(VecGetArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));

 83:   /* Get leading dimensions */
 84:   PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
 85:   PetscCall(MatDenseGetLDA(Z_local, &lda_z_local));

 87:   /* Extract local observations and weight R^{-1/2} */
 88:   for (k = 0; k < ncols; k++) {
 89:     PetscInt    obs_idx   = cols[k];
 90:     PetscScalar weight    = vals[k];
 91:     PetscInt    local_idx = obs_idx;

 93:     /* If using local work vectors, map global index to local index */
 94:     if (obs_g2l) {
 95:       PetscCall(PetscHMapIGet(obs_g2l, obs_idx, &local_idx));
 96:       PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local map", obs_idx);
 97:     }

 99:     y_local_array[k]          = y_global_array[local_idx];
100:     y_mean_local_array[k]     = y_mean_global_array[local_idx];
101:     r_inv_sqrt_local_array[k] = r_inv_sqrt_global_array[local_idx] * PetscSqrtScalar(weight);

103:     /* Extract Z matrix row (column-major layout) */
104:     for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = z_global_array[local_idx + j * lda_z_global];
105:   }

107:   /* Restore arrays */
108:   PetscCall(VecRestoreArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
109:   PetscCall(VecRestoreArray(y_mean_local, &y_mean_local_array));
110:   PetscCall(VecRestoreArray(y_local, &y_local_array));
111:   PetscCall(MatDenseRestoreArrayWrite(Z_local, &z_local_array));
112:   PetscCall(VecRestoreArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
113:   PetscCall(VecRestoreArrayRead(y_mean_global, &y_mean_global_array));
114:   PetscCall(VecRestoreArrayRead(y_global, &y_global_array));
115:   PetscCall(MatDenseRestoreArrayRead(Z_global, &z_global_array));

117:   /* Restore Q row */
118:   PetscCall(MatRestoreRow(Q, vertex_idx, &ncols, &cols, &vals));

120:   /* Assemble local matrices/vectors */
121:   PetscCall(MatAssemblyBegin(Z_local, MAT_FINAL_ASSEMBLY));
122:   PetscCall(MatAssemblyEnd(Z_local, MAT_FINAL_ASSEMBLY));
123:   PetscFunctionReturn(PETSC_SUCCESS);
124: }

126: /*
127:   PetscDALETKFLocalAnalysis - Performs local LETKF analysis for all grid points (CPU version)

129:   Input Parameters:
130: + da             - the PetscDA context
131: . impl           - LETKF implementation data
132: . m              - ensemble size
133: . n_vertices     - number of grid points
134: . X              - global anomaly matrix (state_size x m)
135: . observation    - observation vector
136: . Z_global       - global observation ensemble (obs_size x m)
137: . y_mean_global  - global observation mean
138: - r_inv_sqrt_global - global R^{-1/2}

140:   Output:
141: . da->ensemble - updated with analysis ensemble

143:   Notes:
144:   This function performs the local analysis loop for LETKF, processing each grid point
145:   independently using its local observations defined by the localization matrix Q.
146:   This is the CPU version that does not use Kokkos acceleration.

148:   All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
149:   y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
150:   created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
151: */
152: PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
153: {
154:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
155:   Mat               Z_local, S_local, T_sqrt_local, G_local;
156:   Vec               y_local, y_mean_local, delta_scaled_local, r_inv_sqrt_local;
157:   Vec               w_local, s_transpose_delta;
158:   PetscInt          i_grid_point;
159:   PetscInt          ndof;
160:   PetscReal         sqrt_m_minus_1, scale;
161:   PetscInt          rstart;
162:   Mat               X_rows, E_analysis_rows;

164:   PetscFunctionBegin;
165:   ndof           = da->ndof;
166:   scale          = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
167:   sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
168:   /* Create local analysis workspace (n_obs_vertex x m matrices and vectors) */
169:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, impl->n_obs_vertex, m, NULL, &Z_local));
170:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)Z_local, "dense_"));
171:   PetscCall(MatSetFromOptions(Z_local));
172:   PetscCall(MatSetUp(Z_local));
173:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, impl->n_obs_vertex, m, NULL, &S_local));
174:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)S_local, "dense_"));
175:   PetscCall(MatSetFromOptions(S_local));
176:   PetscCall(MatSetUp(S_local));
177:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &T_sqrt_local));
178:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)T_sqrt_local, "dense_"));
179:   PetscCall(MatSetFromOptions(T_sqrt_local));
180:   PetscCall(MatSetUp(T_sqrt_local));
181:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &G_local));
182:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)G_local, "dense_"));
183:   PetscCall(MatSetFromOptions(G_local));
184:   PetscCall(MatSetUp(G_local));

186:   /* Create vectors using MatCreateVecs from Z_local (n_obs_vertex x m) */
187:   PetscCall(MatCreateVecs(Z_local, &w_local, &y_local));
188:   PetscCall(VecDuplicate(y_local, &y_mean_local));
189:   PetscCall(VecDuplicate(y_local, &delta_scaled_local));
190:   PetscCall(VecDuplicate(y_local, &r_inv_sqrt_local));
191:   PetscCall(VecDuplicate(w_local, &s_transpose_delta));

193:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, ndof, m, NULL, &X_rows));
194:   PetscCall(MatDuplicate(X_rows, MAT_DO_NOT_COPY_VALUES, &E_analysis_rows));

196:   /* LETKF: Loop over all grid points and perform local analysis */
197:   PetscCall(MatGetOwnershipRange(impl->Q, &rstart, NULL));

199:   for (i_grid_point = 0; i_grid_point < n_vertices; i_grid_point++) {
200:     /* Extract local observations for this grid point using Q[i_grid_point,:] */
201:     /* Note: i_grid_point is local index, but MatGetRow needs global index */
202:     PetscCall(ExtractLocalObservations(impl->Q, rstart + i_grid_point, Z_global, observation, y_mean_global, r_inv_sqrt_global, impl->obs_g2l, m, Z_local, y_local, y_mean_local, r_inv_sqrt_local));

204:     /* Compute local normalized innovation matrix: S_local = R_local^{-1/2} * (Z_local - y_mean_local * 1') / sqrt(m - 1) */
205:     PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(Z_local, y_mean_local, r_inv_sqrt_local, m, scale, S_local));

207:     /* Compute local delta_scaled = R_local^{-1/2} * (y_local - y_mean_local) */
208:     PetscCall(VecWAXPY(delta_scaled_local, -1.0, y_mean_local, y_local));
209:     PetscCall(VecPointwiseMult(delta_scaled_local, delta_scaled_local, r_inv_sqrt_local));

211:     /* Factor local T = (I + S_local^T * S_local) */
212:     PetscCall(PetscDAEnsembleTFactor(da, S_local));

214:     /* Compute local analysis weights: w_local = T_local^{-1} * S_local^T * delta_scaled_local */
215:     PetscCall(MatMultTranspose(S_local, delta_scaled_local, s_transpose_delta));
216:     PetscCall(PetscDAEnsembleApplyTInverse(da, s_transpose_delta, w_local));

218:     /* Compute local square-root transform: T_sqrt_local = T_local^{-1/2} (U is identity, so pass NULL) */
219:     PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, T_sqrt_local));

221:     /* Form local transform G_local = w_local * 1' + sqrt(m - 1) * T_sqrt_local * U
222:        Instead of creating w_ones_local = w_local * 1', we add w_local to each column of G_local */
223:     PetscCall(MatCopy(T_sqrt_local, G_local, SAME_NONZERO_PATTERN));
224:     PetscCall(MatScale(G_local, sqrt_m_minus_1));
225:     {
226:       const PetscScalar *w_array;
227:       PetscScalar       *g_array;
228:       PetscInt           j, k, lda_g;

230:       PetscCall(VecGetArrayRead(w_local, &w_array));
231:       PetscCall(MatDenseGetArrayWrite(G_local, &g_array));
232:       PetscCall(MatDenseGetLDA(G_local, &lda_g));
233:       for (j = 0; j < m; j++)
234:         for (k = 0; k < m; k++) g_array[k + j * lda_g] += w_array[k];
235:       PetscCall(MatDenseRestoreArrayWrite(G_local, &g_array));
236:       PetscCall(VecRestoreArrayRead(w_local, &w_array));
237:     }

239:     /* LETKF Algorithm 2, Line 13: Update ensemble at grid point i_grid_point
240:        E_a[i,:] = x_bar_f[i] + X_f[i,:] * G_local

242:        Where:
243:        - x_bar_f[i] is the forecast mean at grid point i_grid_point (ndof values from global mean vector)
244:        - X_f[i,:] is the forecast anomaly rows at grid point i_grid_point (ndof rows from global anomaly matrix X)
245:        - G_local = w_local * 1' + sqrt(m-1) * T_local^{1/2} * U (computed above in G_local)
246:      */
247:     {
248:       const PetscScalar *x_array, *mean_array;
249:       PetscScalar       *e_array, *x_rows_array, *ea_rows_array;
250:       PetscInt           j, k, lda_x, lda_e;

252:       /* Extract ndof rows starting at (i_grid_point * ndof) from X: X_f[i_grid_point*ndof:(i_grid_point+1)*ndof, :] */
253:       PetscCall(MatDenseGetArrayRead(X, &x_array));
254:       PetscCall(MatDenseGetArray(X_rows, &x_rows_array));
255:       PetscCall(MatDenseGetLDA(X, &lda_x));
256:       for (j = 0; j < m; j++) {
257:         for (k = 0; k < ndof; k++) x_rows_array[k + j * ndof] = x_array[(i_grid_point * ndof + k) + j * lda_x];
258:       }
259:       PetscCall(MatDenseRestoreArray(X_rows, &x_rows_array));
260:       PetscCall(MatDenseRestoreArrayRead(X, &x_array));

262:       /* Apply local transform: E_analysis_rows = X_rows * G_local^T */
263:       PetscCall(MatMatMult(X_rows, G_local, MAT_REUSE_MATRIX, PETSC_DEFAULT, &E_analysis_rows));

265:       /* Add local mean: E_a[i_grid_point*ndof:(i_grid_point+1)*ndof, :] = x_bar_f[i_grid_point*ndof:(i_grid_point+1)*ndof] + X_f[...] * G_local */
266:       PetscCall(VecGetArrayRead(impl->mean, &mean_array));
267:       PetscCall(MatDenseGetArray(E_analysis_rows, &ea_rows_array));
268:       for (j = 0; j < m; j++) {
269:         for (k = 0; k < ndof; k++) ea_rows_array[k + j * ndof] += mean_array[i_grid_point * ndof + k];
270:       }
271:       PetscCall(MatDenseRestoreArray(E_analysis_rows, &ea_rows_array));
272:       PetscCall(VecRestoreArrayRead(impl->mean, &mean_array));

274:       /* Store result back in ensemble[i_grid_point*ndof:(i_grid_point+1)*ndof, :] */
275:       PetscCall(MatDenseGetArrayWrite(en->ensemble, &e_array));
276:       PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
277:       PetscCall(MatDenseGetArrayRead(E_analysis_rows, (const PetscScalar **)&ea_rows_array));
278:       for (j = 0; j < m; j++) {
279:         for (k = 0; k < ndof; k++) e_array[(i_grid_point * ndof + k) + j * lda_e] = ea_rows_array[k + j * ndof];
280:       }
281:       PetscCall(MatDenseRestoreArrayRead(E_analysis_rows, (const PetscScalar **)&ea_rows_array));
282:       PetscCall(MatDenseRestoreArrayWrite(en->ensemble, &e_array));
283:     }
284:   }
285:   PetscCall(MatDestroy(&E_analysis_rows));
286:   PetscCall(MatDestroy(&X_rows));
287:   PetscCall(VecDestroy(&s_transpose_delta));
288:   PetscCall(VecDestroy(&w_local));
289:   PetscCall(VecDestroy(&r_inv_sqrt_local));
290:   PetscCall(VecDestroy(&delta_scaled_local));
291:   PetscCall(VecDestroy(&y_mean_local));
292:   PetscCall(VecDestroy(&y_local));
293:   PetscCall(MatDestroy(&G_local));
294:   PetscCall(MatDestroy(&T_sqrt_local));
295:   PetscCall(MatDestroy(&S_local));
296:   PetscCall(MatDestroy(&Z_local));
297:   PetscFunctionReturn(PETSC_SUCCESS);
298: }

300: static PetscErrorCode PetscDAEnsembleAnalysis_LETKF(PetscDA da, Vec observation, Mat H)
301: {
302:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
303:   Mat            X;
304:   PetscInt       m;
305:   PetscBool      reallocate = PETSC_FALSE;

307:   PetscFunctionBegin;
308:   m = impl->en.size;

310:   /* Check if localization matrix Q is set */
311:   PetscCheck(impl->Q, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Localization matrix Q not set. Call PetscDALETKFSetLocalization() first.");

313:   /* Warn if Cholesky sqrt type is used with LETKF - it produces an asymmetric
314:      T^{-1/2} = L^{-T} which is incorrect for the local perturbation update.
315:      LETKF requires the symmetric square root T^{-1/2} = V * D^{-1/2} * V^T. */
316:   PetscCheck(impl->en.sqrt_type != PETSCDA_SQRT_CHOLESKY, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Cholesky sqrt type produces asymmetric T^{-1/2}, which is incorrect for LETKF. Use -petscda_ensemble_sqrt_type eigen or PetscDAEnsembleSetSqrtType(da, PETSCDA_SQRT_EIGEN) instead.");

318:   /* Check that ensemble size <= number of local observations per vertex.
319:      The eigen decomposition of T = I + S^T*S (m x m) requires that the
320:      local observation count p >= m; otherwise T is rank-deficient and the
321:      decomposition is ill-posed. */
322:   PetscCheck(m <= impl->n_obs_vertex, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Ensemble size (%" PetscInt_FMT ") must be <= number of local observations per vertex (%" PetscInt_FMT ") for LETKF eigen decomposition to be well-posed", m,
323:              impl->n_obs_vertex);

325:   /* Check for reallocation needs */
326:   if (impl->mean) {
327:     PetscInt mean_size;
328:     PetscCall(VecGetSize(impl->mean, &mean_size));
329:     if (mean_size != da->state_size) reallocate = PETSC_TRUE;
330:   }
331:   if (impl->Z) {
332:     PetscInt z_rows, z_cols;
333:     PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
334:     if (z_rows != da->obs_size || z_cols != m) reallocate = PETSC_TRUE;
335:   }

337:   /* Initialize or reallocate persistent work objects */
338:   if (!impl->mean || reallocate) {
339:     PetscCall(VecDestroy(&impl->mean));
340:     PetscCall(VecDestroy(&impl->y_mean));
341:     PetscCall(VecDestroy(&impl->delta_scaled));
342:     PetscCall(VecDestroy(&impl->w));
343:     PetscCall(VecDestroy(&impl->r_inv_sqrt));
344:     PetscCall(MatDestroy(&impl->Z));
345:     PetscCall(MatDestroy(&impl->S));
346:     PetscCall(MatDestroy(&impl->T_sqrt));
347:     PetscCall(MatDestroy(&impl->w_ones));

349:     /* Create mean vector from ensemble matrix (right vector = state space) */
350:     PetscCall(MatCreateVecs(impl->en.ensemble, NULL, &impl->mean));

352:     /* Create Z matrix (obs_size x m) */
353:     PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, da->obs_size, m, NULL, &impl->Z));
354:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->Z, "dense_"));
355:     PetscCall(MatSetFromOptions(impl->Z));
356:     PetscCall(MatSetUp(impl->Z));

358:     /* Create observation space vectors from Z matrix (left vector = observation space) */
359:     PetscCall(MatCreateVecs(impl->Z, NULL, &impl->y_mean));
360:     PetscCall(VecDuplicate(impl->y_mean, &impl->delta_scaled));
361:     PetscCall(VecDuplicate(da->obs_error_var, &impl->r_inv_sqrt));

363:     /* Create S matrix (same layout as Z) */
364:     PetscCall(MatDuplicate(impl->Z, MAT_DO_NOT_COPY_VALUES, &impl->S));

366:     /* Create T_sqrt matrix (m x m) - usually small */
367:     /* T_sqrt will hold the result of applying T^{-1/2} to identity matrix */
368:     PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->T_sqrt));
369:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->T_sqrt, "dense_"));
370:     PetscCall(MatSetFromOptions(impl->T_sqrt));
371:     PetscCall(MatSetUp(impl->T_sqrt));

373:     /* Create w_ones matrix (m x m) */
374:     PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->w_ones));
375:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->w_ones, "dense_"));
376:     PetscCall(MatSetFromOptions(impl->w_ones));
377:     PetscCall(MatSetUp(impl->w_ones));
378:   }

380:   /* Alg 6.4 line 1-2: Compute ensemble mean and scaled anomalies */
381:   PetscCall(PetscDAEnsembleComputeMean(da, impl->mean));

383:   /* Create anomaly matrix X = (E - x_mean * 1') / sqrt(m - 1) */
384:   PetscCall(PetscDAEnsembleComputeAnomalies(da, impl->mean, &X));

386:   /* Alg 6.4 line 3-4: Compute GLOBAL observation ensemble Z = H * E */
387:   /* Note: When H is a Kokkos matrix type (e.g., aijkokkos), MatMatMult may fail
388:      with non-Kokkos dense matrices. Use column-by-column multiplication with
389:      temporary vectors that are compatible with H's type. */
390:   {
391:     Vec      col_in, col_out, temp_in, temp_out;
392:     PetscInt j;

394:     /* Create temporary vectors compatible with H's type */
395:     PetscCall(MatCreateVecs(H, &temp_in, &temp_out));

397:     /* Compute Z = H * E column by column to avoid Kokkos vector type issues */
398:     for (j = 0; j < m; j++) {
399:       PetscCall(MatDenseGetColumnVecRead(impl->en.ensemble, j, &col_in));
400:       PetscCall(MatDenseGetColumnVecWrite(impl->Z, j, &col_out));

402:       /* Copy to temp vector, multiply, then copy back */
403:       PetscCall(VecCopy(col_in, temp_in));
404:       PetscCall(MatMult(H, temp_in, temp_out));
405:       PetscCall(VecCopy(temp_out, col_out));

407:       PetscCall(MatDenseRestoreColumnVecWrite(impl->Z, j, &col_out));
408:       PetscCall(MatDenseRestoreColumnVecRead(impl->en.ensemble, j, &col_in));
409:     }
410:     PetscCall(MatAssemblyBegin(impl->Z, MAT_FINAL_ASSEMBLY));
411:     PetscCall(MatAssemblyEnd(impl->Z, MAT_FINAL_ASSEMBLY));

413:     PetscCall(VecDestroy(&temp_out));
414:     PetscCall(VecDestroy(&temp_in));
415:   }

417:   /* Compute GLOBAL observation mean y_mean = H * x_mean */
418:   /* Use temporary vector compatible with H's type */
419:   {
420:     Vec temp_mean, temp_y_mean;
421:     PetscCall(MatCreateVecs(H, &temp_mean, &temp_y_mean));
422:     PetscCall(VecCopy(impl->mean, temp_mean));
423:     PetscCall(MatMult(H, temp_mean, temp_y_mean));
424:     PetscCall(VecCopy(temp_y_mean, impl->y_mean));
425:     PetscCall(VecDestroy(&temp_y_mean));
426:     PetscCall(VecDestroy(&temp_mean));
427:   }

429:   /* Compute GLOBAL R^{-1/2} (assumes diagonal R) */
430:   PetscCall(VecCopy(da->obs_error_var, impl->r_inv_sqrt));
431:   PetscCall(VecSqrtAbs(impl->r_inv_sqrt));
432:   PetscCall(VecReciprocal(impl->r_inv_sqrt));
433:   /* Perform local analysis for all vertices */

435: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
436:   /* Use GPU version only if:
437:      1. sqrt_type is eigen (GPU version only implements eigen/SVD, not cholesky)
438:      2. H matrix is a Kokkos type (aijkokkos) */
439:   {
440:     PetscBool use_gpu = PETSC_FALSE;
441:     if (impl->en.sqrt_type == PETSCDA_SQRT_EIGEN) {
442:   #if !defined(PETSC_USE_COMPLEX)
443:       /* Check if H matrix is a Kokkos type */
444:       PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, &use_gpu, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
445:   #endif
446:     }

448:     /* Scatter global vectors to local work vectors if available */
449:     if (impl->obs_scat) {
450:       PetscCall(VecScatterBegin(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
451:       PetscCall(VecScatterEnd(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));

453:       PetscCall(VecScatterBegin(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
454:       PetscCall(VecScatterEnd(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));

456:       PetscCall(VecScatterBegin(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
457:       PetscCall(VecScatterEnd(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));

459:       /* Handle Z matrix (scatter columns) */
460:       {
461:         PetscInt n_obs_local;
462:         PetscCall(VecGetLocalSize(impl->obs_work, &n_obs_local));
463:         if (!impl->Z_work) {
464:           PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
465:         } else {
466:           PetscInt m_old, n_old;
467:           PetscCall(MatGetSize(impl->Z_work, &n_old, &m_old));
468:           if (m_old != m || n_old != n_obs_local) {
469:             PetscCall(MatDestroy(&impl->Z_work));
470:             PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
471:           }
472:         }
473:         for (PetscInt i = 0; i < m; i++) {
474:           Vec z_col_global, z_col_local;
475:           PetscCall(MatDenseGetColumnVecRead(impl->Z, i, &z_col_global));
476:           PetscCall(MatDenseGetColumnVecWrite(impl->Z_work, i, &z_col_local));
477:           PetscCall(VecScatterBegin(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
478:           PetscCall(VecScatterEnd(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
479:           PetscCall(MatDenseRestoreColumnVecRead(impl->Z, i, &z_col_global));
480:           PetscCall(MatDenseRestoreColumnVecWrite(impl->Z_work, i, &z_col_local));
481:         }
482:       }
483:     }

485:     if (use_gpu) {
486:       PetscInt n_local;
487:       PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
488:       /* Use local work vectors for GPU analysis */
489:       PetscCall(PetscDALETKFLocalAnalysis_GPU(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
490:     } else {
491:       PetscInt n_local;
492:       PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
493:       if (impl->obs_scat) {
494:         PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
495:       } else {
496:         PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, observation, impl->Z, impl->y_mean, impl->r_inv_sqrt));
497:       }
498:     }
499:   }
500: #else
501:   /* Without Kokkos, use CPU version */
502:   {
503:     PetscInt n_local;
504:     PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
505:     PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, observation, impl->Z, impl->y_mean, impl->r_inv_sqrt));
506:   }
507: #endif
508:   PetscCall(MatDestroy(&X));
509:   PetscFunctionReturn(PETSC_SUCCESS);
510: }

512: static PetscErrorCode PetscDALETKFSetObsPerVertex_LETKF(PetscDA da, PetscInt n_obs_vertex)
513: {
514:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

516:   PetscFunctionBegin;
517:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
518:   impl->n_obs_vertex = n_obs_vertex;
519:   PetscFunctionReturn(PETSC_SUCCESS);
520: }

522: static PetscErrorCode PetscDALETKFGetObsPerVertex_LETKF(PetscDA da, PetscInt *n_obs_vertex)
523: {
524:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

526:   PetscFunctionBegin;
527:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
528:   *n_obs_vertex = impl->n_obs_vertex;
529:   PetscFunctionReturn(PETSC_SUCCESS);
530: }

532: static PetscErrorCode PetscDALETKFSetLocalization_LETKF(PetscDA da, Mat Q, Mat H)
533: {
534:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
535:   PetscInt       i, nrows, ncols, nnz, rstart, rend;

537:   PetscFunctionBegin;
538:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");

540:   /* Get matrix dimensions */
541:   PetscCall(MatGetSize(Q, &nrows, &ncols));

543:   /* Validate matrix dimensions */
544:   PetscCheck(nrows == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix rows (%" PetscInt_FMT ") must match state size (%" PetscInt_FMT ")", nrows, da->state_size);
545:   PetscCheck(ncols == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix columns (%" PetscInt_FMT ") must match observation size (%" PetscInt_FMT ")", ncols, da->obs_size);

547:   /* Validate that each row has const non-zero entries */
548:   PetscCall(MatGetOwnershipRange(Q, &rstart, &rend));
549:   for (i = rstart; i < rend; i++) {
550:     const PetscInt    *cols;
551:     const PetscScalar *vals;
552:     PetscCall(MatGetRow(Q, i, &nnz, &cols, &vals));
553:     PetscCheck(nnz == impl->n_obs_vertex, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Row %" PetscInt_FMT " has %" PetscInt_FMT " non-zeros, expected %" PetscInt_FMT, i, nnz, (PetscInt)impl->n_obs_vertex);
554:     PetscCall(MatRestoreRow(Q, i, &nnz, &cols, &vals));
555:   }

557:   /* Store the localization matrix */
558:   PetscCall(MatDestroy(&impl->Q));
559:   PetscCall(PetscObjectReference((PetscObject)Q));
560:   impl->Q = Q;
561: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
562:   PetscCall(PetscDALETKFSetupLocalization_Kokkos(impl, H));
563: #endif
564:   PetscFunctionReturn(PETSC_SUCCESS);
565: }

567: static PetscErrorCode PetscDAView_LETKF(PetscDA da, PetscViewer viewer)
568: {
569:   PetscBool      iascii;
570:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

572:   PetscFunctionBegin;
573:   PetscCall(PetscDAView_Ensemble(da, viewer));
574:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
575:   if (iascii) {
576: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
577:     if (impl->en.sqrt_type == PETSCDA_SQRT_CHOLESKY) {
578:       PetscCall(PetscViewerASCIIPrintf(viewer, "  Local analysis: CPU\n"));
579:     } else {
580:       /* Check if R matrix is Kokkos type to determine if GPU will be used */
581:       if (da->R) {
582:         PetscBool is_kokkos = PETSC_FALSE;
583:         PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, &is_kokkos, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
584:         if (is_kokkos) {
585:           PetscCall(PetscViewerASCIIPrintf(viewer, "  Local analysis: Kokkos\n"));
586:         } else {
587:           PetscCall(PetscViewerASCIIPrintf(viewer, "  Local analysis: CPU\n"));
588:         }
589:       } else {
590:         PetscCall(PetscViewerASCIIPrintf(viewer, "  Local analysis: CPU or Kokkos (depending on covariance matrix type)\n"));
591:       }
592:     }
593: #else
594:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Local analysis: CPU\n"));
595: #endif
596:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Local observations per vertex: %" PetscInt_FMT "\n", impl->n_obs_vertex));
597:     if (impl->batch_size > 0) {
598:       PetscCall(PetscViewerASCIIPrintf(viewer, "  GPU batch size: %" PetscInt_FMT "\n", impl->batch_size));
599:     } else {
600:       PetscCall(PetscViewerASCIIPrintf(viewer, "  GPU batch size: auto\n"));
601:     }
602:     if (impl->Q) {
603:       PetscCall(PetscViewerASCIIPrintf(viewer, "  Localization matrix: set\n"));
604:     } else {
605:       PetscCall(PetscViewerASCIIPrintf(viewer, "  Localization matrix: not set\n"));
606:     }
607:   }
608:   PetscFunctionReturn(PETSC_SUCCESS);
609: }

611: static PetscErrorCode PetscDASetFromOptions_LETKF(PetscDA da, PetscOptionItems *PetscOptionsObjectPtr)
612: {
613:   PetscDA_LETKF   *impl               = (PetscDA_LETKF *)da->data;
614:   PetscOptionItems PetscOptionsObject = *PetscOptionsObjectPtr;

616:   PetscFunctionBegin;
617:   PetscCall(PetscDASetFromOptions_Ensemble(da, PetscOptionsObjectPtr));
618:   PetscOptionsHeadBegin(PetscOptionsObject, "PetscDA LETKF Options");
619:   PetscCall(PetscOptionsInt("-petscda_letkf_batch_size", "Batch size for GPU processing", "", impl->batch_size, &impl->batch_size, NULL));
620:   PetscCall(PetscOptionsInt("-petscda_letkf_obs_per_vertex", "Number of local observations per vertex", "", impl->n_obs_vertex, &impl->n_obs_vertex, NULL));
621:   PetscOptionsHeadEnd();
622:   PetscFunctionReturn(PETSC_SUCCESS);
623: }

625: /*MC
626:    PETSCDALETKF - The Local ETKF performs the analysis update locally around each grid point, enabling scalable assimilation on large
627:    domains by avoiding the global ensemble covariance matrix.

629:    Options Database Keys:
630: +  -petscda_type letkf                           - set the `PetscDAType` to `PETSCDALETKF`
631: .  -petscda_ensemble_size <size>                 - number of ensemble members
632: .  -petscda_ensemble_sqrt_type <cholesky, eigen> - the square root of the matrix to use
633: .  -petscda_letkf_batch_size <batch_size>        - set the batch size for GPU processing
634: -  -petscda_letkf_obs_per_vertex <n_obs_vertex>  - number of observations per vertex

636:    Level: beginner

638: .seealso: [](ch_da), `PetscDA`, `PetscDACreate()`, `PETSCDAETKF`, `PetscDALETKFSetObsPerVertex()`, `PetscDALETKFGetObsPerVertex()`,
639:           `PetscDALETKFSetLocalization()`, `PetscDAEnsembleSetSize()`, `PetscDASetSizes()`, `PetscDAEnsembleSetSqrtType()`, `PetscDAEnsembleSetInflation()`,
640:           `PetscDAEnsembleComputeMean()`, `PetscDAEnsembleComputeAnomalies()`, `PetscDAEnsembleAnalysis()`, `PetscDAEnsembleForecast()`
641: M*/

643: PETSC_INTERN PetscErrorCode PetscDACreate_LETKF(PetscDA da)
644: {
645:   PetscDA_LETKF *impl;

647:   PetscFunctionBegin;
648:   PetscCall(PetscNew(&impl));
649:   da->data = impl;
650:   PetscCall(PetscDACreate_Ensemble(da));
651:   da->ops->destroy        = PetscDADestroy_LETKF;
652:   da->ops->view           = PetscDAView_LETKF;
653:   da->ops->setfromoptions = PetscDASetFromOptions_LETKF;
654:   impl->en.analysis       = PetscDAEnsembleAnalysis_LETKF;
655:   impl->en.forecast       = PetscDAEnsembleForecast_Ensemble;

657:   impl->n_obs_vertex = 9;
658:   impl->Q            = NULL;
659:   impl->batch_size   = 0;

661:   /* Register the method for setting localization */
662:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalization_C", PetscDALETKFSetLocalization_LETKF));
663:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetObsPerVertex_C", PetscDALETKFSetObsPerVertex_LETKF));
664:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetObsPerVertex_C", PetscDALETKFGetObsPerVertex_LETKF));
665:   PetscFunctionReturn(PETSC_SUCCESS);
666: }

668: /*@
669:   PetscDALETKFSetObsPerVertex - Sets the number of local observations per vertex for the LETKF algorithm.

671:   Logically Collective

673:   Input Parameters:
674: + da           - the `PetscDA` context
675: - n_obs_vertex - number of observations per vertex

677:   Level: advanced

679: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalization()`
680: @*/
681: PetscErrorCode PetscDALETKFSetObsPerVertex(PetscDA da, PetscInt n_obs_vertex)
682: {
683:   PetscFunctionBegin;
686:   PetscTryMethod(da, "PetscDALETKFSetObsPerVertex_C", (PetscDA, PetscInt), (da, n_obs_vertex));
687:   PetscFunctionReturn(PETSC_SUCCESS);
688: }

690: /*@
691:   PetscDALETKFGetObsPerVertex - Gets the number of local observations per vertex for the LETKF algorithm.

693:   Not Collective

695:   Input Parameter:
696: . da - the `PetscDA` context

698:   Output Parameter:
699: . n_obs_vertex - number of observations per vertex

701:   Level: advanced

703: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetObsPerVertex()`
704: @*/
705: PetscErrorCode PetscDALETKFGetObsPerVertex(PetscDA da, PetscInt *n_obs_vertex)
706: {
707:   PetscFunctionBegin;
709:   PetscAssertPointer(n_obs_vertex, 2);
710:   PetscUseMethod(da, "PetscDALETKFGetObsPerVertex_C", (PetscDA, PetscInt *), (da, n_obs_vertex));
711:   PetscFunctionReturn(PETSC_SUCCESS);
712: }

714: /*@
715:   PetscDALETKFSetLocalization - Sets the localization matrix for the LETKF algorithm.

717:   Collective

719:   Input Parameters:
720: + da - the `PetscDA` context
721: . Q  - the localization matrix (N x P)
722: - H  - the observation operator matrix (P x N)

724:   Level: advanced

726: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`
727: @*/
728: PetscErrorCode PetscDALETKFSetLocalization(PetscDA da, Mat Q, Mat H)
729: {
730:   PetscFunctionBegin;
734:   PetscTryMethod(da, "PetscDALETKFSetLocalization_C", (PetscDA, Mat, Mat), (da, Q, H));
735:   PetscFunctionReturn(PETSC_SUCCESS);
736: }