Actual source code: letkfilter.c

  1: #include <petscda.h>
  2: #include <petsc/private/daimpl.h>
  3: #include <petsc/private/daensembleimpl.h>
  4: #include <petscblaslapack.h>
  5: #include <../src/ml/da/impls/ensemble/letkf/letkf.h>

  7: static PetscErrorCode PetscDALETKFInstallQ(PetscDA, Mat, PetscInt, PetscInt, Mat);
  8: static PetscErrorCode PetscDALETKFResetLocalization_LETKF(PetscDA);

 10: /* Names must match the PetscDALETKFLocalizationType enum order in include/petscda.h. */
 11: const char *const PetscDALETKFLocalizationTypes[] = {"none", "gaspari_cohn", "gaussian", "boxcar", "PetscDALETKFLocalizationType", "PETSCDA_LETKF_LOC_", NULL};

 13: /* The Kokkos analysis paths key off the type of the obs-error covariance Mat (R), since R is
 14:    created via MatSetType + MatSetFromOptions and inherits whatever -mat_type the user requested.
 15:    Returns PETSC_FALSE when R is not yet built or when Kokkos kernels are unavailable. */
 16: static PetscErrorCode PetscDALETKFUseKokkosBackend(PetscDA da, PetscBool *use_kokkos)
 17: {
 18:   PetscFunctionBegin;
 19:   *use_kokkos = PETSC_FALSE;
 20: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
 21:   if (da->R) PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, use_kokkos, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
 22: #endif
 23:   PetscFunctionReturn(PETSC_SUCCESS);
 24: }

 26: /* Free cached coordinate inputs (used only for built-in kernels). */
 27: static PetscErrorCode PetscDALETKFClearCoordinates(PetscDA_LETKF *impl)
 28: {
 29:   PetscFunctionBegin;
 30:   for (PetscInt d = 0; d < 3; d++) {
 31:     PetscCall(VecDestroy(&impl->coord_xyz[d]));
 32:     impl->coord_bd[d] = 0.0;
 33:   }
 34:   PetscCall(MatDestroy(&impl->coord_H));
 35:   PetscFunctionReturn(PETSC_SUCCESS);
 36: }

 38: /*
 39:   PetscDALETKFReplicateWeightVector - replicate weight vector w across all columns of w_ones (m x m dense).
 40:   Used only by the LOC_NONE fast path. w lives on PETSC_COMM_SELF (size m); w_ones is a SELF SeqDense m x m.
 41: */
 42: PETSC_INTERN PetscErrorCode PetscDALETKFReplicateWeightVector(Vec w, PetscInt m, Mat w_ones)
 43: {
 44:   const PetscScalar *w_array;
 45:   PetscScalar       *mat_array;
 46:   PetscInt           w_size, lda, wo_rows, wo_cols;

 48:   PetscFunctionBegin;
 49:   PetscCall(VecGetLocalSize(w, &w_size));
 50:   PetscCheck(w_size == m, PetscObjectComm((PetscObject)w), PETSC_ERR_ARG_INCOMP, "w size %" PetscInt_FMT " != m %" PetscInt_FMT, w_size, m);
 51:   PetscCall(MatGetSize(w_ones, &wo_rows, &wo_cols));
 52:   PetscCheck(wo_rows == m && wo_cols == m, PetscObjectComm((PetscObject)w_ones), PETSC_ERR_ARG_INCOMP, "w_ones must be %" PetscInt_FMT " x %" PetscInt_FMT ", got %" PetscInt_FMT " x %" PetscInt_FMT, m, m, wo_rows, wo_cols);
 53:   PetscCall(VecGetArrayRead(w, &w_array));
 54:   PetscCall(MatDenseGetArrayWrite(w_ones, &mat_array));
 55:   PetscCall(MatDenseGetLDA(w_ones, &lda));
 56:   for (PetscInt i = 0; i < m; i++) PetscCall(PetscArraycpy(mat_array + i * lda, w_array, m));
 57:   PetscCall(MatDenseRestoreArrayWrite(w_ones, &mat_array));
 58:   PetscCall(VecRestoreArrayRead(w, &w_array));
 59:   PetscFunctionReturn(PETSC_SUCCESS);
 60: }

 62: /*
 63:   PetscDALETKFEnsureGlobalScratch - Lazily allocate the per-rank-replicated m-sized SELF scratch
 64:   used by the LOC_NONE fast path (impl->w, impl->s_transpose_delta, impl->T_sqrt, impl->w_ones).
 65:   Both the CPU (PetscDALETKFGlobalAnalysis) and Kokkos (PetscDALETKFGlobalAnalysis_Kokkos) backends
 66:   call this on entry; the per-vertex paths skip it.
 67: */
 68: PETSC_INTERN PetscErrorCode PetscDALETKFEnsureGlobalScratch(PetscDA_LETKF *impl, PetscInt m)
 69: {
 70:   PetscFunctionBegin;
 71:   if (!impl->w) PetscCall(VecCreateSeq(PETSC_COMM_SELF, m, &impl->w));
 72:   if (!impl->s_transpose_delta) PetscCall(VecCreateSeq(PETSC_COMM_SELF, m, &impl->s_transpose_delta));
 73:   if (!impl->T_sqrt) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &impl->T_sqrt));
 74:   if (!impl->w_ones) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &impl->w_ones));
 75:   PetscFunctionReturn(PETSC_SUCCESS);
 76: }

 78: /*
 79:   PetscDALETKFUpdateEnsembleWithTransform - E = mean*1' + X*G.
 80:   Used only by the LOC_NONE fast path. G is replicated on PETSC_COMM_SELF (every rank holds the
 81:   same m x m), X and the ensemble share the same row distribution; the local rows of E are
 82:   X_local * G + mean_local broadcast across columns. Computed via a per-rank BLASgemm for X*G
 83:   plus a column-broadcast add of mean.
 84: */
 85: static PetscErrorCode PetscDALETKFUpdateEnsembleWithTransform(Vec mean, Mat X, Mat G, PetscInt m, Mat ensemble)
 86: {
 87:   const PetscScalar *x_array, *g_array, *mean_array;
 88:   PetscScalar       *xg_buf, *ens_array;
 89:   PetscScalar        one = 1.0, zero = 0.0;
 90:   PetscBLASInt       n_local_b, m_b, lda_x_b, lda_g_b;
 91:   PetscInt           n_local_ens, n_local_x, n_g_rows, n_g_cols, lda_x, lda_g, lda_ens;

 93:   PetscFunctionBegin;
 94:   PetscCall(MatGetLocalSize(ensemble, &n_local_ens, NULL));
 95:   PetscCall(MatGetLocalSize(X, &n_local_x, NULL));
 96:   PetscCheck(n_local_x == n_local_ens, PetscObjectComm((PetscObject)ensemble), PETSC_ERR_ARG_INCOMP, "X local rows (%" PetscInt_FMT ") must match ensemble local rows (%" PetscInt_FMT ")", n_local_x, n_local_ens);
 97:   PetscCall(MatGetSize(G, &n_g_rows, &n_g_cols));
 98:   PetscCheck(n_g_rows == m && n_g_cols == m, PetscObjectComm((PetscObject)ensemble), PETSC_ERR_ARG_INCOMP, "G must be %" PetscInt_FMT " x %" PetscInt_FMT ", got %" PetscInt_FMT " x %" PetscInt_FMT, m, m, n_g_rows, n_g_cols);
 99:   PetscCall(MatDenseGetArrayRead(X, &x_array));
100:   PetscCall(MatDenseGetLDA(X, &lda_x));
101:   PetscCall(MatDenseGetArrayRead(G, &g_array));
102:   PetscCall(MatDenseGetLDA(G, &lda_g));
103:   PetscCall(MatDenseGetArrayWrite(ensemble, &ens_array));
104:   PetscCall(MatDenseGetLDA(ensemble, &lda_ens));
105:   PetscCall(VecGetArrayRead(mean, &mean_array));
106:   PetscCall(PetscMalloc1((size_t)n_local_ens * m, &xg_buf));
107:   PetscCall(PetscBLASIntCast(n_local_ens, &n_local_b));
108:   PetscCall(PetscBLASIntCast(m, &m_b));
109:   PetscCall(PetscBLASIntCast(lda_x, &lda_x_b));
110:   PetscCall(PetscBLASIntCast(lda_g, &lda_g_b));
111:   if (n_local_ens > 0) PetscCallBLAS("BLASgemm", BLASgemm_("N", "N", &n_local_b, &m_b, &m_b, &one, x_array, &lda_x_b, g_array, &lda_g_b, &zero, xg_buf, &n_local_b));
112:   for (PetscInt j = 0; j < m; j++)
113:     for (PetscInt i = 0; i < n_local_ens; i++) ens_array[i + j * lda_ens] = mean_array[i] + xg_buf[i + j * n_local_ens];
114:   PetscCall(PetscFree(xg_buf));
115:   PetscCall(VecRestoreArrayRead(mean, &mean_array));
116:   PetscCall(MatDenseRestoreArrayWrite(ensemble, &ens_array));
117:   PetscCall(MatDenseRestoreArrayRead(G, &g_array));
118:   PetscCall(MatDenseRestoreArrayRead(X, &x_array));
119:   PetscFunctionReturn(PETSC_SUCCESS);
120: }

