Actual source code: letkfilter.c
1: #include <petscda.h>
2: #include <petsc/private/daimpl.h>
3: #include <petsc/private/daensembleimpl.h>
4: #include <petscblaslapack.h>
5: #include <../src/ml/da/impls/ensemble/letkf/letkf.h>
7: static PetscErrorCode PetscDALETKFInstallQ(PetscDA, Mat, PetscInt, PetscInt, Mat);
8: static PetscErrorCode PetscDALETKFResetLocalization_LETKF(PetscDA);
10: /* Names must match the PetscDALETKFLocalizationType enum order in include/petscda.h. */
11: const char *const PetscDALETKFLocalizationTypes[] = {"none", "gaspari_cohn", "gaussian", "boxcar", "PetscDALETKFLocalizationType", "PETSCDA_LETKF_LOC_", NULL};
13: /* The Kokkos analysis paths key off the type of the obs-error covariance Mat (R), since R is
14: created via MatSetType + MatSetFromOptions and inherits whatever -mat_type the user requested.
15: Returns PETSC_FALSE when R is not yet built or when Kokkos kernels are unavailable. */
16: static PetscErrorCode PetscDALETKFUseKokkosBackend(PetscDA da, PetscBool *use_kokkos)
17: {
18: PetscFunctionBegin;
19: *use_kokkos = PETSC_FALSE;
20: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
21: if (da->R) PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, use_kokkos, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
22: #endif
23: PetscFunctionReturn(PETSC_SUCCESS);
24: }
26: /* Free cached coordinate inputs (used only for built-in kernels). */
27: static PetscErrorCode PetscDALETKFClearCoordinates(PetscDA_LETKF *impl)
28: {
29: PetscFunctionBegin;
30: for (PetscInt d = 0; d < 3; d++) {
31: PetscCall(VecDestroy(&impl->coord_xyz[d]));
32: impl->coord_bd[d] = 0.0;
33: }
34: PetscCall(MatDestroy(&impl->coord_H));
35: PetscFunctionReturn(PETSC_SUCCESS);
36: }
38: /*
39: PetscDALETKFReplicateWeightVector - replicate weight vector w across all columns of w_ones (m x m dense).
40: Used only by the LOC_NONE fast path. w lives on PETSC_COMM_SELF (size m); w_ones is a SELF SeqDense m x m.
41: */
42: PETSC_INTERN PetscErrorCode PetscDALETKFReplicateWeightVector(Vec w, PetscInt m, Mat w_ones)
43: {
44: const PetscScalar *w_array;
45: PetscScalar *mat_array;
46: PetscInt w_size, lda, wo_rows, wo_cols;
48: PetscFunctionBegin;
49: PetscCall(VecGetLocalSize(w, &w_size));
50: PetscCheck(w_size == m, PetscObjectComm((PetscObject)w), PETSC_ERR_ARG_INCOMP, "w size %" PetscInt_FMT " != m %" PetscInt_FMT, w_size, m);
51: PetscCall(MatGetSize(w_ones, &wo_rows, &wo_cols));
52: PetscCheck(wo_rows == m && wo_cols == m, PetscObjectComm((PetscObject)w_ones), PETSC_ERR_ARG_INCOMP, "w_ones must be %" PetscInt_FMT " x %" PetscInt_FMT ", got %" PetscInt_FMT " x %" PetscInt_FMT, m, m, wo_rows, wo_cols);
53: PetscCall(VecGetArrayRead(w, &w_array));
54: PetscCall(MatDenseGetArrayWrite(w_ones, &mat_array));
55: PetscCall(MatDenseGetLDA(w_ones, &lda));
56: for (PetscInt i = 0; i < m; i++) PetscCall(PetscArraycpy(mat_array + i * lda, w_array, m));
57: PetscCall(MatDenseRestoreArrayWrite(w_ones, &mat_array));
58: PetscCall(VecRestoreArrayRead(w, &w_array));
59: PetscFunctionReturn(PETSC_SUCCESS);
60: }
62: /*
63: PetscDALETKFEnsureGlobalScratch - Lazily allocate the per-rank-replicated m-sized SELF scratch
64: used by the LOC_NONE fast path (impl->w, impl->s_transpose_delta, impl->T_sqrt, impl->w_ones).
65: Both the CPU (PetscDALETKFGlobalAnalysis) and Kokkos (PetscDALETKFGlobalAnalysis_Kokkos) backends
66: call this on entry; the per-vertex paths skip it.
67: */
68: PETSC_INTERN PetscErrorCode PetscDALETKFEnsureGlobalScratch(PetscDA_LETKF *impl, PetscInt m)
69: {
70: PetscFunctionBegin;
71: if (!impl->w) PetscCall(VecCreateSeq(PETSC_COMM_SELF, m, &impl->w));
72: if (!impl->s_transpose_delta) PetscCall(VecCreateSeq(PETSC_COMM_SELF, m, &impl->s_transpose_delta));
73: if (!impl->T_sqrt) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &impl->T_sqrt));
74: if (!impl->w_ones) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &impl->w_ones));
75: PetscFunctionReturn(PETSC_SUCCESS);
76: }
78: /*
79: PetscDALETKFUpdateEnsembleWithTransform - E = mean*1' + X*G.
80: Used only by the LOC_NONE fast path. G is replicated on PETSC_COMM_SELF (every rank holds the
81: same m x m), X and the ensemble share the same row distribution; the local rows of E are
82: X_local * G + mean_local broadcast across columns. Computed via a per-rank BLASgemm for X*G
83: plus a column-broadcast add of mean.
84: */
85: static PetscErrorCode PetscDALETKFUpdateEnsembleWithTransform(Vec mean, Mat X, Mat G, PetscInt m, Mat ensemble)
86: {
87: const PetscScalar *x_array, *g_array, *mean_array;
88: PetscScalar *xg_buf, *ens_array;
89: PetscScalar one = 1.0, zero = 0.0;
90: PetscBLASInt n_local_b, m_b, lda_x_b, lda_g_b;
91: PetscInt n_local_ens, n_local_x, n_g_rows, n_g_cols, lda_x, lda_g, lda_ens;
93: PetscFunctionBegin;
94: PetscCall(MatGetLocalSize(ensemble, &n_local_ens, NULL));
95: PetscCall(MatGetLocalSize(X, &n_local_x, NULL));
96: PetscCheck(n_local_x == n_local_ens, PetscObjectComm((PetscObject)ensemble), PETSC_ERR_ARG_INCOMP, "X local rows (%" PetscInt_FMT ") must match ensemble local rows (%" PetscInt_FMT ")", n_local_x, n_local_ens);
97: PetscCall(MatGetSize(G, &n_g_rows, &n_g_cols));
98: PetscCheck(n_g_rows == m && n_g_cols == m, PetscObjectComm((PetscObject)ensemble), PETSC_ERR_ARG_INCOMP, "G must be %" PetscInt_FMT " x %" PetscInt_FMT ", got %" PetscInt_FMT " x %" PetscInt_FMT, m, m, n_g_rows, n_g_cols);
99: PetscCall(MatDenseGetArrayRead(X, &x_array));
100: PetscCall(MatDenseGetLDA(X, &lda_x));
101: PetscCall(MatDenseGetArrayRead(G, &g_array));
102: PetscCall(MatDenseGetLDA(G, &lda_g));
103: PetscCall(MatDenseGetArrayWrite(ensemble, &ens_array));
104: PetscCall(MatDenseGetLDA(ensemble, &lda_ens));
105: PetscCall(VecGetArrayRead(mean, &mean_array));
106: PetscCall(PetscMalloc1((size_t)n_local_ens * m, &xg_buf));
107: PetscCall(PetscBLASIntCast(n_local_ens, &n_local_b));
108: PetscCall(PetscBLASIntCast(m, &m_b));
109: PetscCall(PetscBLASIntCast(lda_x, &lda_x_b));
110: PetscCall(PetscBLASIntCast(lda_g, &lda_g_b));
111: if (n_local_ens > 0) PetscCallBLAS("BLASgemm", BLASgemm_("N", "N", &n_local_b, &m_b, &m_b, &one, x_array, &lda_x_b, g_array, &lda_g_b, &zero, xg_buf, &n_local_b));
112: for (PetscInt j = 0; j < m; j++)
113: for (PetscInt i = 0; i < n_local_ens; i++) ens_array[i + j * lda_ens] = mean_array[i] + xg_buf[i + j * n_local_ens];
114: PetscCall(PetscFree(xg_buf));
115: PetscCall(VecRestoreArrayRead(mean, &mean_array));
116: PetscCall(MatDenseRestoreArrayWrite(ensemble, &ens_array));
117: PetscCall(MatDenseRestoreArrayRead(G, &g_array));
118: PetscCall(MatDenseRestoreArrayRead(X, &x_array));
119: PetscFunctionReturn(PETSC_SUCCESS);
120: }
122: static PetscErrorCode PetscDADestroy_LETKF(PetscDA da)
123: {
124: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
126: PetscFunctionBegin;
127: PetscCall(VecDestroy(&impl->mean));
128: PetscCall(VecDestroy(&impl->y_mean));
129: PetscCall(VecDestroy(&impl->delta_scaled));
130: PetscCall(VecDestroy(&impl->w));
131: PetscCall(VecDestroy(&impl->s_transpose_delta));
132: PetscCall(VecDestroy(&impl->r_inv_sqrt));
133: PetscCall(VecDestroy(&impl->H_temp_in));
134: PetscCall(VecDestroy(&impl->H_temp_out));
135: PetscCall(PetscFree(impl->H_vec_type));
136: PetscCall(MatDestroy(&impl->Z));
137: PetscCall(MatDestroy(&impl->S));
138: PetscCall(MatDestroy(&impl->T_sqrt));
139: PetscCall(MatDestroy(&impl->w_ones));
140: PetscCall(MatDestroy(&impl->Q));
141: PetscCall(PetscDALETKFClearCoordinates(impl));
142: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
143: PetscCall(PetscDALETKFDestroyLocalization_Kokkos(impl));
144: #endif
145: PetscCall(PetscDALETKFDestroyObsScatter(impl));
146: PetscCall(PetscDADestroy_Ensemble(da));
147: PetscCall(PetscFree(da->data));
149: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationRadius_C", NULL));
150: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationRadius_C", NULL));
151: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationType_C", NULL));
152: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationType_C", NULL));
153: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationCoordinates_C", NULL));
154: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFResetLocalization_C", NULL));
155: PetscFunctionReturn(PETSC_SUCCESS);
156: }
158: /*
159: ExtractLocalObservations - Extracts local observations for a vertex using localization matrix Q (CPU version)
161: Input Parameters:
162: + Q - localization matrix (state_size/ndof x obs_size), variable nnz per row
163: . vertex_idx - index of the vertex (row of Q)
164: . Z_global - global observation ensemble matrix (obs_size x m) OR local work matrix
165: . y_global - global observation vector (size obs_size) OR local work vector
166: . y_mean_global - global observation mean (size obs_size) OR local work vector
167: . r_inv_sqrt_global - global R^{-1/2} (size obs_size) OR local work vector
168: . obs_g2l - map from global observation index to local index (if using local work vectors)
169: . m - ensemble size
171: Output Parameters:
172: . Z_local - local observation ensemble (p_local x m), pre-allocated
173: . y_local - local observation vector (size p_local), pre-allocated
174: . y_mean_local - local observation mean (size p_local), pre-allocated
175: - r_inv_sqrt_local - local R^{-1/2} (size p_local), pre-allocated
176: */
177: static PetscErrorCode ExtractLocalObservations(Mat Q, PetscInt vertex_idx, Mat Z_global, Vec y_global, Vec y_mean_global, Vec r_inv_sqrt_global, PetscHMapI obs_g2l, PetscInt m, Mat Z_local, Vec y_local, Vec y_mean_local, Vec r_inv_sqrt_local)
178: {
179: const PetscInt *cols;
180: const PetscScalar *vals;
181: PetscInt ncols, k, j, p_local;
182: const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
183: PetscScalar *z_local_array, *y_local_array, *y_mean_local_array, *r_inv_sqrt_local_array;
184: PetscInt lda_z_global, lda_z_local;
186: PetscFunctionBegin;
187: /* Get the row of Q corresponding to this vertex */
188: PetscCall(MatGetRow(Q, vertex_idx, &ncols, &cols, &vals));
190: /* Get array access to global data */
191: PetscCall(MatDenseGetArrayRead(Z_global, &z_global_array));
192: PetscCall(VecGetArrayRead(y_global, &y_global_array));
193: PetscCall(VecGetArrayRead(y_mean_global, &y_mean_global_array));
194: PetscCall(VecGetArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
196: /* Get array access to local data */
197: PetscCall(MatDenseGetArrayWrite(Z_local, &z_local_array));
198: PetscCall(VecGetArray(y_local, &y_local_array));
199: PetscCall(VecGetArray(y_mean_local, &y_mean_local_array));
200: PetscCall(VecGetArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
202: /* Get leading dimensions */
203: PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
204: PetscCall(MatDenseGetLDA(Z_local, &lda_z_local));
205: PetscCall(VecGetLocalSize(y_local, &p_local));
207: /* Extract local observations and weight R^{-1/2} */
208: for (k = 0; k < ncols; k++) {
209: PetscInt obs_idx = cols[k];
210: PetscScalar weight = vals[k];
211: PetscInt local_idx = obs_idx;
213: /* If using local work vectors, map global index to local index */
214: if (obs_g2l) {
215: PetscCall(PetscHMapIGet(obs_g2l, obs_idx, &local_idx));
216: PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local map", obs_idx);
217: }
219: y_local_array[k] = y_global_array[local_idx];
220: y_mean_local_array[k] = y_mean_global_array[local_idx];
221: r_inv_sqrt_local_array[k] = r_inv_sqrt_global_array[local_idx] * PetscSqrtScalar(weight);
223: /* Extract Z matrix row (column-major layout) */
224: for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = z_global_array[local_idx + j * lda_z_global];
225: }
227: /* Zero the unused tail [ncols, p_local) so a shorter row does not leak the previous row's
228: trailing values into the downstream normalized-innovation computation. Caller does not need
229: to MatZeroEntries/VecZeroEntries the workspace between iterations. */
230: for (k = ncols; k < p_local; k++) {
231: y_local_array[k] = 0.0;
232: y_mean_local_array[k] = 0.0;
233: r_inv_sqrt_local_array[k] = 0.0;
234: for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = 0.0;
235: }
237: /* Restore arrays */
238: PetscCall(VecRestoreArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
239: PetscCall(VecRestoreArray(y_mean_local, &y_mean_local_array));
240: PetscCall(VecRestoreArray(y_local, &y_local_array));
241: PetscCall(MatDenseRestoreArrayWrite(Z_local, &z_local_array));
242: PetscCall(VecRestoreArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
243: PetscCall(VecRestoreArrayRead(y_mean_global, &y_mean_global_array));
244: PetscCall(VecRestoreArrayRead(y_global, &y_global_array));
245: PetscCall(MatDenseRestoreArrayRead(Z_global, &z_global_array));
247: /* Restore Q row */
248: PetscCall(MatRestoreRow(Q, vertex_idx, &ncols, &cols, &vals));
249: PetscFunctionReturn(PETSC_SUCCESS);
250: }
252: /*
253: PetscDALETKFLocalAnalysis - Performs local LETKF analysis for all grid points (CPU version)
255: Input Parameters:
256: + da - the PetscDA context
257: . impl - LETKF implementation data
258: . m - ensemble size
259: . n_vertices - number of grid points
260: . X - global anomaly matrix (state_size x m)
261: . observation - observation vector
262: . Z_global - global observation ensemble (obs_size x m)
263: . y_mean_global - global observation mean
264: - r_inv_sqrt_global - global R^{-1/2}
266: Output:
267: . da->ensemble - updated with analysis ensemble
269: Notes:
270: This function performs the local analysis loop for LETKF, processing each grid point
271: independently using its local observations defined by the localization matrix Q.
272: This is the CPU version that does not use Kokkos acceleration.
274: All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
275: y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta_local) are
276: created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
277: */
278: PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
279: {
280: PetscDA_Ensemble *en = &impl->en;
281: Mat Z_local, S_local, T_sqrt_local, G_local;
282: Mat X_rows, E_analysis_rows;
283: Vec y_local, y_mean_local, delta_scaled_local, r_inv_sqrt_local;
284: Vec w_local, s_transpose_delta_local;
285: const PetscScalar *w_array, *x_array, *g_array, *mean_array;
286: PetscScalar *g_array_w, *e_array, *x_rows_array, *ea_rows_array;
287: PetscScalar one = 1.0, zero = 0.0;
288: PetscBLASInt ndof_b, m_b, lda_xrows_b, lda_g_b, lda_ea_b;
289: PetscInt ndof, max_nnz, rstart;
290: PetscInt lda_x, lda_e, lda_xrows, lda_g, lda_ea;
291: PetscReal sqrt_m_minus_1, scale;
293: PetscFunctionBegin;
294: ndof = da->ndof;
295: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
296: scale = 1.0 / sqrt_m_minus_1;
297: max_nnz = impl->max_nnz_per_row;
299: /* X and ensemble are accessed at row offsets up to (n_vertices-1)*ndof + (ndof-1).
300: Mirror the precondition the Kokkos path enforces so a bad LDA fails fast on either backend. */
301: PetscCall(MatDenseGetLDA(X, &lda_x));
302: PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
303: PetscCheck(lda_x >= n_vertices * ndof, PetscObjectComm((PetscObject)X), PETSC_ERR_ARG_INCOMP, "X leading dimension %" PetscInt_FMT " < n_vertices*ndof %" PetscInt_FMT, lda_x, n_vertices * ndof);
304: PetscCheck(lda_e >= n_vertices * ndof, PetscObjectComm((PetscObject)en->ensemble), PETSC_ERR_ARG_INCOMP, "Ensemble leading dimension %" PetscInt_FMT " < n_vertices*ndof %" PetscInt_FMT, lda_e, n_vertices * ndof);
306: /* Create local analysis workspace (max_nnz x m matrices and vectors) */
307: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, max_nnz, m, NULL, &Z_local));
308: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)Z_local, "dense_"));
309: PetscCall(MatSetFromOptions(Z_local));
310: PetscCall(MatSetUp(Z_local));
311: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, max_nnz, m, NULL, &S_local));
312: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)S_local, "dense_"));
313: PetscCall(MatSetFromOptions(S_local));
314: PetscCall(MatSetUp(S_local));
315: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &T_sqrt_local));
316: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)T_sqrt_local, "dense_"));
317: PetscCall(MatSetFromOptions(T_sqrt_local));
318: PetscCall(MatSetUp(T_sqrt_local));
319: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &G_local));
320: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)G_local, "dense_"));
321: PetscCall(MatSetFromOptions(G_local));
322: PetscCall(MatSetUp(G_local));
324: /* Create vectors using MatCreateVecs() from Z_local (max_nnz x m) */
325: PetscCall(MatCreateVecs(Z_local, &w_local, &y_local));
326: PetscCall(VecDuplicate(y_local, &y_mean_local));
327: PetscCall(VecDuplicate(y_local, &delta_scaled_local));
328: PetscCall(VecDuplicate(y_local, &r_inv_sqrt_local));
329: PetscCall(VecDuplicate(w_local, &s_transpose_delta_local));
331: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, ndof, m, NULL, &X_rows));
332: PetscCall(MatDuplicate(X_rows, MAT_DO_NOT_COPY_VALUES, &E_analysis_rows));
334: /* X_rows, G_local, E_analysis_rows are loop-invariant; their LDAs and the BLAS-int
335: casts of the gemm shape never change inside the n_vertices loop. Hoist to spare
336: the dispatch overhead at every vertex. */
337: PetscCall(MatDenseGetLDA(G_local, &lda_g));
338: PetscCall(MatDenseGetLDA(X_rows, &lda_xrows));
339: PetscCall(MatDenseGetLDA(E_analysis_rows, &lda_ea));
340: PetscCall(PetscBLASIntCast(ndof, &ndof_b));
341: PetscCall(PetscBLASIntCast(m, &m_b));
342: PetscCall(PetscBLASIntCast(lda_xrows, &lda_xrows_b));
343: PetscCall(PetscBLASIntCast(lda_g, &lda_g_b));
344: PetscCall(PetscBLASIntCast(lda_ea, &lda_ea_b));
346: /* LETKF: Loop over all grid points and perform local analysis */
347: PetscCall(MatGetOwnershipRange(impl->Q, &rstart, NULL));
349: /* X, impl->mean, and en->ensemble are loop-invariant; their array views are read or
350: written at offsets that change per iteration but the underlying storage does not.
351: Hoisting the Get/Restore pairs out of the n_vertices loop avoids repeated lock and
352: validation overhead inside the hot path. Safe to hold across the loop because the
353: inner MatDenseGetArray{,Read,Write} calls on X_rows, E_analysis_rows, and G_local
354: operate on disjoint SELF matrices and never alias X/mean/en->ensemble. */
355: PetscCall(MatDenseGetArrayRead(X, &x_array));
356: PetscCall(VecGetArrayRead(impl->mean, &mean_array));
357: PetscCall(MatDenseGetArrayWrite(en->ensemble, &e_array));
359: for (PetscInt i_grid_point = 0; i_grid_point < n_vertices; i_grid_point++) {
360: /* Extract local observations for this grid point using Q[i_grid_point,:].
361: ExtractLocalObservations() zeros the unwritten [ncols, max_nnz) tail of each
362: workspace, so we do not need to MatZeroEntries/VecZeroEntries every iteration. */
363: /* Note: i_grid_point is local index, but MatGetRow needs global index */
364: PetscCall(ExtractLocalObservations(impl->Q, rstart + i_grid_point, Z_global, observation, y_mean_global, r_inv_sqrt_global, impl->obs_g2l, m, Z_local, y_local, y_mean_local, r_inv_sqrt_local));
366: /* Compute local normalized innovation matrix: S_local = R_local^{-1/2} * (Z_local - y_mean_local * 1') / sqrt(m - 1) */
367: PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(Z_local, y_mean_local, r_inv_sqrt_local, m, scale, S_local));
369: /* Compute local delta_scaled = R_local^{-1/2} * (y_local - y_mean_local) */
370: PetscCall(VecWAXPY(delta_scaled_local, -1.0, y_mean_local, y_local));
371: PetscCall(VecPointwiseMult(delta_scaled_local, delta_scaled_local, r_inv_sqrt_local));
373: /* Factor local T = (I + S_local^T * S_local) */
374: PetscCall(PetscDAEnsembleTFactor(da, S_local));
376: /* Compute local analysis weights: w_local = T_local^{-1} * S_local^T * delta_scaled_local */
377: PetscCall(MatMultTranspose(S_local, delta_scaled_local, s_transpose_delta_local));
378: PetscCall(PetscDAEnsembleApplyTInverse(da, s_transpose_delta_local, w_local));
380: /* Compute local square-root transform: T_sqrt_local = T_local^{-1/2} (U is identity, so pass NULL) */
381: PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, T_sqrt_local));
383: /* Form local transform G_local = w_local * 1' + sqrt(m - 1) * T_sqrt_local * U
384: Instead of creating w_ones_local = w_local * 1', we add w_local to each column of G_local */
385: PetscCall(MatCopy(T_sqrt_local, G_local, SAME_NONZERO_PATTERN));
386: PetscCall(MatScale(G_local, sqrt_m_minus_1));
387: PetscCall(VecGetArrayRead(w_local, &w_array));
388: PetscCall(MatDenseGetArray(G_local, &g_array_w));
389: for (PetscInt j = 0; j < m; j++)
390: for (PetscInt k = 0; k < m; k++) g_array_w[k + j * lda_g] += w_array[k];
391: PetscCall(MatDenseRestoreArray(G_local, &g_array_w));
392: PetscCall(VecRestoreArrayRead(w_local, &w_array));
394: /* LETKF Algorithm 2, Line 13: Update ensemble at grid point i_grid_point
395: E_a[i,:] = x_bar_f[i] + X_f[i,:] * G_local
397: Where:
398: - x_bar_f[i] is the forecast mean at grid point i_grid_point (ndof values from global mean vector)
399: - X_f[i,:] is the forecast anomaly rows at grid point i_grid_point (ndof rows from global anomaly matrix X)
400: - G_local = w_local * 1' + sqrt(m-1) * T_local^{1/2} * U (computed above in G_local)
401: */
402: /* Extract ndof rows starting at (i_grid_point * ndof) from X: X_f[i_grid_point*ndof:(i_grid_point+1)*ndof, :]
403: Hold X_rows / E_analysis_rows with a single read/write GetArray each so the fill, gemm,
404: mean-add, and copy-out share one Get/Restore pair per vertex. */
405: PetscCall(MatDenseGetArray(X_rows, &x_rows_array));
406: for (PetscInt j = 0; j < m; j++)
407: for (PetscInt k = 0; k < ndof; k++) x_rows_array[k + j * lda_xrows] = x_array[(i_grid_point * ndof + k) + j * lda_x];
409: /* Apply local transform via direct BLASgemm: E_analysis_rows = X_rows * G_local.
410: Replaces a per-vertex MatMatMult; ndof and m are typically small (1-100), so the
411: MatProduct dispatch overhead dominated. */
412: PetscCall(MatDenseGetArrayRead(G_local, &g_array));
413: PetscCall(MatDenseGetArray(E_analysis_rows, &ea_rows_array));
414: PetscCallBLAS("BLASgemm", BLASgemm_("N", "N", &ndof_b, &m_b, &m_b, &one, x_rows_array, &lda_xrows_b, g_array, &lda_g_b, &zero, ea_rows_array, &lda_ea_b));
415: PetscCall(MatDenseRestoreArrayRead(G_local, &g_array));
416: PetscCall(MatDenseRestoreArray(X_rows, &x_rows_array));
418: /* Add local mean and store result back in ensemble at row offset i_grid_point*ndof. */
419: for (PetscInt j = 0; j < m; j++) {
420: for (PetscInt k = 0; k < ndof; k++) {
421: ea_rows_array[k + j * lda_ea] += mean_array[i_grid_point * ndof + k];
422: e_array[(i_grid_point * ndof + k) + j * lda_e] = ea_rows_array[k + j * lda_ea];
423: }
424: }
425: PetscCall(MatDenseRestoreArray(E_analysis_rows, &ea_rows_array));
426: }
427: PetscCall(MatDenseRestoreArrayWrite(en->ensemble, &e_array));
428: PetscCall(VecRestoreArrayRead(impl->mean, &mean_array));
429: PetscCall(MatDenseRestoreArrayRead(X, &x_array));
430: PetscCall(MatDestroy(&E_analysis_rows));
431: PetscCall(MatDestroy(&X_rows));
432: PetscCall(VecDestroy(&s_transpose_delta_local));
433: PetscCall(VecDestroy(&w_local));
434: PetscCall(VecDestroy(&r_inv_sqrt_local));
435: PetscCall(VecDestroy(&delta_scaled_local));
436: PetscCall(VecDestroy(&y_mean_local));
437: PetscCall(VecDestroy(&y_local));
438: PetscCall(MatDestroy(&G_local));
439: PetscCall(MatDestroy(&T_sqrt_local));
440: PetscCall(MatDestroy(&S_local));
441: PetscCall(MatDestroy(&Z_local));
442: PetscFunctionReturn(PETSC_SUCCESS);
443: }
445: /*
446: PetscDALETKFGlobalAnalysis - LOC_NONE fast path: a single global ETKF analysis with no
447: per-vertex localization. The m x m T factor and weight vector live on PETSC_COMM_SELF so
448: every rank does the identical eigendecomp; only the gram S^T*S and the projection S^T*delta
449: need an MPI reduction. Dispatches to the Kokkos backend when R is a Kokkos matrix.
450: */
451: static PetscErrorCode PetscDALETKFGlobalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, Mat X, Vec observation)
452: {
453: const PetscScalar *s_array, *d_array;
454: PetscScalar *gram, *buf;
455: PetscScalar one = 1.0, zero = 0.0;
456: PetscBLASInt m_b, n_obs_local_b, s_lda_b, ione = 1;
457: PetscMPIInt m_squared_mpi, m_mpi;
458: PetscInt n_obs_local, s_lda;
459: PetscReal sqrt_m_minus_1, scale;
460: PetscBool use_kokkos;
462: PetscFunctionBegin;
463: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
464: scale = 1.0 / sqrt_m_minus_1;
465: PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));
466: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
467: if (use_kokkos) {
468: PetscCall(PetscDALETKFGlobalAnalysis_Kokkos(da, impl, m, X, observation));
469: PetscFunctionReturn(PETSC_SUCCESS);
470: }
471: #endif
472: PetscCheck(!use_kokkos, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "Kokkos backend selected but PetscDALETKFGlobalAnalysis_Kokkos is unavailable in this build");
474: PetscCall(PetscDALETKFEnsureGlobalScratch(impl, m));
476: /* S = R^{-1/2} * (Z - y_mean*1') / sqrt(m-1) */
477: PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(impl->Z, impl->y_mean, impl->r_inv_sqrt, m, scale, impl->S));
479: /* delta_scaled = R^{-1/2} * (y^o - y_mean) */
480: PetscCall(VecWAXPY(impl->delta_scaled, -1.0, impl->y_mean, observation));
481: PetscCall(VecPointwiseMult(impl->delta_scaled, impl->delta_scaled, impl->r_inv_sqrt));
483: /* Factor T = (1/rho)I + S^T*S replicated on every rank. */
484: PetscCall(PetscCalloc1((size_t)m * m, &gram));
485: PetscCall(MatGetLocalSize(impl->S, &n_obs_local, NULL));
486: PetscCall(MatDenseGetArrayRead(impl->S, &s_array));
487: PetscCall(MatDenseGetLDA(impl->S, &s_lda));
488: PetscCall(PetscBLASIntCast(m, &m_b));
489: PetscCall(PetscBLASIntCast(n_obs_local, &n_obs_local_b));
490: PetscCall(PetscBLASIntCast(s_lda, &s_lda_b));
491: if (n_obs_local > 0) PetscCallBLAS("BLASgemm", BLASgemm_("T", "N", &m_b, &m_b, &n_obs_local_b, &one, s_array, &s_lda_b, s_array, &s_lda_b, &zero, gram, &m_b));
492: PetscCall(PetscMPIIntCast((PetscInt64)m * m, &m_squared_mpi));
493: PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, gram, m_squared_mpi, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)da)));
494: PetscCall(PetscDAEnsembleTFactorFromGram(da, m, gram));
495: PetscCall(PetscFree(gram));
497: /* w = T^{-1} * (S^T * delta_scaled), with the projection reduced across ranks. Hold the
498: buffer with VecGetArray across both the local gemv and the in-place allreduce so the
499: reduction sees this rank's contribution (VecGetArrayWrite would make the post-gemv data
500: undefined after restore). Ranks with no local obs zero the buffer directly because gemv
501: (which would overwrite via beta = 0) is skipped on n_obs_local == 0. */
502: PetscCall(VecGetArrayRead(impl->delta_scaled, &d_array));
503: PetscCall(VecGetArray(impl->s_transpose_delta, &buf));
504: if (n_obs_local > 0) PetscCallBLAS("BLASgemv", BLASgemv_("T", &n_obs_local_b, &m_b, &one, s_array, &s_lda_b, d_array, &ione, &zero, buf, &ione));
505: else PetscCall(PetscArrayzero(buf, m));
506: PetscCall(PetscMPIIntCast(m, &m_mpi));
507: PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, buf, m_mpi, MPIU_SCALAR, MPIU_SUM, PetscObjectComm((PetscObject)da)));
508: PetscCall(VecRestoreArray(impl->s_transpose_delta, &buf));
509: PetscCall(VecRestoreArrayRead(impl->delta_scaled, &d_array));
510: PetscCall(MatDenseRestoreArrayRead(impl->S, &s_array));
512: PetscCall(PetscDAEnsembleApplyTInverse(da, impl->s_transpose_delta, impl->w));
514: /* T_sqrt = T^{-1/2} */
515: PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, impl->T_sqrt));
517: /* G = w*1' + sqrt(m-1) * T_sqrt (in impl->w_ones, all on PETSC_COMM_SELF). */
518: PetscCall(PetscDALETKFReplicateWeightVector(impl->w, m, impl->w_ones));
519: PetscCall(MatAXPY(impl->w_ones, sqrt_m_minus_1, impl->T_sqrt, SAME_NONZERO_PATTERN));
521: /* E = mean*1' + X * G */
522: PetscCall(PetscDALETKFUpdateEnsembleWithTransform(impl->mean, X, impl->w_ones, m, impl->en.ensemble));
523: PetscFunctionReturn(PETSC_SUCCESS);
524: }
526: /*
527: PetscDALETKFRebuildHTemps - ensure the cached H-compatible work vecs match H's current
528: layout and vec type, rebuilding them (and any caches that depend on H's backend) when H has
529: drifted since the last analysis. The two `if` blocks (invalidate, then allocate) must remain
530: separate rather than an if/else: when the first block fires it nulls `impl->H_temp_in` via
531: VecDestroy, and the second `if (!impl->H_temp_in)` must then re-create the temps in the same
532: call. Collapsing to if/else would leave a freshly-destroyed cache uninitialized until the next
533: analysis, breaking the post-condition that on return the H_temp_* and H_vec_type fields are
534: populated and consistent with the current H.
535: */
536: static PetscErrorCode PetscDALETKFRebuildHTemps(PetscDA da, PetscDA_LETKF *impl, Mat H)
537: {
538: PetscInt cur_in_local, cur_out_local, want_in_local, want_out_local;
539: VecType want_type;
540: PetscBool type_match;
542: PetscFunctionBegin;
543: PetscCall(MatGetVecType(H, &want_type));
544: if (impl->H_temp_in) {
545: PetscCall(VecGetLocalSize(impl->H_temp_in, &cur_in_local));
546: PetscCall(VecGetLocalSize(impl->H_temp_out, &cur_out_local));
547: PetscCall(MatGetLocalSize(H, &want_out_local, &want_in_local));
548: PetscCall(PetscStrcmp(impl->H_vec_type, want_type, &type_match));
549: if (!type_match || cur_in_local != want_in_local || cur_out_local != want_out_local) {
550: PetscCall(VecDestroy(&impl->H_temp_in));
551: PetscCall(VecDestroy(&impl->H_temp_out));
552: PetscCall(PetscFree(impl->H_vec_type));
553: /* The obs-scatter source layout is templated off H, and Q's device mirrors live in the
554: backend matching the old H vec type (Kokkos vs host); reset the full localization
555: cache so the next analysis rebuilds Q and its mirrors against the new H. */
556: PetscCall(PetscDALETKFResetLocalization_LETKF(da));
557: }
558: }
559: if (!impl->H_temp_in) {
560: PetscCall(MatCreateVecs(H, &impl->H_temp_in, &impl->H_temp_out));
561: PetscCall(PetscStrallocpy(want_type, &impl->H_vec_type));
562: }
563: PetscFunctionReturn(PETSC_SUCCESS);
564: }
566: static PetscErrorCode PetscDAEnsembleAnalysis_LETKF(PetscDA da, Vec observation, Mat H)
567: {
568: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
569: Mat X;
570: PetscInt m;
571: PetscBool reallocate = PETSC_FALSE;
573: PetscFunctionBegin;
574: m = impl->en.size;
575: PetscCheck(m >= 2, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size must be >= 2 for LETKF; got %" PetscInt_FMT, m);
577: /* Check for reallocation needs */
578: if (impl->mean) {
579: PetscInt mean_size;
580: PetscCall(VecGetSize(impl->mean, &mean_size));
581: if (mean_size != da->state_size) reallocate = PETSC_TRUE;
582: }
583: if (impl->Z) {
584: PetscInt z_rows, z_cols;
585: PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
586: if (z_rows != da->obs_size || z_cols != m) reallocate = PETSC_TRUE;
587: }
588: /* impl->T_sqrt is owned only by the LOC_NONE fast path (PetscDALETKFEnsureGlobalScratch());
589: the per-vertex path uses its own SELF T_sqrt_local. This check fires on a NONE -> per-vertex
590: transition with a changed m, where the stale T_sqrt would mis-size the next NONE cycle. */
591: if (impl->T_sqrt) {
592: PetscInt t_rows, t_cols;
593: PetscCall(MatGetSize(impl->T_sqrt, &t_rows, &t_cols));
594: if (t_rows != m || t_cols != m) reallocate = PETSC_TRUE;
595: }
597: /* Initialize or reallocate persistent work objects */
598: if (!impl->mean || reallocate) {
599: /* On reallocation the cached Q (and obs scatter / Kokkos device buffers) describe a
600: prior state_size/obs_size and must be torn down so the next analysis rebuilds them
601: against the new layout. Skip on first-time init (nothing to reset yet). */
602: if (reallocate) PetscCall(PetscDALETKFResetLocalization_LETKF(da));
603: PetscCall(VecDestroy(&impl->mean));
604: PetscCall(VecDestroy(&impl->y_mean));
605: PetscCall(VecDestroy(&impl->delta_scaled));
606: PetscCall(VecDestroy(&impl->w));
607: PetscCall(VecDestroy(&impl->s_transpose_delta));
608: PetscCall(VecDestroy(&impl->r_inv_sqrt));
609: PetscCall(VecDestroy(&impl->H_temp_in));
610: PetscCall(VecDestroy(&impl->H_temp_out));
611: PetscCall(PetscFree(impl->H_vec_type));
612: PetscCall(MatDestroy(&impl->Z));
613: PetscCall(MatDestroy(&impl->S));
614: PetscCall(MatDestroy(&impl->T_sqrt));
615: PetscCall(MatDestroy(&impl->w_ones));
617: /* Create mean vector from ensemble matrix (left vector = state space) */
618: PetscCall(MatCreateVecs(impl->en.ensemble, NULL, &impl->mean));
620: /* Create Z matrix (obs_size x m) */
621: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, da->obs_size, m, NULL, &impl->Z));
622: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->Z, "dense_"));
623: PetscCall(MatSetFromOptions(impl->Z));
624: PetscCall(MatSetUp(impl->Z));
626: /* Create observation space vectors from Z matrix (left vector = observation space) */
627: PetscCall(MatCreateVecs(impl->Z, NULL, &impl->y_mean));
628: PetscCall(VecDuplicate(impl->y_mean, &impl->delta_scaled));
629: PetscCall(VecDuplicate(da->obs_error_var, &impl->r_inv_sqrt));
631: /* Create S matrix (same layout as Z) */
632: PetscCall(MatDuplicate(impl->Z, MAT_DO_NOT_COPY_VALUES, &impl->S));
634: /* T_sqrt and w_ones are m x m and used only by the LOC_NONE fast path; allocate them
635: lazily in PetscDALETKFGlobalAnalysis() so the per-vertex paths do not pay for them. */
636: }
638: /* Alg 6.4 line 1-2: Compute ensemble mean and scaled anomalies */
639: PetscCall(PetscDAEnsembleComputeMean(da, impl->mean));
641: /* Create anomaly matrix X = (E - x_mean * 1') / sqrt(m - 1) */
642: PetscCall(PetscDAEnsembleComputeAnomalies(da, impl->mean, &X));
644: /* Alg 6.4 line 3-4: Compute GLOBAL observation ensemble Z = H * E column-by-column,
645: staged through H-compatible cached work vecs because impl->Z (MATDENSE) and H
646: (possibly MATAIJKOKKOS) cannot share a MatMatMult product type. */
647: /* Lazily allocate / rebuild the cached H-compatible work vecs (and reset Q if H's vec-type
648: backend changed). */
649: PetscCall(PetscDALETKFRebuildHTemps(da, impl, H));
651: /* Lazily build Q for built-in distance-based kernels using cached coordinates. The dispatcher
652: selects host vs Kokkos backend from the type of the cached observation operator. Setters
653: destroy Q via PetscDALETKFResetLocalization() when their inputs change, so a non-NULL Q is
654: guaranteed to match the current (type, radius, coord_*) tuple. Built after PetscDALETKFRebuildHTemps()
655: because that call may reset Q via PetscDALETKFResetLocalization_LETKF() when H's vec-type
656: backend changed since the last analysis. */
657: if (impl->type != PETSCDA_LETKF_LOC_NONE && !impl->Q) {
658: Mat Q_new = NULL;
659: PetscInt max_nnz_local, n_nnz_local;
660: PetscBool use_kokkos;
662: PetscCheck(impl->coord_H, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Coordinates not set; call PetscDALETKFSetLocalizationCoordinates() before analysis");
663: PetscCheck(impl->localization_radius > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Localization radius not set; call PetscDALETKFSetLocalizationRadius() before analysis");
664: /* Q's backend must match the analysis-time backend, which keys off da->R; otherwise a CPU
665: analysis would walk a Kokkos Q (or vice versa) and pay backend-mismatch transfer overhead. */
666: PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));
667: /* Backends compute the per-rank max-nnz/row and total local nnz from row_counts[] in hand;
668: passing them through avoids both a MatGetRow walk and a MatGetInfo call in
669: PetscDALETKFInstallQ() (both of which force a device->host sync on AIJKOKKOS Q).
670: The analysis-time `H` is threaded into InstallQ as the obs-scatter source template so the
671: scatter rows match the vectors actually being scattered, even if `H`'s row partition or
672: vec type differs from the `coord_H` cached at SetLocalizationCoordinates time. */
673: PetscCall(PetscDALETKFCreateLocalizationMat(impl->type, impl->localization_radius, impl->coord_xyz, impl->coord_bd, impl->coord_H, use_kokkos, &Q_new, &max_nnz_local, &n_nnz_local));
674: PetscCall(PetscDALETKFInstallQ(da, Q_new, max_nnz_local, n_nnz_local, H));
675: PetscCall(MatDestroy(&Q_new));
676: }
678: /* Compute Z = H * E column by column to avoid Kokkos vector type issues */
679: for (PetscInt j = 0; j < m; j++) {
680: Vec col_in, col_out;
681: PetscCall(MatDenseGetColumnVecRead(impl->en.ensemble, j, &col_in));
682: PetscCall(MatDenseGetColumnVecWrite(impl->Z, j, &col_out));
683: PetscCall(VecCopy(col_in, impl->H_temp_in));
684: PetscCall(MatMult(H, impl->H_temp_in, impl->H_temp_out));
685: PetscCall(VecCopy(impl->H_temp_out, col_out));
686: PetscCall(MatDenseRestoreColumnVecWrite(impl->Z, j, &col_out));
687: PetscCall(MatDenseRestoreColumnVecRead(impl->en.ensemble, j, &col_in));
688: }
690: /* Compute GLOBAL observation mean y_mean = H * x_mean using the same cached temps. */
691: PetscCall(VecCopy(impl->mean, impl->H_temp_in));
692: PetscCall(MatMult(H, impl->H_temp_in, impl->H_temp_out));
693: PetscCall(VecCopy(impl->H_temp_out, impl->y_mean));
695: /* Compute GLOBAL R^{-1/2} (assumes diagonal R) */
696: PetscCall(VecCopy(da->obs_error_var, impl->r_inv_sqrt));
697: PetscCall(VecSqrtAbs(impl->r_inv_sqrt));
698: PetscCall(VecReciprocal(impl->r_inv_sqrt));
700: if (impl->type == PETSCDA_LETKF_LOC_NONE) PetscCall(PetscDALETKFGlobalAnalysis(da, impl, m, X, observation));
701: else {
702: PetscInt n_local, n_obs_local, rows_old, cols_old;
703: PetscBool use_kokkos;
705: PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));
707: /* Per-vertex local analysis path. PetscDALETKFInstallQ() above already templated the
708: obs-scatter on the live `H`, and PetscDALETKFRebuildHTemps() invalidates Q (and therefore
709: triggers a re-InstallQ here) whenever H's row layout or vec type drifts, so impl->obs_scat
710: is guaranteed non-NULL and compatible with the live observation vectors. */
711: PetscCall(VecScatterBegin(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
712: PetscCall(VecScatterEnd(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
714: PetscCall(VecScatterBegin(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
715: PetscCall(VecScatterEnd(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
717: PetscCall(VecScatterBegin(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
718: PetscCall(VecScatterEnd(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
720: PetscCall(VecGetLocalSize(impl->obs_work, &n_obs_local));
721: if (impl->Z_work) {
722: PetscCall(MatGetSize(impl->Z_work, &rows_old, &cols_old));
723: if (rows_old != n_obs_local || cols_old != m) PetscCall(MatDestroy(&impl->Z_work));
724: }
725: if (!impl->Z_work) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
726: for (PetscInt i = 0; i < m; i++) {
727: Vec z_col_global, z_col_local;
728: PetscCall(MatDenseGetColumnVecRead(impl->Z, i, &z_col_global));
729: PetscCall(MatDenseGetColumnVecWrite(impl->Z_work, i, &z_col_local));
730: PetscCall(VecScatterBegin(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
731: PetscCall(VecScatterEnd(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
732: PetscCall(MatDenseRestoreColumnVecRead(impl->Z, i, &z_col_global));
733: PetscCall(MatDenseRestoreColumnVecWrite(impl->Z_work, i, &z_col_local));
734: }
736: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
737: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
738: if (use_kokkos) PetscCall(PetscDALETKFLocalAnalysis_Kokkos(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
739: else PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
740: #else
741: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
742: #endif
743: }
745: PetscCall(MatDestroy(&X));
746: /* Self-call ViewFromOptions at the tail, mirroring KSPSolve()/SNESSolve(). Fires every cycle
747: when the user passes -petscda_view; tutorials that want a single end-of-run snapshot call
748: PetscDAView() explicitly after the DA loop. */
749: PetscCall(PetscDAViewFromOptions(da, NULL, "-petscda_view"));
750: PetscFunctionReturn(PETSC_SUCCESS);
751: }
753: /*
754: PetscDALETKFResetLocalization_LETKF - destroy the cached Q matrix, obs scatter, and any device
755: buffers tied to Q. Coordinates/type/radius are preserved so the next analysis rebuilds Q from
756: the current inputs. Called by the setters that mutate Q-determining inputs.
757: */
758: static PetscErrorCode PetscDALETKFResetLocalization_LETKF(PetscDA da)
759: {
760: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
762: PetscFunctionBegin;
763: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
764: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
765: /* Drop only the Q device mirrors; the persistent cusolver/rocblas/SYCL handle and the
766: eigensolver workspace are reused across Q rebuilds. */
767: if (impl->Q) PetscCall(PetscDALETKFDestroyQDeviceMirrors_Kokkos(impl));
768: #endif
769: PetscCall(PetscDALETKFDestroyObsScatter(impl));
770: PetscCall(MatDestroy(&impl->Q));
771: impl->max_nnz_per_row = 0;
772: impl->n_nnz_local = 0;
773: PetscFunctionReturn(PETSC_SUCCESS);
774: }
776: static PetscErrorCode PetscDALETKFSetLocalizationRadius_LETKF(PetscDA da, PetscReal radius)
777: {
778: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
780: PetscFunctionBegin;
781: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
782: PetscCheck(radius > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Localization radius must be positive, got %g", (double)radius);
783: /* Exact equality: a tolerance would silently keep a stale Q after a small intentional bump. */
784: if (impl->localization_radius != radius) {
785: impl->localization_radius = radius;
786: PetscCall(PetscDALETKFResetLocalization_LETKF(da));
787: }
788: PetscFunctionReturn(PETSC_SUCCESS);
789: }
791: static PetscErrorCode PetscDALETKFSetLocalizationType_LETKF(PetscDA da, PetscDALETKFLocalizationType type)
792: {
793: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
795: PetscFunctionBegin;
796: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
797: PetscCheck(type >= 0 && type < PETSCDA_LETKF_LOC_NUM_TYPES, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Invalid localization type %d; must be in [0,%d)", (int)type, (int)PETSCDA_LETKF_LOC_NUM_TYPES);
798: if (impl->type != type) {
799: impl->type = type;
800: PetscCall(PetscDALETKFResetLocalization_LETKF(da));
801: }
802: PetscFunctionReturn(PETSC_SUCCESS);
803: }
805: static PetscErrorCode PetscDALETKFGetLocalizationType_LETKF(PetscDA da, PetscDALETKFLocalizationType *type)
806: {
807: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
809: PetscFunctionBegin;
810: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
811: *type = impl->type;
812: PetscFunctionReturn(PETSC_SUCCESS);
813: }
815: static PetscErrorCode PetscDALETKFSetLocalizationCoordinates_LETKF(PetscDA da, const Vec xyz[3], const PetscReal bd[3], Mat H)
816: {
817: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
818: PetscBool changed = PETSC_FALSE;
819: PetscInt H_rows, H_cols, vert_global, vert_local;
821: PetscFunctionBegin;
822: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
823: PetscAssertPointer(xyz, 2);
824: /* bd[d] > 0 selects periodic handling for dimension d; bd[d] == 0 means non-periodic. Reject
825: negative values so a stray sign flip cannot silently re-interpret as non-periodic. */
826: if (bd)
827: for (PetscInt d = 0; d < 3; d++) PetscCheck(bd[d] >= 0.0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Periodic-domain extent bd[%" PetscInt_FMT "] must be non-negative (use 0 for non-periodic), got %g", d, (double)bd[d]);
828: /* Validate H and xyz[0] against the PetscDA's recorded sizes at the API boundary so a
829: structurally mismatched H or coordinate vector is rejected here, where the caller can fix
830: it, rather than after the previous Q/obs-scatter has already been torn down inside the
831: lazy-build path. */
832: PetscCall(MatGetSize(H, &H_rows, &H_cols));
833: PetscCheck(H_rows == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "H has %" PetscInt_FMT " rows; PetscDA obs_size is %" PetscInt_FMT, H_rows, da->obs_size);
834: PetscCheck(da->ndof > 0 && da->state_size % da->ndof == 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "state_size (%" PetscInt_FMT ") must be a positive multiple of ndof (%" PetscInt_FMT ")", da->state_size, da->ndof);
835: PetscCall(VecGetSize(xyz[0], &vert_global));
836: PetscCall(VecGetLocalSize(xyz[0], &vert_local));
837: PetscCheck(vert_global == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "xyz[0] global size %" PetscInt_FMT " != vertex count state_size/ndof (%" PetscInt_FMT ")", vert_global, da->state_size / da->ndof);
838: /* H's columns must match xyz[0]'s rows so PetscDALETKFComputeObsCoords() can MatMult(H, xyz[d]).
839: The multi-DOF observation operator (cols == state_size) is a frequent mistake here; reject it
840: before PetscDALETKFResetLocalization_LETKF() tears down the previous Q. */
841: PetscCheck(H_cols == vert_global, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "H has %" PetscInt_FMT " columns; expected %" PetscInt_FMT " (vertex count, == xyz[0] global size). Did you pass the multi-DOF observation operator instead of the per-vertex one?", H_cols, vert_global);
842: /* Per-dim slots beyond xyz[0] must share xyz[0]'s global size and local partition; otherwise
843: PetscDALETKFGatherObsBbox() would walk mismatched coordinate arrays and read past valid memory. */
844: for (PetscInt d = 1; d < 3; d++) {
845: PetscInt other_global, other_local;
847: if (!xyz[d]) continue;
848: PetscCall(VecGetSize(xyz[d], &other_global));
849: PetscCall(VecGetLocalSize(xyz[d], &other_local));
850: PetscCheck(other_global == vert_global, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "xyz[%" PetscInt_FMT "] global size %" PetscInt_FMT " != xyz[0] global size %" PetscInt_FMT, d, other_global, vert_global);
851: PetscCheck(other_local == vert_local, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "xyz[%" PetscInt_FMT "] local size %" PetscInt_FMT " != xyz[0] local size %" PetscInt_FMT, d, other_local, vert_local);
852: }
853: /* Compare against the cached (xyz, bd, H) tuple by pointer/value so that re-supplying the same
854: geometry (a common pattern when the tutorial reapplies the same observation operator each
855: analysis cycle) does not invalidate Q and force the obs-scatter and device buffers to be
856: rebuilt. The contract requires the user to call this again after mutating any of these objects. */
857: for (PetscInt d = 0; d < 3; d++) {
858: if (impl->coord_xyz[d] != xyz[d]) changed = PETSC_TRUE;
859: if (impl->coord_bd[d] != (bd ? bd[d] : 0.0)) changed = PETSC_TRUE;
860: }
861: if (impl->coord_H != H) changed = PETSC_TRUE;
862: if (!changed) PetscFunctionReturn(PETSC_SUCCESS);
864: PetscCall(PetscDALETKFClearCoordinates(impl));
865: for (PetscInt d = 0; d < 3; d++) {
866: if (xyz[d]) {
867: PetscCall(PetscObjectReference((PetscObject)xyz[d]));
868: impl->coord_xyz[d] = xyz[d];
869: }
870: impl->coord_bd[d] = bd ? bd[d] : 0.0;
871: }
872: PetscCall(PetscObjectReference((PetscObject)H));
873: impl->coord_H = H;
874: PetscCall(PetscDALETKFResetLocalization_LETKF(da));
875: PetscFunctionReturn(PETSC_SUCCESS);
876: }
878: static PetscErrorCode PetscDALETKFGetLocalizationRadius_LETKF(PetscDA da, PetscReal *radius)
879: {
880: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
882: PetscFunctionBegin;
883: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
884: *radius = impl->localization_radius;
885: PetscFunctionReturn(PETSC_SUCCESS);
886: }
888: /*
889: Install a freshly built localization matrix Q (validate sizes, cache row nnz bounds, wire Kokkos
890: device buffers when applicable). Only called from the lazy-build path. The Kokkos device-side
891: setup runs only when PETSc was built with Kokkos kernels; the bare CPU analysis path (used by
892: serial non-Kokkos builds) needs only the size validation and nnz bookkeeping below.
894: scatter_H templates the obs-scatter source layout. The caller passes the analysis-time `H` so
895: that the scatter source matches the vectors the per-vertex path is about to scatter, even when
896: the user's analysis-time `H` has a different row partition or vec type than the `coord_H` that
897: was cached at SetLocalizationCoordinates time. Q's column footprint (the unique global obs
898: indices each rank touches) is independent of scatter_H's row layout, so this is well-defined as
899: long as scatter_H and Q agree on the global obs-space size - `PetscDALETKFSetupObsScatter()`
900: PetscCheck's that.
901: */
902: static PetscErrorCode PetscDALETKFInstallQ(PetscDA da, Mat Q, PetscInt max_nnz_local, PetscInt n_nnz_local, Mat scatter_H)
903: {
904: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
905: PetscInt nrows, ncols;
906: PetscBool use_kokkos;
908: PetscFunctionBegin;
909: PetscCheck(da->ndof > 0 && da->state_size % da->ndof == 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "state_size (%" PetscInt_FMT ") must be a positive multiple of ndof (%" PetscInt_FMT ")", da->state_size, da->ndof);
910: PetscCall(MatGetSize(Q, &nrows, &ncols));
911: PetscCheck(nrows == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix rows (%" PetscInt_FMT ") must equal vertex count state_size/ndof (%" PetscInt_FMT ")", nrows, da->state_size / da->ndof);
912: PetscCheck(ncols == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix columns (%" PetscInt_FMT ") must match observation size (%" PetscInt_FMT ")", ncols, da->obs_size);
913: PetscCheck(max_nnz_local >= 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "max_nnz_local must be >= 0, got %" PetscInt_FMT, max_nnz_local);
914: PetscCheck(n_nnz_local >= 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "n_nnz_local must be >= 0, got %" PetscInt_FMT, n_nnz_local);
915: /* The obs-scatter is needed by both the CPU and Kokkos per-vertex paths whenever Q exists.
916: Validate before tearing down the previous Q/obs-scatter so a missing prerequisite leaves
917: the impl in its prior usable state instead of a half-installed one. */
919: PetscCall(PetscDALETKFUseKokkosBackend(da, &use_kokkos));
921: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
922: /* Drop only the Q device mirrors; the eigensolver workspace and solver handle persist. */
923: if (impl->Q) PetscCall(PetscDALETKFDestroyQDeviceMirrors_Kokkos(impl));
924: #endif
925: /* Destroy the previous obs-scatter so SetupObsScatter() can rebuild it for the new Q footprint. */
926: PetscCall(PetscDALETKFDestroyObsScatter(impl));
928: PetscCall(PetscObjectReference((PetscObject)Q));
929: PetscCall(MatDestroy(&impl->Q));
930: impl->Q = Q;
932: /* The CSR backends already walked row_counts[] to size their allocations; reuse the per-rank
933: max (and total local nnz) computed there instead of MatGetRow/MatGetInfo walks, which would
934: force a device->host sync on AIJKOKKOS Q. Allreduce the max so impl->max_nnz_per_row holds
935: the global max all ranks need for sizing; n_nnz_local stays per-rank. */
936: impl->max_nnz_per_row = max_nnz_local;
937: PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &impl->max_nnz_per_row, 1, MPIU_INT, MPI_MAX, PetscObjectComm((PetscObject)da)));
938: impl->n_nnz_local = n_nnz_local;
940: PetscCall(PetscDALETKFSetupObsScatter(impl, scatter_H));
941: #if defined(PETSC_HAVE_KOKKOS_KERNELS) && !defined(PETSC_USE_COMPLEX)
942: /* Gate the device-mirror setup on the analysis-time backend: a Kokkos-capable build with a
943: non-Kokkos da->R runs the CPU per-vertex path, which never reads Q_device_*. Skipping the
944: setup avoids a redundant MatGetRow walk over every local row of Q and the three device-resident
945: Kokkos Views it would allocate. */
946: if (use_kokkos) PetscCall(PetscDALETKFSetupLocalization_Kokkos(impl));
947: #endif
948: PetscFunctionReturn(PETSC_SUCCESS);
949: }
951: static PetscErrorCode PetscDAView_LETKF(PetscDA da, PetscViewer viewer)
952: {
953: PetscBool iascii = PETSC_FALSE, is_kokkos = PETSC_FALSE;
954: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
956: PetscFunctionBegin;
957: PetscCall(PetscDAView_Ensemble(da, viewer));
958: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
959: if (iascii) {
960: PetscCall(PetscDALETKFUseKokkosBackend(da, &is_kokkos));
961: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: %s\n", is_kokkos ? "Kokkos" : "CPU"));
962: PetscCall(PetscViewerASCIIPrintf(viewer, " Localization type: %s\n", PetscDALETKFLocalizationTypes[impl->type]));
963: if (impl->type != PETSCDA_LETKF_LOC_NONE) {
964: if (impl->localization_radius > 0.0) PetscCall(PetscViewerASCIIPrintf(viewer, " Localization radius: %g\n", (double)impl->localization_radius));
965: else PetscCall(PetscViewerASCIIPrintf(viewer, " Localization radius: (unset)\n"));
966: if (is_kokkos) {
967: if (impl->batch_size > 0) PetscCall(PetscViewerASCIIPrintf(viewer, " GPU batch size: %" PetscInt_FMT "\n", impl->batch_size));
968: else PetscCall(PetscViewerASCIIPrintf(viewer, " GPU batch size: auto\n"));
969: }
970: }
971: }
972: PetscFunctionReturn(PETSC_SUCCESS);
973: }
975: static PetscErrorCode PetscDASetFromOptions_LETKF(PetscDA da, PetscOptionItems *PetscOptionsObjectPtr)
976: {
977: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
978: PetscOptionItems PetscOptionsObject = *PetscOptionsObjectPtr;
979: PetscReal radius;
980: PetscInt type_idx, batch_size;
981: PetscBool type_set = PETSC_FALSE, radius_set = PETSC_FALSE;
983: PetscFunctionBegin;
984: PetscCall(PetscDASetFromOptions_Ensemble(da, PetscOptionsObjectPtr));
985: PetscOptionsHeadBegin(PetscOptionsObject, "PetscDA LETKF Options");
986: batch_size = impl->batch_size;
987: PetscCall(PetscOptionsInt("-petscda_letkf_batch_size", "Batch size for GPU processing (0 = auto)", "PETSCDALETKF", batch_size, &batch_size, NULL));
988: PetscCheck(batch_size >= 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "batch_size must be >= 0, got %" PetscInt_FMT, batch_size);
989: impl->batch_size = batch_size;
990: radius = impl->localization_radius;
991: PetscCall(PetscOptionsReal("-petscda_letkf_localization_radius", "Localization cutoff radius for built-in kernels", "PetscDALETKFSetLocalizationRadius", radius, &radius, &radius_set));
992: if (radius_set) PetscCall(PetscDALETKFSetLocalizationRadius(da, radius));
993: type_idx = (PetscInt)impl->type;
994: PetscCall(PetscOptionsEList("-petscda_letkf_localization_type", "Localization kernel type", "PetscDALETKFSetLocalizationType", PetscDALETKFLocalizationTypes, PETSCDA_LETKF_LOC_NUM_TYPES, PetscDALETKFLocalizationTypes[type_idx], &type_idx, &type_set));
995: if (type_set) PetscCall(PetscDALETKFSetLocalizationType(da, (PetscDALETKFLocalizationType)type_idx));
996: PetscOptionsHeadEnd();
997: PetscFunctionReturn(PETSC_SUCCESS);
998: }
1000: /*MC
1001: PETSCDALETKF - The Local ETKF performs the analysis update locally around each grid point, enabling scalable assimilation on large
1002: domains by avoiding the global ensemble covariance matrix.
1004: Options Database Keys:
1005: + -petscda_type letkf - set the `PetscDAType` to `PETSCDALETKF`
1006: . -petscda_ensemble_size size - number of ensemble members
1007: . -petscda_ensemble_inflation factor - multiplicative inflation factor applied to anomalies
1008: . -petscda_letkf_batch_size batch_size - set the batch size for GPU processing
1009: . -petscda_letkf_localization_radius radius - localization cutoff radius for the built-in kernels (must be positive)
1010: . -petscda_letkf_localization_type (none|gaspari_cohn|gaussian|boxcar) - select the localization kernel
1011: - -petscda_view - view the `PetscDA` at the end of every `PetscDAEnsembleAnalysis()` call
1013: Level: beginner
1015: Notes:
1016: The default localization kernel is `PETSCDA_LETKF_LOC_GASPARI_COHN`, which requires the user to
1017: call `PetscDALETKFSetLocalizationRadius()` and `PetscDALETKFSetLocalizationCoordinates()` before
1018: the first analysis. To skip localization entirely use `PetscDALETKFSetLocalizationType(da, PETSCDA_LETKF_LOC_NONE)`
1019: (or `-petscda_letkf_localization_type none`).
1020: Both the CPU and Kokkos analysis paths support multi-rank runs; the Kokkos backend is selected
1021: when the covariance matrix `da->R` is a Kokkos AIJ type, otherwise the CPU per-vertex (or LOC_NONE
1022: replicated) path is used.
1023: `-petscda_view` fires at the tail of every `PetscDAEnsembleAnalysis()` call (mirroring `KSPSolve()`/`SNESSolve()`),
1024: so over a multi-cycle assimilation run the view is emitted once per analysis. Code that wants a single
1025: end-of-run snapshot should call `PetscDAView()` explicitly after the assimilation loop instead.
1027: .seealso: [](ch_da), `PetscDA`, `PetscDACreate()`, `PetscDALETKFSetLocalizationRadius()`, `PetscDALETKFGetLocalizationRadius()`,
1028: `PetscDALETKFSetLocalizationType()`, `PetscDALETKFGetLocalizationType()`, `PetscDALETKFSetLocalizationCoordinates()`,
1029: `PetscDALETKFResetLocalization()`, `PetscDAEnsembleSetSize()`, `PetscDASetSizes()`, `PetscDAEnsembleSetInflation()`,
1030: `PetscDAEnsembleComputeMean()`, `PetscDAEnsembleComputeAnomalies()`, `PetscDAEnsembleAnalysis()`, `PetscDAEnsembleForecast()`
1031: M*/
1033: PETSC_INTERN PetscErrorCode PetscDACreate_LETKF(PetscDA da)
1034: {
1035: PetscDA_LETKF *impl;
1037: PetscFunctionBegin;
1038: PetscCall(PetscNew(&impl));
1039: da->data = impl;
1040: PetscCall(PetscDACreate_Ensemble(da));
1041: da->ops->setup = PetscDASetUp_Ensemble;
1042: da->ops->destroy = PetscDADestroy_LETKF;
1043: da->ops->view = PetscDAView_LETKF;
1044: da->ops->setfromoptions = PetscDASetFromOptions_LETKF;
1045: impl->en.analysis = PetscDAEnsembleAnalysis_LETKF;
1046: impl->en.forecast = PetscDAEnsembleForecast_Ensemble;
1048: impl->type = PETSCDA_LETKF_LOC_GASPARI_COHN;
1050: /* Register the method for setting localization */
1051: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationRadius_C", PetscDALETKFSetLocalizationRadius_LETKF));
1052: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationRadius_C", PetscDALETKFGetLocalizationRadius_LETKF));
1053: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationType_C", PetscDALETKFSetLocalizationType_LETKF));
1054: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetLocalizationType_C", PetscDALETKFGetLocalizationType_LETKF));
1055: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalizationCoordinates_C", PetscDALETKFSetLocalizationCoordinates_LETKF));
1056: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFResetLocalization_C", PetscDALETKFResetLocalization_LETKF));
1057: PetscFunctionReturn(PETSC_SUCCESS);
1058: }
1060: /*@
1061: PetscDALETKFSetLocalizationRadius - Sets the localization cutoff radius used by LETKF's built-in distance-based kernels.
1063: Logically Collective
1065: Input Parameters:
1066: + da - the `PetscDA` context
1067: - radius - the localization cutoff radius (must be positive; use a large value for effectively no localization)
1069: Level: advanced
1071: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationCoordinates()`, `PetscDALETKFGetLocalizationRadius()`
1072: @*/
1073: PetscErrorCode PetscDALETKFSetLocalizationRadius(PetscDA da, PetscReal radius)
1074: {
1075: PetscFunctionBegin;
1078: PetscTryMethod(da, "PetscDALETKFSetLocalizationRadius_C", (PetscDA, PetscReal), (da, radius));
1079: PetscFunctionReturn(PETSC_SUCCESS);
1080: }
1082: /*@
1083: PetscDALETKFGetLocalizationRadius - Gets the localization cutoff radius used by LETKF's built-in distance-based kernels.
1085: Not Collective
1087: Input Parameter:
1088: . da - the `PetscDA` context
1090: Output Parameter:
1091: . radius - the localization cutoff radius
1093: Level: advanced
1095: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationRadius()`
1096: @*/
1097: PetscErrorCode PetscDALETKFGetLocalizationRadius(PetscDA da, PetscReal *radius)
1098: {
1099: PetscFunctionBegin;
1101: PetscAssertPointer(radius, 2);
1102: PetscUseMethod(da, "PetscDALETKFGetLocalizationRadius_C", (PetscDA, PetscReal *), (da, radius));
1103: PetscFunctionReturn(PETSC_SUCCESS);
1104: }
1106: /*@
1107: PetscDALETKFSetLocalizationType - Selects the localization kernel used by `PETSCDALETKF`.
1109: Logically Collective
1111: Input Parameters:
1112: + da - the `PetscDA` context
1113: - type - the kernel type (see `PetscDALETKFLocalizationType`)
1115: Level: intermediate
1117: Notes:
1118: Use `PETSCDA_LETKF_LOC_NONE` to bypass localization entirely; the analysis is then mathematically
1119: equivalent to the global ETKF and dispatches through a single global eigensolve plus a dense
1120: `BLASgemm` weight transform reduced across ranks, instead of the per-vertex local loop.
1122: For the built-in distance-based kernels (`PETSCDA_LETKF_LOC_GASPARI_COHN`, `PETSCDA_LETKF_LOC_GAUSSIAN`,
1123: `PETSCDA_LETKF_LOC_BOXCAR`) you must also call `PetscDALETKFSetLocalizationRadius()` and
1124: `PetscDALETKFSetLocalizationCoordinates()`. The localization matrix is then constructed
1125: lazily before the first analysis.
1126: All three built-in kernels are 1 at distance 0; `radius` selects the effective support but the
1127: cutoff distance and continuity at the cutoff differ.
1128: `PETSCDA_LETKF_LOC_GASPARI_COHN` is compactly supported with cutoff at distance `2*radius`, and
1129: is C^1 continuous everywhere (it tapers smoothly to zero at the cutoff).
1130: `PETSCDA_LETKF_LOC_GAUSSIAN` is `exp(-d^2 / (2*radius^2))` truncated at distance `2*radius`; the
1131: truncation introduces a discontinuity of `exp(-2)` (~0.135) at the cutoff, so prefer
1132: `PETSCDA_LETKF_LOC_GASPARI_COHN` if a smooth taper at the cutoff matters.
1133: `PETSCDA_LETKF_LOC_BOXCAR` is 1 inside `radius` and 0 outside; the discontinuity is by design.
1135: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFLocalizationType`, `PetscDALETKFGetLocalizationType()`,
1136: `PetscDALETKFSetLocalizationRadius()`, `PetscDALETKFSetLocalizationCoordinates()`
1137: @*/
1138: PetscErrorCode PetscDALETKFSetLocalizationType(PetscDA da, PetscDALETKFLocalizationType type)
1139: {
1140: PetscFunctionBegin;
1143: PetscCheck(type >= PETSCDA_LETKF_LOC_NONE && type < PETSCDA_LETKF_LOC_NUM_TYPES, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Localization type %d out of range [0, %d)", (int)type, (int)PETSCDA_LETKF_LOC_NUM_TYPES);
1144: PetscTryMethod(da, "PetscDALETKFSetLocalizationType_C", (PetscDA, PetscDALETKFLocalizationType), (da, type));
1145: PetscFunctionReturn(PETSC_SUCCESS);
1146: }
1148: /*@
1149: PetscDALETKFGetLocalizationType - Returns the localization kernel currently used by `PETSCDALETKF`.
1151: Not Collective
1153: Input Parameter:
1154: . da - the `PetscDA` context
1156: Output Parameter:
1157: . type - the kernel type
1159: Level: intermediate
1161: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFLocalizationType`, `PetscDALETKFSetLocalizationType()`
1162: @*/
1163: PetscErrorCode PetscDALETKFGetLocalizationType(PetscDA da, PetscDALETKFLocalizationType *type)
1164: {
1165: PetscFunctionBegin;
1167: PetscAssertPointer(type, 2);
1168: PetscUseMethod(da, "PetscDALETKFGetLocalizationType_C", (PetscDA, PetscDALETKFLocalizationType *), (da, type));
1169: PetscFunctionReturn(PETSC_SUCCESS);
1170: }
1172: /*@
1173: PetscDALETKFSetLocalizationCoordinates - Provides the geometry used to lazily build the
1174: localization matrix `Q` for a built-in `PETSCDALETKF` kernel.
1176: Collective
1178: Input Parameters:
1179: + da - the `PetscDA` context
1180: . xyz - length-3 array of coordinate vectors, one per spatial dimension; set unused trailing
1181: slots to `NULL` (the spatial dimension is taken to be the index of the first `NULL`,
1182: so `{x, y, NULL}` is 2D and `{x, NULL, NULL}` is 1D)
1183: . bd - length-3 array of periodic-domain extents (use 0 for non-periodic dimensions); pass
1184: `NULL` to mean fully non-periodic
1185: - H - the observation operator (used to map state-space coordinates to observation locations)
1187: Level: intermediate
1189: Notes:
1190: The `xyz` array must always have three slots even in 1D or 2D; trailing slots are set to `NULL`.
1191: This matches the internal cached layout `coord_xyz[3]` and the layout used by both Q backends.
1193: The localization matrix `Q` is built on first analysis (or whenever the type, radius or
1194: coordinates change) using the kernel selected by `PetscDALETKFSetLocalizationType()`. If the
1195: current type is `PETSCDA_LETKF_LOC_NONE`, the coordinates are cached but the analysis continues
1196: to run the NONE fast path; switch to a distance-based kernel via
1197: `PetscDALETKFSetLocalizationType()` for the cached coordinates to take effect.
1198: The reference counts on `xyz` and `H` are increased; the caller may destroy them afterwards.
1200: The cached coordinate `Vec`s are referenced, not deep-copied. If the caller mutates the contents
1201: of any element of `xyz` after this call (for example, after a remesh or recoordinate step), the
1202: cached `Q` will not be rebuilt automatically; call `PetscDALETKFResetLocalization()` to invalidate
1203: `Q` and force a rebuild on the next analysis.
1205: The columns of `Q` are global indices into the observation vector, derived from the row
1206: ownership and sparsity of the `H` cached here. The `H` passed to `PetscDAEnsembleAnalysis()`
1207: must therefore use the same global row indexing (same observation ordering and the same global
1208: obs-space size) as the `H` cached here. The analysis-time `H` may differ from the cached `H`
1209: in MPI row partition or vec type - the obs-scatter is templated on the analysis-time `H`, so
1210: those differences are absorbed automatically. A structurally different `H` (rows referring to
1211: different physical observations) will produce wrong analyses without raising an error; in that
1212: case, call `PetscDALETKFSetLocalizationCoordinates()` again with the new `H` to rebuild `Q`.
1214: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationType()`,
1215: `PetscDALETKFSetLocalizationRadius()`
1216: @*/
1217: PetscErrorCode PetscDALETKFSetLocalizationCoordinates(PetscDA da, const Vec xyz[3], const PetscReal bd[3], Mat H)
1218: {
1219: PetscFunctionBegin;
1221: /* xyz is required; bd is optional. Use an always-on PetscCheck rather than the debug-only
1222: PetscAssertPointer() so a NULL xyz argument is rejected cleanly in optimized builds before
1223: the xyz[0] dereference below. */
1224: PetscCheck(xyz, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_NULL, "xyz must be a non-NULL length-3 array of Vec");
1225: PetscCheck(xyz[0], PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONG, "xyz[0] must be a valid Vec; the spatial dimension is taken to be the index of the first NULL slot in xyz[3]");
1227: if (bd)
1229: PetscTryMethod(da, "PetscDALETKFSetLocalizationCoordinates_C", (PetscDA, const Vec[3], const PetscReal[3], Mat), (da, xyz, bd, H));
1230: PetscFunctionReturn(PETSC_SUCCESS);
1231: }
1233: /*@
1234: PetscDALETKFResetLocalization - Discards the cached localization matrix `Q` so the next analysis
1235: rebuilds it from the current type, radius, and coordinates.
1237: Collective
1239: Input Parameter:
1240: . da - the `PetscDA` context
1242: Level: advanced
1244: Notes:
1245: The setters `PetscDALETKFSetLocalizationType()`, `PetscDALETKFSetLocalizationRadius()`, and
1246: `PetscDALETKFSetLocalizationCoordinates()` already invalidate `Q` when their inputs actually
1247: change, so most users never need to call this directly. Use it when an input was mutated outside
1248: of the setters (for example, the entries of a cached coordinate `Vec` were edited in place, or
1249: the cached observation operator `H` was reassembled with different sparsity).
1251: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalizationType()`,
1252: `PetscDALETKFSetLocalizationRadius()`, `PetscDALETKFSetLocalizationCoordinates()`
1253: @*/
1254: PetscErrorCode PetscDALETKFResetLocalization(PetscDA da)
1255: {
1256: PetscFunctionBegin;
1258: PetscTryMethod(da, "PetscDALETKFResetLocalization_C", (PetscDA), (da));
1259: PetscFunctionReturn(PETSC_SUCCESS);
1260: }