122: static PetscErrorCode PetscDADestroy_LETKF(PetscDA da)
123: {
124:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

126:   PetscFunctionBegin;
127:   PetscCall(VecDestroy(&impl->mean));
128:   PetscCall(VecDestroy(&impl->y_mean));
129:   PetscCall(VecDestroy(&impl->delta_scaled));
130:   PetscCall(VecDestroy(&impl->w));
131:   PetscCall(VecDestroy(&impl->s_transpose_delta));
132:   PetscCall(VecDestroy(&impl->r_inv_sqrt));
133:   PetscCall(VecDestroy(&impl->H_temp_in));
134:   PetscCall(VecDestroy(&impl->H_temp_out));
135:   PetscCall(PetscFree(impl->H_vec_type));
136:   PetscCall(MatDestroy(&impl->Z));
137:   PetscCall(MatDestroy(&impl->S));
138:   PetscCall(MatDestroy(&impl->T_sqrt));
139:   PetscCall(MatDestroy(&impl->w_ones));
140:   PetscCall(MatDestroy(&impl->Q));
141:   PetscCall(PetscDALETKFClearCoordinates(impl));
142: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
143:   PetscCall(PetscDALETKFDestroyLocalization_Kokkos(impl));
144: #endif
145:   PetscCall(PetscDALETKFDestroyObsScatter(impl));
146:   PetscCall(PetscDADestroy_Ensemble(da));
147:   PetscCall(PetscFree(da->data));

149:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationRadius_C", NULL));
150:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationRadius_C", NULL));
151:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationType_C", NULL));
152:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationType_C", NULL));
153:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationCoordinates_C", NULL));
154:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFResetLocalization_C", NULL));
155:   PetscFunctionReturn(PETSC_SUCCESS);
156: }

158: /*
159:   ExtractLocalObservations - Extracts local observations for a vertex using localization matrix Q (CPU version)

161:   Input Parameters:
162: + Q          - localization matrix (state_size/ndof x obs_size), variable nnz per row
163: . vertex_idx - index of the vertex (row of Q)
164: . Z_global   - global observation ensemble matrix (obs_size x m) OR local work matrix
165: . y_global   - global observation vector (size obs_size) OR local work vector
166: . y_mean_global - global observation mean (size obs_size) OR local work vector
167: . r_inv_sqrt_global - global R^{-1/2} (size obs_size) OR local work vector
168: . obs_g2l    - map from global observation index to local index (if using local work vectors)
169: . m          - ensemble size

171:   Output Parameters:
172: . Z_local    - local observation ensemble (p_local x m), pre-allocated
173: . y_local    - local observation vector (size p_local), pre-allocated
174: . y_mean_local - local observation mean (size p_local), pre-allocated
175: - r_inv_sqrt_local - local R^{-1/2} (size p_local), pre-allocated
176: */
177: static PetscErrorCode ExtractLocalObservations(Mat Q, PetscInt vertex_idx, Mat Z_global, Vec y_global, Vec y_mean_global, Vec r_inv_sqrt_global, PetscHMapI obs_g2l, PetscInt m, Mat Z_local, Vec y_local, Vec y_mean_local, Vec r_inv_sqrt_local)
178: {
179:   const PetscInt    *cols;
180:   const PetscScalar *vals;
181:   PetscInt           ncols, k, j, p_local;
182:   const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
183:   PetscScalar       *z_local_array, *y_local_array, *y_mean_local_array, *r_inv_sqrt_local_array;
184:   PetscInt           lda_z_global, lda_z_local;

186:   PetscFunctionBegin;
187:   /* Get the row of Q corresponding to this vertex */
188:   PetscCall(MatGetRow(Q, vertex_idx, &ncols, &cols, &vals));

190:   /* Get array access to global data */
191:   PetscCall(MatDenseGetArrayRead(Z_global, &z_global_array));
192:   PetscCall(VecGetArrayRead(y_global, &y_global_array));
193:   PetscCall(VecGetArrayRead(y_mean_global, &y_mean_global_array));
194:   PetscCall(VecGetArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));

196:   /* Get array access to local data */
197:   PetscCall(MatDenseGetArrayWrite(Z_local, &z_local_array));
198:   PetscCall(VecGetArray(y_local, &y_local_array));
199:   PetscCall(VecGetArray(y_mean_local, &y_mean_local_array));
200:   PetscCall(VecGetArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));

202:   /* Get leading dimensions */
203:   PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
204:   PetscCall(MatDenseGetLDA(Z_local, &lda_z_local));
205:   PetscCall(VecGetLocalSize(y_local, &p_local));

207:   /* Extract local observations and weight R^{-1/2} */
208:   for (k = 0; k < ncols; k++) {
209:     PetscInt    obs_idx   = cols[k];
210:     PetscScalar weight    = vals[k];
211:     PetscInt    local_idx = obs_idx;

213:     /* If using local work vectors, map global index to local index */
214:     if (obs_g2l) {
215:       PetscCall(PetscHMapIGet(obs_g2l, obs_idx, &local_idx));
216:       PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local map", obs_idx);
217:     }

219:     y_local_array[k]          = y_global_array[local_idx];
220:     y_mean_local_array[k]     = y_mean_global_array[local_idx];
221:     r_inv_sqrt_local_array[k] = r_inv_sqrt_global_array[local_idx] * PetscSqrtScalar(weight);

223:     /* Extract Z matrix row (column-major layout) */
224:     for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = z_global_array[local_idx + j * lda_z_global];
225:   }

227:   /* Zero the unused tail [ncols, p_local) so a shorter row does not leak the previous row's
228:      trailing values into the downstream normalized-innovation computation. Caller does not need
229:      to MatZeroEntries/VecZeroEntries the workspace between iterations. */
230:   for (k = ncols; k < p_local; k++) {
231:     y_local_array[k]          = 0.0;
232:     y_mean_local_array[k]     = 0.0;
233:     r_inv_sqrt_local_array[k] = 0.0;
234:     for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = 0.0;
235:   }

237:   /* Restore arrays */
238:   PetscCall(VecRestoreArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
239:   PetscCall(VecRestoreArray(y_mean_local, &y_mean_local_array));
240:   PetscCall(VecRestoreArray(y_local, &y_local_array));
241:   PetscCall(MatDenseRestoreArrayWrite(Z_local, &z_local_array));
242:   PetscCall(VecRestoreArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
243:   PetscCall(VecRestoreArrayRead(y_mean_global, &y_mean_global_array));
244:   PetscCall(VecRestoreArrayRead(y_global, &y_global_array));
245:   PetscCall(MatDenseRestoreArrayRead(Z_global, &z_global_array));

247:   /* Restore Q row */
248:   PetscCall(MatRestoreRow(Q, vertex_idx, &ncols, &cols, &vals));
249:   PetscFunctionReturn(PETSC_SUCCESS);
250: }

252: /*
253:   PetscDALETKFLocalAnalysis - Performs local LETKF analysis for all grid points (CPU version)

255:   Input Parameters:
256: + da             - the PetscDA context
257: . impl           - LETKF implementation data
258: . m              - ensemble size
259: . n_vertices     - number of grid points
260: . X              - global anomaly matrix (state_size x m)
261: . observation    - observation vector
262: . Z_global       - global observation ensemble (obs_size x m)
263: . y_mean_global  - global observation mean
264: - r_inv_sqrt_global - global R^{-1/2}

266:   Output:
267: . da->ensemble - updated with analysis ensemble

269:   Notes:
270:   This function performs the local analysis loop for LETKF, processing each grid point
271:   independently using its local observations defined by the localization matrix Q.
272:   This is the CPU version that does not use Kokkos acceleration.

274:   All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
275:   y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta_local) are
276:   created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
277: */
278: PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
279: {
280:   PetscDA_Ensemble  *en = &impl->en;
281:   Mat                Z_local, S_local, T_sqrt_local, G_local;
282:   Mat                X_rows, E_analysis_rows;
283:   Vec                y_local, y_mean_local, delta_scaled_local, r_inv_sqrt_local;
284:   Vec                w_local, s_transpose_delta_local;
285:   const PetscScalar *w_array, *x_array, *g_array, *mean_array;
286:   PetscScalar       *g_array_w, *e_array, *x_rows_array, *ea_rows_array;
287:   PetscScalar        one = 1.0, zero = 0.0;
288:   PetscBLASInt       ndof_b, m_b, lda_xrows_b, lda_g_b, lda_ea_b;
289:   PetscInt           ndof, max_nnz, rstart;
290:   PetscInt           lda_x, lda_e, lda_xrows, lda_g, lda_ea;
291:   PetscReal          sqrt_m_minus_1, scale;

293:   PetscFunctionBegin;
294:   ndof           = da->ndof;
295:   sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
296:   scale          = 1.0 / sqrt_m_minus_1;
297:   max_nnz        = impl->max_nnz_per_row;

299:   /* X and ensemble are accessed at row offsets up to (n_vertices-1)*ndof + (ndof-1).
300:      Mirror the precondition the Kokkos path enforces so a bad LDA fails fast on either backend. */
301:   PetscCall(MatDenseGetLDA(X, &lda_x));
302:   PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
303:   PetscCheck(lda_x >= n_vertices * ndof, PetscObjectComm((PetscObject)X), PETSC_ERR_ARG_INCOMP, "X leading dimension %" PetscInt_FMT " < n_vertices*ndof %" PetscInt_FMT, lda_x, n_vertices * ndof);
304:   PetscCheck(lda_e >= n_vertices * ndof, PetscObjectComm((PetscObject)en->ensemble), PETSC_ERR_ARG_INCOMP, "Ensemble leading dimension %" PetscInt_FMT " < n_vertices*ndof %" PetscInt_FMT, lda_e, n_vertices * ndof);

306:   /* Create local analysis workspace (max_nnz x m matrices and vectors) */
307:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, max_nnz, m, NULL, &Z_local));
308:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)Z_local, "dense_"));
309:   PetscCall(MatSetFromOptions(Z_local));
310:   PetscCall(MatSetUp(Z_local));
311:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, max_nnz, m, NULL, &S_local));
312:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)S_local, "dense_"));
313:   PetscCall(MatSetFromOptions(S_local));
314:   PetscCall(MatSetUp(S_local));
315:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &T_sqrt_local));
316:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)T_sqrt_local, "dense_"));
317:   PetscCall(MatSetFromOptions(T_sqrt_local));
318:   PetscCall(MatSetUp(T_sqrt_local));
319:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &G_local));
320:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)G_local, "dense_"));
321:   PetscCall(MatSetFromOptions(G_local));
322:   PetscCall(MatSetUp(G_local));

324:   /* Create vectors using MatCreateVecs() from Z_local (max_nnz x m) */
325:   PetscCall(MatCreateVecs(Z_local, &w_local, &y_local));
326:   PetscCall(VecDuplicate(y_local, &y_mean_local));
327:   PetscCall(VecDuplicate(y_local, &delta_scaled_local));
328:   PetscCall(VecDuplicate(y_local, &r_inv_sqrt_local));
329:   PetscCall(VecDuplicate(w_local, &s_transpose_delta_local));

331:   PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, ndof, m, NULL, &X_rows));
332:   PetscCall(MatDuplicate(X_rows, MAT_DO_NOT_COPY_VALUES, &E_analysis_rows));

334:   /* X_rows, G_local, E_analysis_rows are loop-invariant; their LDAs and the BLAS-int
335:      casts of the gemm shape never change inside the n_vertices loop. Hoist to spare
336:      the dispatch overhead at every vertex. */
337:   PetscCall(MatDenseGetLDA(G_local, &lda_g));
338:   PetscCall(MatDenseGetLDA(X_rows, &lda_xrows));
339:   PetscCall(MatDenseGetLDA(E_analysis_rows, &lda_ea));
340:   PetscCall(PetscBLASIntCast(ndof, &ndof_b));
341:   PetscCall(PetscBLASIntCast(m, &m_b));
342:   PetscCall(PetscBLASIntCast(lda_xrows, &lda_xrows_b));
343:   PetscCall(PetscBLASIntCast(lda_g, &lda_g_b));
344:   PetscCall(PetscBLASIntCast(lda_ea, &lda_ea_b));

346:   /* LETKF: Loop over all grid points and perform local analysis */
347:   PetscCall(MatGetOwnershipRange(impl->Q, &rstart, NULL));

349:   /* X, impl->mean, and en->ensemble are loop-invariant; their array views are read or
350:      written at offsets that change per iteration but the underlying storage does not.
351:      Hoisting the Get/Restore pairs out of the n_vertices loop avoids repeated lock and
352:      validation overhead inside the hot path. Safe to hold across the loop because the
353:      inner MatDenseGetArray{,Read,Write} calls on X_rows, E_analysis_rows, and G_local
354:      operate on disjoint SELF matrices and never alias X/mean/en->ensemble. */
355:   PetscCall(MatDenseGetArrayRead(X, &x_array));
356:   PetscCall(VecGetArrayRead(impl->mean, &mean_array));
357:   PetscCall(MatDenseGetArrayWrite(en->ensemble, &e_array));

359:   for (PetscInt i_grid_point = 0; i_grid_point < n_vertices; i_grid_point++) {
360:     /* Extract local observations for this grid point using Q[i_grid_point,:].
361:        ExtractLocalObservations() zeros the unwritten [ncols, max_nnz) tail of each
362:        workspace, so we do not need to MatZeroEntries/VecZeroEntries every iteration. */
363:     /* Note: i_grid_point is local index, but MatGetRow needs global index */
364:     PetscCall(ExtractLocalObservations(impl->Q, rstart + i_grid_point, Z_global, observation, y_mean_global, r_inv_sqrt_global, impl->obs_g2l, m, Z_local, y_local, y_mean_local, r_inv_sqrt_local));

366:     /* Compute local normalized innovation matrix: S_local = R_local^{-1/2} * (Z_local - y_mean_local * 1') / sqrt(m - 1) */
367:     PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(Z_local, y_mean_local, r_inv_sqrt_local, m, scale, S_local));

369:     /* Compute local delta_scaled = R_local^{-1/2} * (y_local - y_mean_local) */
370:     PetscCall(VecWAXPY(delta_scaled_local, -1.0, y_mean_local, y_local));
371:     PetscCall(VecPointwiseMult(delta_scaled_local, delta_scaled_local, r_inv_sqrt_local));

373:     /* Factor local T = (I + S_local^T * S_local) */
374:     PetscCall(PetscDAEnsembleTFactor(da, S_local));

376:     /* Compute local analysis weights: w_local = T_local^{-1} * S_local^T * delta_scaled_local */
377:     PetscCall(MatMultTranspose(S_local, delta_scaled_local, s_transpose_delta_local));
378:     PetscCall(PetscDAEnsembleApplyTInverse(da, s_transpose_delta_local, w_local));

380:     /* Compute local square-root transform: T_sqrt_local = T_local^{-1/2} (U is identity, so pass NULL) */
381:     PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, T_sqrt_local));

383:     /* Form local transform G_local = w_local * 1' + sqrt(m - 1) * T_sqrt_local * U
384:        Instead of creating w_ones_local = w_local * 1', we add w_local to each column of G_local */
385:     PetscCall(MatCopy(T_sqrt_local, G_local, SAME_NONZERO_PATTERN));
386:     PetscCall(MatScale(G_local, sqrt_m_minus_1));
387:     PetscCall(VecGetArrayRead(w_local, &w_array));
388:     PetscCall(MatDenseGetArray(G_local, &g_array_w));
389:     for (PetscInt j = 0; j < m; j++)
390:       for (PetscInt k = 0; k < m; k++) g_array_w[k + j * lda_g] += w_array[k];
391:     PetscCall(MatDenseRestoreArray(G_local, &g_array_w));
392:     PetscCall(VecRestoreArrayRead(w_local, &w_array));

394:     /* LETKF Algorithm 2, Line 13: Update ensemble at grid point i_grid_point
395:        E_a[i,:] = x_bar_f[i] + X_f[i,:] * G_local

397:        Where:
398:        - x_bar_f[i] is the forecast mean at grid point i_grid_point (ndof values from global mean vector)
399:        - X_f[i,:] is the forecast anomaly rows at grid point i_grid_point (ndof rows from global anomaly matrix X)
400:        - G_local = w_local * 1' + sqrt(m-1) * T_local^{1/2} * U (computed above in G_local)
401:      */
402:     /* Extract ndof rows starting at (i_grid_point * ndof) from X: X_f[i_grid_point*ndof:(i_grid_point+1)*ndof, :]
403:        Hold X_rows / E_analysis_rows with a single read/write GetArray each so the fill, gemm,
404:        mean-add, and copy-out share one Get/Restore pair per vertex. */
405:     PetscCall(MatDenseGetArray(X_rows, &x_rows_array));
406:     for (PetscInt j = 0; j < m; j++)
407:       for (PetscInt k = 0; k < ndof; k++) x_rows_array[k + j * lda_xrows] = x_array[(i_grid_point * ndof + k) + j * lda_x];

409:     /* Apply local transform via direct BLASgemm: E_analysis_rows = X_rows * G_local.
410:        Replaces a per-vertex MatMatMult; ndof and m are typically small (1-100), so the
411:        MatProduct dispatch overhead dominated. */
412:     PetscCall(MatDenseGetArrayRead(G_local, &g_array));
413:     PetscCall(MatDenseGetArray(E_analysis_rows, &ea_rows_array));
414:     PetscCallBLAS("BLASgemm", BLASgemm_("N", "N", &ndof_b, &m_b, &m_b, &one, x_rows_array, &lda_xrows_b, g_array, &lda_g_b, &zero, ea_rows_array, &lda_ea_b));
415:     PetscCall(MatDenseRestoreArrayRead(G_local, &g_array));
416:     PetscCall(MatDenseRestoreArray(X_rows, &x_rows_array));

418:     /* Add local mean and store result back in ensemble at row offset i_grid_point*ndof. */
419:     for (PetscInt j = 0; j < m; j++) {
420:       for (PetscInt k = 0; k < ndof; k++) {
421:         ea_rows_array[k + j * lda_ea] += mean_array[i_grid_point * ndof + k];
422:         e_array[(i_grid_point * ndof + k) + j * lda_e] = ea_rows_array[k + j * lda_ea];
423:       }
424:     }
425:     PetscCall(MatDenseRestoreArray(E_analysis_rows, &ea_rows_array));
426:   }
427:   PetscCall(MatDenseRestoreArrayWrite(en->ensemble, &e_array));
428:   PetscCall(VecRestoreArrayRead(impl->mean, &mean_array));
429:   PetscCall(MatDenseRestoreArrayRead(X, &x_array));
430:   PetscCall(MatDestroy(&E_analysis_rows));
431:   PetscCall(MatDestroy(&X_rows));
432:   PetscCall(VecDestroy(&s_transpose_delta_local));
433:   PetscCall(VecDestroy(&w_local));
434:   PetscCall(VecDestroy(&r_inv_sqrt_local));
435:   PetscCall(VecDestroy(&delta_scaled_local));
436:   PetscCall(VecDestroy(&y_mean_local));
437:   PetscCall(VecDestroy(&y_local));
438:   PetscCall(MatDestroy(&G_local));
439:   PetscCall(MatDestroy(&T_sqrt_local));
440:   PetscCall(MatDestroy(&S_local));
441:   PetscCall(MatDestroy(&Z_local));
442:   PetscFunctionReturn(PETSC_SUCCESS);
443: }

445: /*
446:   PetscDALETKFGlobalAnalysis - LOC_NONE fast path: a single global ETKF analysis with no
447:   per-vertex localization. The m x m T factor and weight vector live on PETSC_COMM_SELF so
448:   every rank does the identical eigendecomp; only the gram S^T*S and the projection S^T*delta
449:   need an MPI reduction. Dispatches to the Kokkos backend when R is a Kokkos matrix.
450: */
451: static PetscErrorCode PetscDALETKFGlobalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, Mat X, Vec observation)
452: {
453:   const PetscScalar *s_array, *d_array;
454:   PetscScalar       *gram, *buf;
455:   PetscScalar        one = 1.0, zero = 0.0;
456:   PetscBLASInt       m_b, n_obs_local_b, s_lda_b, ione = 1;
457:   PetscMPIInt        m_squared_mpi, m_mpi;
458:   PetscInt           n_obs_local, s_lda;
459:   PetscReal          sqrt_m_minus_1, scale;
460:   PetscBool          use_kokkos;

462:   PetscFunctionBegin;
463:   sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
464:   scale          = 1.0 / sqrt_m_minus_1;
465:   PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));
466: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
467:   if (use_kokkos) {
468:     PetscCall(PetscDALETKFGlobalAnalysis_Kokkos(da, impl, m, X, observation));
469:     PetscFunctionReturn(PETSC_SUCCESS);
470:   }
471: #endif
472:   PetscCheck(!use_kokkos, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "Kokkos backend selected but PetscDALETKFGlobalAnalysis_Kokkos is unavailable in this build");

474:   PetscCall(PetscDALETKFEnsureGlobalScratch(impl, m));

476:   /* S = R^{-1/2} * (Z - y_mean*1') / sqrt(m-1) */
477:   PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(impl->Z, impl->y_mean, impl->r_inv_sqrt, m, scale, impl->S));

479:   /* delta_scaled = R^{-1/2} * (y^o - y_mean) */
480:   PetscCall(VecWAXPY(impl->delta_scaled, -1.0, impl->y_mean, observation));
481:   PetscCall(VecPointwiseMult(impl->delta_scaled, impl->delta_scaled, impl->r_inv_sqrt));

483:   /* Factor T = (1/rho)I + S^T*S replicated on every rank. */
484:   PetscCall(PetscCalloc1((size_t)m * m, &gram));
485:   PetscCall(MatGetLocalSize(impl->S, &n_obs_local, NULL));
486:   PetscCall(MatDenseGetArrayRead(impl->S, &s_array));
487:   PetscCall(MatDenseGetLDA(impl->S, &s_lda));
488:   PetscCall(PetscBLASIntCast(m, &m_b));
489:   PetscCall(PetscBLASIntCast(n_obs_local, &n_obs_local_b));
490:   PetscCall(PetscBLASIntCast(s_lda, &s_lda_b));
491:   if (n_obs_local > 0) PetscCallBLAS("BLASgemm", BLASgemm_("T", "N", &m_b, &m_b, &n_obs_local_b, &one, s_array, &s_lda_b, s_array, &s_lda_b, &zero, gram, &m_b));
492:   PetscCall(PetscMPIIntCast((PetscInt64)m * m, &m_squared_mpi));
493:   PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, gram, m_squared_mpi, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)da)));
494:   PetscCall(PetscDAEnsembleTFactorFromGram(da, m, gram));
495:   PetscCall(PetscFree(gram));

497:   /* w = T^{-1} * (S^T * delta_scaled), with the projection reduced across ranks. Hold the
498:      buffer with VecGetArray across both the local gemv and the in-place allreduce so the
499:      reduction sees this rank's contribution (VecGetArrayWrite would make the post-gemv data
500:      undefined after restore). Ranks with no local obs zero the buffer directly because gemv
501:      (which would overwrite via beta = 0) is skipped on n_obs_local == 0. */
502:   PetscCall(VecGetArrayRead(impl->delta_scaled, &d_array));
503:   PetscCall(VecGetArray(impl->s_transpose_delta, &buf));
504:   if (n_obs_local > 0) PetscCallBLAS("BLASgemv", BLASgemv_("T", &n_obs_local_b, &m_b, &one, s_array, &s_lda_b, d_array, &ione, &zero, buf, &ione));
505:   else PetscCall(PetscArrayzero(buf, m));
506:   PetscCall(PetscMPIIntCast(m, &m_mpi));
507:   PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, buf, m_mpi, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)da)));
508:   PetscCall(VecRestoreArray(impl->s_transpose_delta, &buf));
509:   PetscCall(VecRestoreArrayRead(impl->delta_scaled, &d_array));
510:   PetscCall(MatDenseRestoreArrayRead(impl->S, &s_array));

512:   PetscCall(PetscDAEnsembleApplyTInverse(da, impl->s_transpose_delta, impl->w));

514:   /* T_sqrt = T^{-1/2} */
515:   PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, impl->T_sqrt));

517:   /* G = w*1' + sqrt(m-1) * T_sqrt (in impl->w_ones, all on PETSC_COMM_SELF). */
518:   PetscCall(PetscDALETKFReplicateWeightVector(impl->w, m, impl->w_ones));
519:   PetscCall(MatAXPY(impl->w_ones, sqrt_m_minus_1, impl->T_sqrt, SAME_NONZERO_PATTERN));

521:   /* E = mean*1' + X * G */
522:   PetscCall(PetscDALETKFUpdateEnsembleWithTransform(impl->mean, X, impl->w_ones, m, impl->en.ensemble));
523:   PetscFunctionReturn(PETSC_SUCCESS);
524: }

526: /*
527:   PetscDALETKFRebuildHTemps - ensure the cached H-compatible work vecs match H's current
528:   layout and vec type, rebuilding them (and any caches that depend on H's backend) when H has
529:   drifted since the last analysis. The two `if` blocks (invalidate, then allocate) must remain
530:   separate rather than an if/else: when the first block fires it nulls `impl->H_temp_in` via
531:   VecDestroy, and the second `if (!impl->H_temp_in)` must then re-create the temps in the same
532:   call. Collapsing to if/else would leave a freshly-destroyed cache uninitialized until the next
533:   analysis, breaking the post-condition that on return the H_temp_* and H_vec_type fields are
534:   populated and consistent with the current H.
535: */
536: static PetscErrorCode PetscDALETKFRebuildHTemps(PetscDA da, PetscDA_LETKF *impl, Mat H)
537: {
538:   PetscInt  cur_in_local, cur_out_local, want_in_local, want_out_local;
539:   VecType   want_type;
540:   PetscBool type_match;

542:   PetscFunctionBegin;
543:   PetscCall(MatGetVecType(H, &want_type));
544:   if (impl->H_temp_in) {
545:     PetscCall(VecGetLocalSize(impl->H_temp_in, &cur_in_local));
546:     PetscCall(VecGetLocalSize(impl->H_temp_out, &cur_out_local));
547:     PetscCall(MatGetLocalSize(H, &want_out_local, &want_in_local));
548:     PetscCall(PetscStrcmp(impl->H_vec_type, want_type, &type_match));
549:     if (!type_match || cur_in_local != want_in_local || cur_out_local != want_out_local) {
550:       PetscCall(VecDestroy(&impl->H_temp_in));
551:       PetscCall(VecDestroy(&impl->H_temp_out));
552:       PetscCall(PetscFree(impl->H_vec_type));
553:       /* The obs-scatter source layout is templated off H, and Q's device mirrors live in the
554:          backend matching the old H vec type (Kokkos vs host); reset the full localization
555:          cache so the next analysis rebuilds Q and its mirrors against the new H. */
556:       PetscCall(PetscDALETKFResetLocalization_LETKF(da));
557:     }
558:   }
559:   if (!impl->H_temp_in) {
560:     PetscCall(MatCreateVecs(H, &impl->H_temp_in, &impl->H_temp_out));
561:     PetscCall(PetscStrallocpy(want_type, &impl->H_vec_type));
562:   }
563:   PetscFunctionReturn(PETSC_SUCCESS);
564: }

566: static PetscErrorCode PetscDAEnsembleAnalysis_LETKF(PetscDA da, Vec observation, Mat H)
567: {
568:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
569:   Mat            X;
570:   PetscInt       m;
571:   PetscBool      reallocate = PETSC_FALSE;

573:   PetscFunctionBegin;
574:   m = impl->en.size;
575:   PetscCheck(m >= 2, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size must be >= 2 for LETKF; got %" PetscInt_FMT, m);

577:   /* Check for reallocation needs */
578:   if (impl->mean) {
579:     PetscInt mean_size;
580:     PetscCall(VecGetSize(impl->mean, &mean_size));
581:     if (mean_size != da->state_size) reallocate = PETSC_TRUE;
582:   }
583:   if (impl->Z) {
584:     PetscInt z_rows, z_cols;
585:     PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
586:     if (z_rows != da->obs_size || z_cols != m) reallocate = PETSC_TRUE;
587:   }
588:   /* impl->T_sqrt is owned only by the LOC_NONE fast path (PetscDALETKFEnsureGlobalScratch());
589:      the per-vertex path uses its own SELF T_sqrt_local. This check fires on a NONE -> per-vertex
590:      transition with a changed m, where the stale T_sqrt would mis-size the next NONE cycle. */
591:   if (impl->T_sqrt) {
592:     PetscInt t_rows, t_cols;
593:     PetscCall(MatGetSize(impl->T_sqrt, &t_rows, &t_cols));
594:     if (t_rows != m || t_cols != m) reallocate = PETSC_TRUE;
595:   }

597:   /* Initialize or reallocate persistent work objects */
598:   if (!impl->mean || reallocate) {
599:     /* On reallocation the cached Q (and obs scatter / Kokkos device buffers) describe a
600:        prior state_size/obs_size and must be torn down so the next analysis rebuilds them
601:        against the new layout. Skip on first-time init (nothing to reset yet). */
602:     if (reallocate) PetscCall(PetscDALETKFResetLocalization_LETKF(da));
603:     PetscCall(VecDestroy(&impl->mean));
604:     PetscCall(VecDestroy(&impl->y_mean));
605:     PetscCall(VecDestroy(&impl->delta_scaled));
606:     PetscCall(VecDestroy(&impl->w));
607:     PetscCall(VecDestroy(&impl->s_transpose_delta));
608:     PetscCall(VecDestroy(&impl->r_inv_sqrt));
609:     PetscCall(VecDestroy(&impl->H_temp_in));
610:     PetscCall(VecDestroy(&impl->H_temp_out));
611:     PetscCall(PetscFree(impl->H_vec_type));
612:     PetscCall(MatDestroy(&impl->Z));
613:     PetscCall(MatDestroy(&impl->S));
614:     PetscCall(MatDestroy(&impl->T_sqrt));
615:     PetscCall(MatDestroy(&impl->w_ones));

617:     /* Create mean vector from ensemble matrix (left vector = state space) */
618:     PetscCall(MatCreateVecs(impl->en.ensemble, NULL, &impl->mean));

620:     /* Create Z matrix (obs_size x m) */
621:     PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, da->obs_size, m, NULL, &impl->Z));
622:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->Z, "dense_"));
623:     PetscCall(MatSetFromOptions(impl->Z));
624:     PetscCall(MatSetUp(impl->Z));

626:     /* Create observation space vectors from Z matrix (left vector = observation space) */
627:     PetscCall(MatCreateVecs(impl->Z, NULL, &impl->y_mean));
628:     PetscCall(VecDuplicate(impl->y_mean, &impl->delta_scaled));
629:     PetscCall(VecDuplicate(da->obs_error_var, &impl->r_inv_sqrt));

631:     /* Create S matrix (same layout as Z) */
632:     PetscCall(MatDuplicate(impl->Z, MAT_DO_NOT_COPY_VALUES, &impl->S));

634:     /* T_sqrt and w_ones are m x m and used only by the LOC_NONE fast path; allocate them
635:        lazily in PetscDALETKFGlobalAnalysis() so the per-vertex paths do not pay for them. */
636:   }

638:   /* Alg 6.4 line 1-2: Compute ensemble mean and scaled anomalies */
639:   PetscCall(PetscDAEnsembleComputeMean(da, impl->mean));

641:   /* Create anomaly matrix X = (E - x_mean * 1') / sqrt(m - 1) */
642:   PetscCall(PetscDAEnsembleComputeAnomalies(da, impl->mean, &X));

644:   /* Alg 6.4 line 3-4: Compute GLOBAL observation ensemble Z = H * E column-by-column,
645:      staged through H-compatible cached work vecs because impl->Z (MATDENSE) and H
646:      (possibly MATAIJKOKKOS) cannot share a MatMatMult product type. */
647:   /* Lazily allocate / rebuild the cached H-compatible work vecs (and reset Q if H's vec-type
648:      backend changed). */
649:   PetscCall(PetscDALETKFRebuildHTemps(da, impl, H));

651:   /* Lazily build Q for built-in distance-based kernels using cached coordinates. The dispatcher
652:      selects host vs Kokkos backend from the type of the cached observation operator. Setters
653:      destroy Q via PetscDALETKFResetLocalization() when their inputs change, so a non-NULL Q is
654:      guaranteed to match the current (type, radius, coord_*) tuple. Built after PetscDALETKFRebuildHTemps()
655:      because that call may reset Q via PetscDALETKFResetLocalization_LETKF() when H's vec-type
656:      backend changed since the last analysis. */
657:   if (impl->type != PETSCDA_LETKF_LOC_NONE && !impl->Q) {
658:     Mat       Q_new = NULL;
659:     PetscInt  max_nnz_local, n_nnz_local;
660:     PetscBool use_kokkos;

662:     PetscCheck(impl->coord_H, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Coordinates not set; call PetscDALETKFSetLocalizationCoordinates() before analysis");
663:     PetscCheck(impl->localization_radius > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Localization radius not set; call PetscDALETKFSetLocalizationRadius() before analysis");
664:     /* Q's backend must match the analysis-time backend, which keys off da->R; otherwise a CPU
665:        analysis would walk a Kokkos Q (or vice versa) and pay backend-mismatch transfer overhead. */
666:     PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));
667:     /* Backends compute the per-rank max-nnz/row and total local nnz from row_counts[] in hand;
668:        passing them through avoids both a MatGetRow walk and a MatGetInfo call in
669:        PetscDALETKFInstallQ() (both of which force a device->host sync on AIJKOKKOS Q).
670:        The analysis-time `H` is threaded into InstallQ as the obs-scatter source template so the
671:        scatter rows match the vectors actually being scattered, even if `H`'s row partition or
672:        vec type differs from the `coord_H` cached at SetLocalizationCoordinates time. */
673:     PetscCall(PetscDALETKFCreateLocalizationMat(impl->type, impl->localization_radius, impl->coord_xyz, impl->coord_bd, impl->coord_H, use_kokkos, &Q_new, &max_nnz_local, &n_nnz_local));
674:     PetscCall(PetscDALETKFInstallQ(da, Q_new, max_nnz_local, n_nnz_local, H));
675:     PetscCall(MatDestroy(&Q_new));
676:   }

678:   /* Compute Z = H * E column by column to avoid Kokkos vector type issues */
679:   for (PetscInt j = 0; j < m; j++) {
680:     Vec col_in, col_out;
681:     PetscCall(MatDenseGetColumnVecRead(impl->en.ensemble, j, &col_in));
682:     PetscCall(MatDenseGetColumnVecWrite(impl->Z, j, &col_out));
683:     PetscCall(VecCopy(col_in, impl->H_temp_in));
684:     PetscCall(MatMult(H, impl->H_temp_in, impl->H_temp_out));
685:     PetscCall(VecCopy(impl->H_temp_out, col_out));
686:     PetscCall(MatDenseRestoreColumnVecWrite(impl->Z, j, &col_out));
687:     PetscCall(MatDenseRestoreColumnVecRead(impl->en.ensemble, j, &col_in));
688:   }

690:   /* Compute GLOBAL observation mean y_mean = H * x_mean using the same cached temps. */
691:   PetscCall(VecCopy(impl->mean, impl->H_temp_in));
692:   PetscCall(MatMult(H, impl->H_temp_in, impl->H_temp_out));
693:   PetscCall(VecCopy(impl->H_temp_out, impl->y_mean));

695:   /* Compute GLOBAL R^{-1/2} (assumes diagonal R) */
696:   PetscCall(VecCopy(da->obs_error_var, impl->r_inv_sqrt));
697:   PetscCall(VecSqrtAbs(impl->r_inv_sqrt));
698:   PetscCall(VecReciprocal(impl->r_inv_sqrt));

700:   if (impl->type == PETSCDA_LETKF_LOC_NONE) PetscCall(PetscDALETKFGlobalAnalysis(da, impl, m, X, observation));
701:   else {
702:     PetscInt  n_local, n_obs_local, rows_old, cols_old;
703:     PetscBool use_kokkos;

705:     PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));

707:     /* Per-vertex local analysis path. PetscDALETKFInstallQ() above already templated the
708:        obs-scatter on the live `H`, and PetscDALETKFRebuildHTemps() invalidates Q (and therefore
709:        triggers a re-InstallQ here) whenever H's row layout or vec type drifts, so impl->obs_scat
710:        is guaranteed non-NULL and compatible with the live observation vectors. */
711:     PetscCall(VecScatterBegin(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
712:     PetscCall(VecScatterEnd(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));

714:     PetscCall(VecScatterBegin(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
715:     PetscCall(VecScatterEnd(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));

717:     PetscCall(VecScatterBegin(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
718:     PetscCall(VecScatterEnd(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));

720:     PetscCall(VecGetLocalSize(impl->obs_work, &n_obs_local));
721:     if (impl->Z_work) {
722:       PetscCall(MatGetSize(impl->Z_work, &rows_old, &cols_old));
723:       if (rows_old != n_obs_local || cols_old != m) PetscCall(MatDestroy(&impl->Z_work));
724:     }
725:     if (!impl->Z_work) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
726:     for (PetscInt i = 0; i < m; i++) {
727:       Vec z_col_global, z_col_local;
728:       PetscCall(MatDenseGetColumnVecRead(impl->Z, i, &z_col_global));
729:       PetscCall(MatDenseGetColumnVecWrite(impl->Z_work, i, &z_col_local));
730:       PetscCall(VecScatterBegin(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
731:       PetscCall(VecScatterEnd(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
732:       PetscCall(MatDenseRestoreColumnVecRead(impl->Z, i, &z_col_global));
733:       PetscCall(MatDenseRestoreColumnVecWrite(impl->Z_work, i, &z_col_local));
734:     }

736:     PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
737: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
738:     if (use_kokkos) PetscCall(PetscDALETKFLocalAnalysis_Kokkos(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
739:     else PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
740: #else
741:     PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
742: #endif
743:   }

745:   PetscCall(MatDestroy(&X));
746:   /* Self-call ViewFromOptions at the tail, mirroring KSPSolve()/SNESSolve(). Fires every cycle
747:      when the user passes -petscda_view; tutorials that want a single end-of-run snapshot call
748:      PetscDAView() explicitly after the DA loop. */
749:   PetscCall(PetscDAViewFromOptions(da, NULL, "-petscda_view"));
750:   PetscFunctionReturn(PETSC_SUCCESS);
751: }

753: /*
754:   PetscDALETKFResetLocalization_LETKF - destroy the cached Q matrix, obs scatter, and any device
755:   buffers tied to Q. Coordinates/type/radius are preserved so the next analysis rebuilds Q from
756:   the current inputs. Called by the setters that mutate Q-determining inputs.
757: */
758: static PetscErrorCode PetscDALETKFResetLocalization_LETKF(PetscDA da)
759: {
760:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

762:   PetscFunctionBegin;
763:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
764: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
765:   /* Drop only the Q device mirrors; the persistent cusolver/rocblas/SYCL handle and the
766:      eigensolver workspace are reused across Q rebuilds. */
767:   if (impl->Q) PetscCall(PetscDALETKFDestroyQDeviceMirrors_Kokkos(impl));
768: #endif
769:   PetscCall(PetscDALETKFDestroyObsScatter(impl));
770:   PetscCall(MatDestroy(&impl->Q));
771:   impl->max_nnz_per_row = 0;
772:   impl->n_nnz_local     = 0;
773:   PetscFunctionReturn(PETSC_SUCCESS);
774: }

776: static PetscErrorCode PetscDALETKFSetLocalizationRadius_LETKF(PetscDA da, PetscReal radius)
777: {
778:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

780:   PetscFunctionBegin;
781:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
782:   PetscCheck(radius > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Localization radius must be positive, got %g", (double)radius);
783:   /* Exact equality: a tolerance would silently keep a stale Q after a small intentional bump. */
784:   if (impl->localization_radius != radius) {
785:     impl->localization_radius = radius;
786:     PetscCall(PetscDALETKFResetLocalization_LETKF(da));
787:   }
788:   PetscFunctionReturn(PETSC_SUCCESS);
789: }

791: static PetscErrorCode PetscDALETKFSetLocalizationType_LETKF(PetscDA da, PetscDALETKFLocalizationType type)
792: {
793:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

795:   PetscFunctionBegin;
796:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
797:   PetscCheck(type >= 0 && type < PETSCDA_LETKF_LOC_NUM_TYPES, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Invalid localization type %d; must be in [0,%d)", (int)type, (int)PETSCDA_LETKF_LOC_NUM_TYPES);
798:   if (impl->type != type) {
799:     impl->type = type;
800:     PetscCall(PetscDALETKFResetLocalization_LETKF(da));
801:   }
802:   PetscFunctionReturn(PETSC_SUCCESS);
803: }

805: static PetscErrorCode PetscDALETKFGetLocalizationType_LETKF(PetscDA da, PetscDALETKFLocalizationType *type)
806: {
807:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

809:   PetscFunctionBegin;
810:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
811:   *type = impl->type;
812:   PetscFunctionReturn(PETSC_SUCCESS);
813: }

815: static PetscErrorCode PetscDALETKFSetLocalizationCoordinates_LETKF(PetscDA da, const Vec xyz[3], const PetscReal bd[3], Mat H)
816: {
817:   PetscDA_LETKF *impl    = (PetscDA_LETKF *)da->data;
818:   PetscBool      changed = PETSC_FALSE;
819:   PetscInt       H_rows, H_cols, vert_global, vert_local;

821:   PetscFunctionBegin;
822:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
823:   PetscAssertPointer(xyz, 2);
824:   /* bd[d] > 0 selects periodic handling for dimension d; bd[d] == 0 means non-periodic. Reject
825:      negative values so a stray sign flip cannot silently re-interpret as non-periodic. */
826:   if (bd)
827:     for (PetscInt d = 0; d < 3; d++) PetscCheck(bd[d] >= 0.0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Periodic-domain extent bd[%" PetscInt_FMT "] must be non-negative (use 0 for non-periodic), got %g", d, (double)bd[d]);
828:   /* Validate H and xyz[0] against the PetscDA's recorded sizes at the API boundary so a
829:      structurally mismatched H or coordinate vector is rejected here, where the caller can fix
830:      it, rather than after the previous Q/obs-scatter has already been torn down inside the
831:      lazy-build path. */
832:   PetscCall(MatGetSize(H, &H_rows, &H_cols));
833:   PetscCheck(H_rows == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "H has %" PetscInt_FMT " rows; PetscDA obs_size is %" PetscInt_FMT, H_rows, da->obs_size);
834:   PetscCheck(da->ndof > 0 && da->state_size % da->ndof == 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "state_size (%" PetscInt_FMT ") must be a positive multiple of ndof (%" PetscInt_FMT ")", da->state_size, da->ndof);
835:   PetscCall(VecGetSize(xyz[0], &vert_global));
836:   PetscCall(VecGetLocalSize(xyz[0], &vert_local));
837:   PetscCheck(vert_global == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "xyz[0] global size %" PetscInt_FMT " != vertex count state_size/ndof (%" PetscInt_FMT ")", vert_global, da->state_size / da->ndof);
838:   /* H's columns must match xyz[0]'s rows so PetscDALETKFComputeObsCoords() can MatMult(H, xyz[d]).
839:      The multi-DOF observation operator (cols == state_size) is a frequent mistake here; reject it
840:      before PetscDALETKFResetLocalization_LETKF() tears down the previous Q. */
841:   PetscCheck(H_cols == vert_global, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "H has %" PetscInt_FMT " columns; expected %" PetscInt_FMT " (vertex count, == xyz[0] global size). Did you pass the multi-DOF observation operator instead of the per-vertex one?", H_cols, vert_global);
842:   /* Per-dim slots beyond xyz[0] must share xyz[0]'s global size and local partition; otherwise
843:      PetscDALETKFGatherObsBbox() would walk mismatched coordinate arrays and read past valid memory. */
844:   for (PetscInt d = 1; d < 3; d++) {
845:     PetscInt other_global, other_local;

847:     if (!xyz[d]) continue;
848:     PetscCall(VecGetSize(xyz[d], &other_global));
849:     PetscCall(VecGetLocalSize(xyz[d], &other_local));
850:     PetscCheck(other_global == vert_global, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "xyz[%" PetscInt_FMT "] global size %" PetscInt_FMT " != xyz[0] global size %" PetscInt_FMT, d, other_global, vert_global);
851:     PetscCheck(other_local == vert_local, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "xyz[%" PetscInt_FMT "] local size %" PetscInt_FMT " != xyz[0] local size %" PetscInt_FMT, d, other_local, vert_local);
852:   }
853:   /* Compare against the cached (xyz, bd, H) tuple by pointer/value so that re-supplying the same
854:      geometry (a common pattern when the tutorial reapplies the same observation operator each
855:      analysis cycle) does not invalidate Q and force the obs-scatter and device buffers to be
856:      rebuilt. The contract requires the user to call this again after mutating any of these objects. */
857:   for (PetscInt d = 0; d < 3; d++) {
858:     if (impl->coord_xyz[d] != xyz[d]) changed = PETSC_TRUE;
859:     if (impl->coord_bd[d] != (bd ? bd[d] : 0.0)) changed = PETSC_TRUE;
860:   }
861:   if (impl->coord_H != H) changed = PETSC_TRUE;
862:   if (!changed) PetscFunctionReturn(PETSC_SUCCESS);

864:   PetscCall(PetscDALETKFClearCoordinates(impl));
865:   for (PetscInt d = 0; d < 3; d++) {
866:     if (xyz[d]) {
867:       PetscCall(PetscObjectReference((PetscObject)xyz[d]));
868:       impl->coord_xyz[d] = xyz[d];
869:     }
870:     impl->coord_bd[d] = bd ? bd[d] : 0.0;
871:   }
872:   PetscCall(PetscObjectReference((PetscObject)H));
873:   impl->coord_H = H;
874:   PetscCall(PetscDALETKFResetLocalization_LETKF(da));
875:   PetscFunctionReturn(PETSC_SUCCESS);
876: }

878: static PetscErrorCode PetscDALETKFGetLocalizationRadius_LETKF(PetscDA da, PetscReal *radius)
879: {
880:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

882:   PetscFunctionBegin;
883:   PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
884:   *radius = impl->localization_radius;
885:   PetscFunctionReturn(PETSC_SUCCESS);
886: }

888: /*
889:   Install a freshly built localization matrix Q (validate sizes, cache row nnz bounds, wire Kokkos
890:   device buffers when applicable). Only called from the lazy-build path. The Kokkos device-side
891:   setup runs only when PETSc was built with Kokkos kernels; the bare CPU analysis path (used by
892:   serial non-Kokkos builds) needs only the size validation and nnz bookkeeping below.

894:   scatter_H templates the obs-scatter source layout. The caller passes the analysis-time `H` so
895:   that the scatter source matches the vectors the per-vertex path is about to scatter, even when
896:   the user's analysis-time `H` has a different row partition or vec type than the `coord_H` that
897:   was cached at SetLocalizationCoordinates time. Q's column footprint (the unique global obs
898:   indices each rank touches) is independent of scatter_H's row layout, so this is well-defined as
899:   long as scatter_H and Q agree on the global obs-space size - `PetscDALETKFSetupObsScatter()`
900:   PetscCheck's that.
901: */
902: static PetscErrorCode PetscDALETKFInstallQ(PetscDA da, Mat Q, PetscInt max_nnz_local, PetscInt n_nnz_local, Mat scatter_H)
903: {
904:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
905:   PetscInt       nrows, ncols;
906:   PetscBool      use_kokkos;

908:   PetscFunctionBegin;
909:   PetscCheck(da->ndof > 0 && da->state_size % da->ndof == 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "state_size (%" PetscInt_FMT ") must be a positive multiple of ndof (%" PetscInt_FMT ")", da->state_size, da->ndof);
910:   PetscCall(MatGetSize(Q, &nrows, &ncols));
911:   PetscCheck(nrows == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix rows (%" PetscInt_FMT ") must equal vertex count state_size/ndof (%" PetscInt_FMT ")", nrows, da->state_size / da->ndof);
912:   PetscCheck(ncols == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix columns (%" PetscInt_FMT ") must match observation size (%" PetscInt_FMT ")", ncols, da->obs_size);
913:   PetscCheck(max_nnz_local >= 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "max_nnz_local must be >= 0, got %" PetscInt_FMT, max_nnz_local);
914:   PetscCheck(n_nnz_local >= 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "n_nnz_local must be >= 0, got %" PetscInt_FMT, n_nnz_local);
915:   /* The obs-scatter is needed by both the CPU and Kokkos per-vertex paths whenever Q exists.
916:      Validate before tearing down the previous Q/obs-scatter so a missing prerequisite leaves
917:      the impl in its prior usable state instead of a half-installed one. */
919:   PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));

921: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
922:   /* Drop only the Q device mirrors; the eigensolver workspace and solver handle persist. */
923:   if (impl->Q) PetscCall(PetscDALETKFDestroyQDeviceMirrors_Kokkos(impl));
924: #endif
925:   /* Destroy the previous obs-scatter so SetupObsScatter() can rebuild it for the new Q footprint. */
926:   PetscCall(PetscDALETKFDestroyObsScatter(impl));

928:   PetscCall(PetscObjectReference((PetscObject)Q));
929:   PetscCall(MatDestroy(&impl->Q));
930:   impl->Q = Q;

932:   /* The CSR backends already walked row_counts[] to size their allocations; reuse the per-rank
933:      max (and total local nnz) computed there instead of MatGetRow/MatGetInfo walks, which would
934:      force a device->host sync on AIJKOKKOS Q. Allreduce the max so impl->max_nnz_per_row holds
935:      the global max all ranks need for sizing; n_nnz_local stays per-rank. */
936:   impl->max_nnz_per_row = max_nnz_local;
937:   PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &impl->max_nnz_per_row, 1, MPIU_INT, MPI_MAX, PetscObjectComm((PetscObject)da)));
938:   impl->n_nnz_local = n_nnz_local;

940:   PetscCall(PetscDALETKFSetupObsScatter(impl, scatter_H));
941: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
942:   /* Gate the device-mirror setup on the analysis-time backend: a Kokkos-capable build with a
943:      non-Kokkos da->R runs the CPU per-vertex path, which never reads Q_device_*. Skipping the
944:      setup avoids a redundant MatGetRow walk over every local row of Q and the three device-resident
945:      Kokkos Views it would allocate. */
946:   if (use_kokkos) PetscCall(PetscDALETKFSetupLocalization_Kokkos(impl));
947: #endif
948:   PetscFunctionReturn(PETSC_SUCCESS);
949: }

951: static PetscErrorCode PetscDAView_LETKF(PetscDA da, PetscViewer viewer)
952: {
953:   PetscBool      iascii = PETSC_FALSE, is_kokkos = PETSC_FALSE;
954:   PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;

956:   PetscFunctionBegin;
957:   PetscCall(PetscDAView_Ensemble(da, viewer));
958:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
959:   if (iascii) {
960:     PetscCall(PetscDALETKFUseKokkosBackend(da, &is_kokkos));
961:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Local analysis: %s\n", is_kokkos ? "Kokkos" : "CPU"));
962:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Localization type: %s\n", PetscDALETKFLocalizationTypes[impl->type]));
963:     if (impl->type != PETSCDA_LETKF_LOC_NONE) {
964:       if (impl->localization_radius > 0.0) PetscCall(PetscViewerASCIIPrintf(viewer, "  Localization radius: %g\n", (double)impl->localization_radius));
965:       else PetscCall(PetscViewerASCIIPrintf(viewer, "  Localization radius: (unset)\n"));
966:       if (is_kokkos) {
967:         if (impl->batch_size > 0) PetscCall(PetscViewerASCIIPrintf(viewer, "  GPU batch size: %" PetscInt_FMT "\n", impl->batch_size));
968:         else PetscCall(PetscViewerASCIIPrintf(viewer, "  GPU batch size: auto\n"));
969:       }
970:     }
971:   }
972:   PetscFunctionReturn(PETSC_SUCCESS);
973: }

975: static PetscErrorCode PetscDASetFromOptions_LETKF(PetscDA da, PetscOptionItems *PetscOptionsObjectPtr)
976: {
977:   PetscDA_LETKF   *impl               = (PetscDA_LETKF *)da->data;
978:   PetscOptionItems PetscOptionsObject = *PetscOptionsObjectPtr;
979:   PetscReal        radius;
980:   PetscInt         type_idx, batch_size;
981:   PetscBool        type_set = PETSC_FALSE, radius_set = PETSC_FALSE;

983:   PetscFunctionBegin;
984:   PetscCall(PetscDASetFromOptions_Ensemble(da, PetscOptionsObjectPtr));
985:   PetscOptionsHeadBegin(PetscOptionsObject, "PetscDA LETKF Options");
986:   batch_size = impl->batch_size;
987:   PetscCall(PetscOptionsInt("-petscda_letkf_batch_size", "Batch size for GPU processing (0 = auto)", "PETSCDALETKF", batch_size, &batch_size, NULL));
988:   PetscCheck(batch_size >= 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "batch_size must be >= 0, got %" PetscInt_FMT, batch_size);
989:   impl->batch_size = batch_size;
990:   radius           = impl->localization_radius;
991:   PetscCall(PetscOptionsReal("-petscda_letkf_localization_radius", "Localization cutoff radius for built-in kernels", "PetscDALETKFSetLocalizationRadius", radius, &radius, &radius_set));
992:   if (radius_set) PetscCall(PetscDALETKFSetLocalizationRadius(da, radius));
993:   type_idx = (PetscInt)impl->type;
994:   PetscCall(PetscOptionsEList("-petscda_letkf_localization_type", "Localization kernel type", "PetscDALETKFSetLocalizationType", PetscDALETKFLocalizationTypes, PETSCDA_LETKF_LOC_NUM_TYPES, PetscDALETKFLocalizationTypes[type_idx], &type_idx, &type_set));
995:   if (type_set) PetscCall(PetscDALETKFSetLocalizationType(da, (PetscDALETKFLocalizationType)type_idx));
996:   PetscOptionsHeadEnd();
997:   PetscFunctionReturn(PETSC_SUCCESS);
998: }

1000: /*MC
1001:    PETSCDALETKF - The Local ETKF performs the analysis update locally around each grid point, enabling scalable assimilation on large
1002:    domains by avoiding the global ensemble covariance matrix.

1004:    Options Database Keys:
1005: + -petscda_type letkf                                                  - set the `PetscDAType` to `PETSCDALETKF`
1006: . -petscda_ensemble_size size                                          - number of ensemble members
1007: . -petscda_ensemble_inflation factor                                   - multiplicative inflation factor applied to anomalies
1008: . -petscda_letkf_batch_size batch_size                                 - set the batch size for GPU processing
1009: . -petscda_letkf_localization_radius radius                            - localization cutoff radius for the built-in kernels (must be positive)
1010: . -petscda_letkf_localization_type (none|gaspari_cohn|gaussian|boxcar) - select the localization kernel
1011: - -petscda_view                                                        - view the `PetscDA` at the end of every `PetscDAEnsembleAnalysis()` call

1013:    Level: beginner

1015:    Notes:
1016:    The default localization kernel is `PETSCDA_LETKF_LOC_GASPARI_COHN`, which requires the user to
1017:    call `PetscDALETKFSetLocalizationRadius()` and `PetscDALETKFSetLocalizationCoordinates()` before
1018:    the first analysis. To skip localization entirely use `PetscDALETKFSetLocalizationType(da, PETSCDA_LETKF_LOC_NONE)`
1019:    (or `-petscda_letkf_localization_type none`).
1020:    Both the CPU and Kokkos analysis paths support multi-rank runs; the Kokkos backend is selected
1021:    when the covariance matrix `da->R` is a Kokkos AIJ type, otherwise the CPU per-vertex (or LOC_NONE
1022:    replicated) path is used.
1023:    `-petscda_view` fires at the tail of every `PetscDAEnsembleAnalysis()` call (mirroring `KSPSolve()`/`SNESSolve()`),
1024:    so over a multi-cycle assimilation run the view is emitted once per analysis. Code that wants a single
1025:    end-of-run snapshot should call `PetscDAView()` explicitly after the assimilation loop instead.

1027: .seealso: [](ch_da), `PetscDA`, `PetscDACreate()`, `PetscDALETKFSetLocalizationRadius()`, `PetscDALETKFGetLocalizationRadius()`,
1028:           `PetscDALETKFSetLocalizationType()`, `PetscDALETKFGetLocalizationType()`, `PetscDALETKFSetLocalizationCoordinates()`,
1029:           `PetscDALETKFResetLocalization()`, `PetscDAEnsembleSetSize()`, `PetscDASetSizes()`, `PetscDAEnsembleSetInflation()`,
1030:           `PetscDAEnsembleComputeMean()`, `PetscDAEnsembleComputeAnomalies()`, `PetscDAEnsembleAnalysis()`, `PetscDAEnsembleForecast()`
1031: M*/

1033: PETSC_INTERN PetscErrorCode PetscDACreate_LETKF(PetscDA da)
1034: {
1035:   PetscDA_LETKF *impl;

1037:   PetscFunctionBegin;
1038:   PetscCall(PetscNew(&impl));
1039:   da->data = impl;
1040:   PetscCall(PetscDACreate_Ensemble(da));
1041:   da->ops->setup          = PetscDASetUp_Ensemble;
1042:   da->ops->destroy        = PetscDADestroy_LETKF;
1043:   da->ops->view           = PetscDAView_LETKF;
1044:   da->ops->setfromoptions = PetscDASetFromOptions_LETKF;
1045:   impl->en.analysis       = PetscDAEnsembleAnalysis_LETKF;
1046:   impl->en.forecast       = PetscDAEnsembleForecast_Ensemble;

1048:   impl->type = PETSCDA_LETKF_LOC_GASPARI_COHN;

1050:   /* Register the method for setting localization */
1051:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationRadius_C", PetscDALETKFSetLocalizationRadius_LETKF));
1052:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationRadius_C", PetscDALETKFGetLocalizationRadius_LETKF));
1053:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationType_C", PetscDALETKFSetLocalizationType_LETKF));
1054:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationType_C", PetscDALETKFGetLocalizationType_LETKF));
1055:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationCoordinates_C", PetscDALETKFSetLocalizationCoordinates_LETKF));
1056:   PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFResetLocalization_C", PetscDALETKFResetLocalization_LETKF));
1057:   PetscFunctionReturn(PETSC_SUCCESS);
1058: }

1060: /*@
1061:   PetscDALETKFSetLocalizationRadius - Sets the localization cutoff radius used by LETKF's built-in distance-based kernels.

1063:   Logically Collective

1065:   Input Parameters:
1066: + da     - the `PetscDA` context
1067: - radius - the localization cutoff radius (must be positive; use a large value for effectively no localization)

1069:   Level: advanced

1071: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationCoordinates()`, `PetscDALETKFGetLocalizationRadius()`
1072: @*/
1073: PetscErrorCode PetscDALETKFSetLocalizationRadius(PetscDA da, PetscReal radius)
1074: {
1075:   PetscFunctionBegin;
1078:   PetscTryMethod(da, "PetscDALETKFSetLocalizationRadius_C", (PetscDA, PetscReal), (da, radius));
1079:   PetscFunctionReturn(PETSC_SUCCESS);
1080: }

1082: /*@
1083:   PetscDALETKFGetLocalizationRadius - Gets the localization cutoff radius used by LETKF's built-in distance-based kernels.

1085:   Not Collective

1087:   Input Parameter:
1088: . da - the `PetscDA` context

1090:   Output Parameter:
1091: . radius - the localization cutoff radius

1093:   Level: advanced

1095: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationRadius()`
1096: @*/
1097: PetscErrorCode PetscDALETKFGetLocalizationRadius(PetscDA da, PetscReal *radius)
1098: {
1099:   PetscFunctionBegin;
1101:   PetscAssertPointer(radius, 2);
1102:   PetscUseMethod(da, "PetscDALETKFGetLocalizationRadius_C", (PetscDA, PetscReal *), (da, radius));
1103:   PetscFunctionReturn(PETSC_SUCCESS);
1104: }

1106: /*@
1107:   PetscDALETKFSetLocalizationType - Selects the localization kernel used by `PETSCDALETKF`.

1109:   Logically Collective

1111:   Input Parameters:
1112: + da   - the `PetscDA` context
1113: - type - the kernel type (see `PetscDALETKFLocalizationType`)

1115:   Level: intermediate

1117:   Notes:
1118:   Use `PETSCDA_LETKF_LOC_NONE` to bypass localization entirely; the analysis is then mathematically
1119:   equivalent to the global ETKF and dispatches through a single global eigensolve plus a dense
1120:   `BLASgemm` weight transform reduced across ranks, instead of the per-vertex local loop.

1122:   For the built-in distance-based kernels (`PETSCDA_LETKF_LOC_GASPARI_COHN`, `PETSCDA_LETKF_LOC_GAUSSIAN`,
1123:   `PETSCDA_LETKF_LOC_BOXCAR`) you must also call `PetscDALETKFSetLocalizationRadius()` and
1124:   `PetscDALETKFSetLocalizationCoordinates()`. The localization matrix is then constructed
1125:   lazily before the first analysis.
1126:   All three built-in kernels are 1 at distance 0; `radius` selects the effective support but the
1127:   cutoff distance and continuity at the cutoff differ.
1128:   `PETSCDA_LETKF_LOC_GASPARI_COHN` is compactly supported with cutoff at distance `2*radius`, and
1129:   is C^1 continuous everywhere (it tapers smoothly to zero at the cutoff).
1130:   `PETSCDA_LETKF_LOC_GAUSSIAN` is `exp(-d^2 / (2*radius^2))` truncated at distance `2*radius`; the
1131:   truncation introduces a discontinuity of `exp(-2)` (~0.135) at the cutoff, so prefer
1132:   `PETSCDA_LETKF_LOC_GASPARI_COHN` if a smooth taper at the cutoff matters.
1133:   `PETSCDA_LETKF_LOC_BOXCAR` is 1 inside `radius` and 0 outside; the discontinuity is by design.

1135: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFLocalizationType`, `PetscDALETKFGetLocalizationType()`,
1136:           `PetscDALETKFSetLocalizationRadius()`, `PetscDALETKFSetLocalizationCoordinates()`
1137: @*/
1138: PetscErrorCode PetscDALETKFSetLocalizationType(PetscDA da, PetscDALETKFLocalizationType type)
1139: {
1140:   PetscFunctionBegin;
1143:   PetscCheck(type >= PETSCDA_LETKF_LOC_NONE && type < PETSCDA_LETKF_LOC_NUM_TYPES, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Localization type %d out of range [0, %d)", (int)type, (int)PETSCDA_LETKF_LOC_NUM_TYPES);
1144:   PetscTryMethod(da, "PetscDALETKFSetLocalizationType_C", (PetscDA, PetscDALETKFLocalizationType), (da, type));
1145:   PetscFunctionReturn(PETSC_SUCCESS);
1146: }

1148: /*@
1149:   PetscDALETKFGetLocalizationType - Returns the localization kernel currently used by `PETSCDALETKF`.

1151:   Not Collective

1153:   Input Parameter:
1154: . da - the `PetscDA` context

1156:   Output Parameter:
1157: . type - the kernel type

1159:   Level: intermediate

1161: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFLocalizationType`, `PetscDALETKFSetLocalizationType()`
1162: @*/
1163: PetscErrorCode PetscDALETKFGetLocalizationType(PetscDA da, PetscDALETKFLocalizationType *type)
1164: {
1165:   PetscFunctionBegin;
1167:   PetscAssertPointer(type, 2);
1168:   PetscUseMethod(da, "PetscDALETKFGetLocalizationType_C", (PetscDA, PetscDALETKFLocalizationType *), (da, type));
1169:   PetscFunctionReturn(PETSC_SUCCESS);
1170: }

1172: /*@
1173:   PetscDALETKFSetLocalizationCoordinates - Provides the geometry used to lazily build the
1174:   localization matrix `Q` for a built-in `PETSCDALETKF` kernel.

1176:   Collective

1178:   Input Parameters:
1179: + da  - the `PetscDA` context
1180: . xyz - length-3 array of coordinate vectors, one per spatial dimension; set unused trailing
1181:         slots to `NULL` (the spatial dimension is taken to be the index of the first `NULL`,
1182:         so `{x, y, NULL}` is 2D and `{x, NULL, NULL}` is 1D)
1183: . bd  - length-3 array of periodic-domain extents (use 0 for non-periodic dimensions); pass
1184:         `NULL` to mean fully non-periodic
1185: - H   - the observation operator (used to map state-space coordinates to observation locations)

1187:   Level: intermediate

1189:   Notes:
1190:   The `xyz` array must always have three slots even in 1D or 2D; trailing slots are set to `NULL`.
1191:   This matches the internal cached layout `coord_xyz[3]` and the layout used by both Q backends.

1193:   The localization matrix `Q` is built on first analysis (or whenever the type, radius or
1194:   coordinates change) using the kernel selected by `PetscDALETKFSetLocalizationType()`. If the
1195:   current type is `PETSCDA_LETKF_LOC_NONE`, the coordinates are cached but the analysis continues
1196:   to run the NONE fast path; switch to a distance-based kernel via
1197:   `PetscDALETKFSetLocalizationType()` for the cached coordinates to take effect.
1198:   The reference counts on `xyz` and `H` are increased; the caller may destroy them afterwards.

1200:   The cached coordinate `Vec`s are referenced, not deep-copied. If the caller mutates the contents
1201:   of any element of `xyz` after this call (for example, after a remesh or recoordinate step), the
1202:   cached `Q` will not be rebuilt automatically; call `PetscDALETKFResetLocalization()` to invalidate
1203:   `Q` and force a rebuild on the next analysis.

1205:   The columns of `Q` are global indices into the observation vector, derived from the row
1206:   ownership and sparsity of the `H` cached here. The `H` passed to `PetscDAEnsembleAnalysis()`
1207:   must therefore use the same global row indexing (same observation ordering and the same global
1208:   obs-space size) as the `H` cached here. The analysis-time `H` may differ from the cached `H`
1209:   in MPI row partition or vec type - the obs-scatter is templated on the analysis-time `H`, so
1210:   those differences are absorbed automatically. A structurally different `H` (rows referring to
1211:   different physical observations) will produce wrong analyses without raising an error; in that
1212:   case, call `PetscDALETKFSetLocalizationCoordinates()` again with the new `H` to rebuild `Q`.

1214: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationType()`,
1215:           `PetscDALETKFSetLocalizationRadius()`
1216: @*/
1217: PetscErrorCode PetscDALETKFSetLocalizationCoordinates(PetscDA da, const Vec xyz[3], const PetscReal bd[3], Mat H)
1218: {
1219:   PetscFunctionBegin;
1221:   /* xyz is required; bd is optional. Use an always-on PetscCheck rather than the debug-only
1222:      PetscAssertPointer() so a NULL xyz argument is rejected cleanly in optimized builds before
1223:      the xyz[0] dereference below. */
1224:   PetscCheck(xyz, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_NULL, "xyz must be a non-NULL length-3 array of Vec");
1225:   PetscCheck(xyz[0], PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONG, "xyz[0] must be a valid Vec; the spatial dimension is taken to be the index of the first NULL slot in xyz[3]");
1227:   if (bd)
1229:   PetscTryMethod(da, "PetscDALETKFSetLocalizationCoordinates_C", (PetscDA, const Vec[3], const PetscReal[3], Mat), (da, xyz, bd, H));
1230:   PetscFunctionReturn(PETSC_SUCCESS);
1231: }

1233: /*@
1234:   PetscDALETKFResetLocalization - Discards the cached localization matrix `Q` so the next analysis
1235:   rebuilds it from the current type, radius, and coordinates.

1237:   Collective

1239:   Input Parameter:
1240: . da - the `PetscDA` context

1242:   Level: advanced

1244:   Notes:
1245:   The setters `PetscDALETKFSetLocalizationType()`, `PetscDALETKFSetLocalizationRadius()`, and
1246:   `PetscDALETKFSetLocalizationCoordinates()` already invalidate `Q` when their inputs actually
1247:   change, so most users never need to call this directly. Use it when an input was mutated outside
1248:   of the setters (for example, the entries of a cached coordinate `Vec` were edited in place, or
1249:   the cached observation operator `H` was reassembled with different sparsity).

1251: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationType()`,
1252:           `PetscDALETKFSetLocalizationRadius()`, `PetscDALETKFSetLocalizationCoordinates()`
1253: @*/
1254: PetscErrorCode PetscDALETKFResetLocalization(PetscDA da)
1255: {
1256:   PetscFunctionBegin;
1258:   PetscTryMethod(da, "PetscDALETKFResetLocalization_C", (PetscDA), (da));
1259:   PetscFunctionReturn(PETSC_SUCCESS);
1260: }