Actual source code: letkfilter.c
1: #include <petscda.h>
2: #include <petsc/private/daimpl.h>
3: #include <petsc/private/daensembleimpl.h>
4: #include <../src/ml/da/impls/ensemble/letkf/letkf.h>
6: static PetscErrorCode PetscDADestroy_LETKF(PetscDA da)
7: {
8: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
10: PetscFunctionBegin;
11: PetscCall(VecDestroy(&impl->mean));
12: PetscCall(VecDestroy(&impl->y_mean));
13: PetscCall(VecDestroy(&impl->delta_scaled));
14: PetscCall(VecDestroy(&impl->w));
15: PetscCall(VecDestroy(&impl->r_inv_sqrt));
16: PetscCall(MatDestroy(&impl->Z));
17: PetscCall(MatDestroy(&impl->S));
18: PetscCall(MatDestroy(&impl->T_sqrt));
19: PetscCall(MatDestroy(&impl->w_ones));
20: PetscCall(MatDestroy(&impl->Q));
21: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
22: PetscCall(PetscDALETKFDestroyLocalization_Kokkos(impl));
23: #endif
24: PetscCall(ISDestroy(&impl->obs_is_local));
25: PetscCall(VecScatterDestroy(&impl->obs_scat));
26: PetscCall(VecDestroy(&impl->obs_work));
27: PetscCall(VecDestroy(&impl->y_mean_work));
28: PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
29: PetscCall(MatDestroy(&impl->Z_work));
30: PetscCall(PetscDADestroy_Ensemble(da));
31: PetscCall(PetscFree(da->data));
33: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalization_C", NULL));
34: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetObsPerVertex_C", NULL));
35: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetObsPerVertex_C", NULL));
36: PetscFunctionReturn(PETSC_SUCCESS);
37: }
39: /*
40: ExtractLocalObservations - Extracts local observations for a vertex using localization matrix Q (CPU version)
42: Input Parameters:
43: + Q - localization matrix (state_size/ndof x obs_size), each row has constant non-zeros
44: . vertex_idx - index of the vertex (row of Q)
45: . Z_global - global observation ensemble matrix (obs_size x m) OR local work matrix
46: . y_global - global observation vector (size obs_size) OR local work vector
47: . y_mean_global - global observation mean (size obs_size) OR local work vector
48: . r_inv_sqrt_global - global R^{-1/2} (size obs_size) OR local work vector
49: . obs_g2l - map from global observation index to local index (if using local work vectors)
50: . m - ensemble size
52: Output Parameters:
53: . Z_local - local observation ensemble (p_local x m), pre-allocated
54: . y_local - local observation vector (size p_local), pre-allocated
55: . y_mean_local - local observation mean (size p_local), pre-allocated
56: - r_inv_sqrt_local - local R^{-1/2} (size p_local), pre-allocated
57: */
58: static PetscErrorCode ExtractLocalObservations(Mat Q, PetscInt vertex_idx, Mat Z_global, Vec y_global, Vec y_mean_global, Vec r_inv_sqrt_global, PetscHMapI obs_g2l, PetscInt m, Mat Z_local, Vec y_local, Vec y_mean_local, Vec r_inv_sqrt_local)
59: {
60: const PetscInt *cols;
61: const PetscScalar *vals;
62: PetscInt ncols, k, j;
63: const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
64: PetscScalar *z_local_array, *y_local_array, *y_mean_local_array, *r_inv_sqrt_local_array;
65: PetscInt lda_z_global, lda_z_local;
67: PetscFunctionBegin;
68: /* Get the row of Q corresponding to this vertex */
69: PetscCall(MatGetRow(Q, vertex_idx, &ncols, &cols, &vals));
71: /* Get array access to global data */
72: PetscCall(MatDenseGetArrayRead(Z_global, &z_global_array));
73: PetscCall(VecGetArrayRead(y_global, &y_global_array));
74: PetscCall(VecGetArrayRead(y_mean_global, &y_mean_global_array));
75: PetscCall(VecGetArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
77: /* Get array access to local data */
78: PetscCall(MatDenseGetArrayWrite(Z_local, &z_local_array));
79: PetscCall(VecGetArray(y_local, &y_local_array));
80: PetscCall(VecGetArray(y_mean_local, &y_mean_local_array));
81: PetscCall(VecGetArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
83: /* Get leading dimensions */
84: PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
85: PetscCall(MatDenseGetLDA(Z_local, &lda_z_local));
87: /* Extract local observations and weight R^{-1/2} */
88: for (k = 0; k < ncols; k++) {
89: PetscInt obs_idx = cols[k];
90: PetscScalar weight = vals[k];
91: PetscInt local_idx = obs_idx;
93: /* If using local work vectors, map global index to local index */
94: if (obs_g2l) {
95: PetscCall(PetscHMapIGet(obs_g2l, obs_idx, &local_idx));
96: PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local map", obs_idx);
97: }
99: y_local_array[k] = y_global_array[local_idx];
100: y_mean_local_array[k] = y_mean_global_array[local_idx];
101: r_inv_sqrt_local_array[k] = r_inv_sqrt_global_array[local_idx] * PetscSqrtScalar(weight);
103: /* Extract Z matrix row (column-major layout) */
104: for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = z_global_array[local_idx + j * lda_z_global];
105: }
107: /* Restore arrays */
108: PetscCall(VecRestoreArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
109: PetscCall(VecRestoreArray(y_mean_local, &y_mean_local_array));
110: PetscCall(VecRestoreArray(y_local, &y_local_array));
111: PetscCall(MatDenseRestoreArrayWrite(Z_local, &z_local_array));
112: PetscCall(VecRestoreArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
113: PetscCall(VecRestoreArrayRead(y_mean_global, &y_mean_global_array));
114: PetscCall(VecRestoreArrayRead(y_global, &y_global_array));
115: PetscCall(MatDenseRestoreArrayRead(Z_global, &z_global_array));
117: /* Restore Q row */
118: PetscCall(MatRestoreRow(Q, vertex_idx, &ncols, &cols, &vals));
120: /* Assemble local matrices/vectors */
121: PetscCall(MatAssemblyBegin(Z_local, MAT_FINAL_ASSEMBLY));
122: PetscCall(MatAssemblyEnd(Z_local, MAT_FINAL_ASSEMBLY));
123: PetscFunctionReturn(PETSC_SUCCESS);
124: }
126: /*
127: PetscDALETKFLocalAnalysis - Performs local LETKF analysis for all grid points (CPU version)
129: Input Parameters:
130: + da - the PetscDA context
131: . impl - LETKF implementation data
132: . m - ensemble size
133: . n_vertices - number of grid points
134: . X - global anomaly matrix (state_size x m)
135: . observation - observation vector
136: . Z_global - global observation ensemble (obs_size x m)
137: . y_mean_global - global observation mean
138: - r_inv_sqrt_global - global R^{-1/2}
140: Output:
141: . da->ensemble - updated with analysis ensemble
143: Notes:
144: This function performs the local analysis loop for LETKF, processing each grid point
145: independently using its local observations defined by the localization matrix Q.
146: This is the CPU version that does not use Kokkos acceleration.
148: All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
149: y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
150: created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
151: */
152: PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
153: {
154: PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
155: Mat Z_local, S_local, T_sqrt_local, G_local;
156: Vec y_local, y_mean_local, delta_scaled_local, r_inv_sqrt_local;
157: Vec w_local, s_transpose_delta;
158: PetscInt i_grid_point;
159: PetscInt ndof;
160: PetscReal sqrt_m_minus_1, scale;
161: PetscInt rstart;
162: Mat X_rows, E_analysis_rows;
164: PetscFunctionBegin;
165: ndof = da->ndof;
166: scale = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
167: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
168: /* Create local analysis workspace (n_obs_vertex x m matrices and vectors) */
169: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, impl->n_obs_vertex, m, NULL, &Z_local));
170: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)Z_local, "dense_"));
171: PetscCall(MatSetFromOptions(Z_local));
172: PetscCall(MatSetUp(Z_local));
173: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, impl->n_obs_vertex, m, NULL, &S_local));
174: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)S_local, "dense_"));
175: PetscCall(MatSetFromOptions(S_local));
176: PetscCall(MatSetUp(S_local));
177: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &T_sqrt_local));
178: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)T_sqrt_local, "dense_"));
179: PetscCall(MatSetFromOptions(T_sqrt_local));
180: PetscCall(MatSetUp(T_sqrt_local));
181: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &G_local));
182: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)G_local, "dense_"));
183: PetscCall(MatSetFromOptions(G_local));
184: PetscCall(MatSetUp(G_local));
186: /* Create vectors using MatCreateVecs from Z_local (n_obs_vertex x m) */
187: PetscCall(MatCreateVecs(Z_local, &w_local, &y_local));
188: PetscCall(VecDuplicate(y_local, &y_mean_local));
189: PetscCall(VecDuplicate(y_local, &delta_scaled_local));
190: PetscCall(VecDuplicate(y_local, &r_inv_sqrt_local));
191: PetscCall(VecDuplicate(w_local, &s_transpose_delta));
193: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, ndof, m, NULL, &X_rows));
194: PetscCall(MatDuplicate(X_rows, MAT_DO_NOT_COPY_VALUES, &E_analysis_rows));
196: /* LETKF: Loop over all grid points and perform local analysis */
197: PetscCall(MatGetOwnershipRange(impl->Q, &rstart, NULL));
199: for (i_grid_point = 0; i_grid_point < n_vertices; i_grid_point++) {
200: /* Extract local observations for this grid point using Q[i_grid_point,:] */
201: /* Note: i_grid_point is local index, but MatGetRow needs global index */
202: PetscCall(ExtractLocalObservations(impl->Q, rstart + i_grid_point, Z_global, observation, y_mean_global, r_inv_sqrt_global, impl->obs_g2l, m, Z_local, y_local, y_mean_local, r_inv_sqrt_local));
204: /* Compute local normalized innovation matrix: S_local = R_local^{-1/2} * (Z_local - y_mean_local * 1') / sqrt(m - 1) */
205: PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(Z_local, y_mean_local, r_inv_sqrt_local, m, scale, S_local));
207: /* Compute local delta_scaled = R_local^{-1/2} * (y_local - y_mean_local) */
208: PetscCall(VecWAXPY(delta_scaled_local, -1.0, y_mean_local, y_local));
209: PetscCall(VecPointwiseMult(delta_scaled_local, delta_scaled_local, r_inv_sqrt_local));
211: /* Factor local T = (I + S_local^T * S_local) */
212: PetscCall(PetscDAEnsembleTFactor(da, S_local));
214: /* Compute local analysis weights: w_local = T_local^{-1} * S_local^T * delta_scaled_local */
215: PetscCall(MatMultTranspose(S_local, delta_scaled_local, s_transpose_delta));
216: PetscCall(PetscDAEnsembleApplyTInverse(da, s_transpose_delta, w_local));
218: /* Compute local square-root transform: T_sqrt_local = T_local^{-1/2} (U is identity, so pass NULL) */
219: PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, T_sqrt_local));
221: /* Form local transform G_local = w_local * 1' + sqrt(m - 1) * T_sqrt_local * U
222: Instead of creating w_ones_local = w_local * 1', we add w_local to each column of G_local */
223: PetscCall(MatCopy(T_sqrt_local, G_local, SAME_NONZERO_PATTERN));
224: PetscCall(MatScale(G_local, sqrt_m_minus_1));
225: {
226: const PetscScalar *w_array;
227: PetscScalar *g_array;
228: PetscInt j, k, lda_g;
230: PetscCall(VecGetArrayRead(w_local, &w_array));
231: PetscCall(MatDenseGetArrayWrite(G_local, &g_array));
232: PetscCall(MatDenseGetLDA(G_local, &lda_g));
233: for (j = 0; j < m; j++)
234: for (k = 0; k < m; k++) g_array[k + j * lda_g] += w_array[k];
235: PetscCall(MatDenseRestoreArrayWrite(G_local, &g_array));
236: PetscCall(VecRestoreArrayRead(w_local, &w_array));
237: }
239: /* LETKF Algorithm 2, Line 13: Update ensemble at grid point i_grid_point
240: E_a[i,:] = x_bar_f[i] + X_f[i,:] * G_local
242: Where:
243: - x_bar_f[i] is the forecast mean at grid point i_grid_point (ndof values from global mean vector)
244: - X_f[i,:] is the forecast anomaly rows at grid point i_grid_point (ndof rows from global anomaly matrix X)
245: - G_local = w_local * 1' + sqrt(m-1) * T_local^{1/2} * U (computed above in G_local)
246: */
247: {
248: const PetscScalar *x_array, *mean_array;
249: PetscScalar *e_array, *x_rows_array, *ea_rows_array;
250: PetscInt j, k, lda_x, lda_e;
252: /* Extract ndof rows starting at (i_grid_point * ndof) from X: X_f[i_grid_point*ndof:(i_grid_point+1)*ndof, :] */
253: PetscCall(MatDenseGetArrayRead(X, &x_array));
254: PetscCall(MatDenseGetArray(X_rows, &x_rows_array));
255: PetscCall(MatDenseGetLDA(X, &lda_x));
256: for (j = 0; j < m; j++) {
257: for (k = 0; k < ndof; k++) x_rows_array[k + j * ndof] = x_array[(i_grid_point * ndof + k) + j * lda_x];
258: }
259: PetscCall(MatDenseRestoreArray(X_rows, &x_rows_array));
260: PetscCall(MatDenseRestoreArrayRead(X, &x_array));
262: /* Apply local transform: E_analysis_rows = X_rows * G_local^T */
263: PetscCall(MatMatMult(X_rows, G_local, MAT_REUSE_MATRIX, PETSC_DEFAULT, &E_analysis_rows));
265: /* Add local mean: E_a[i_grid_point*ndof:(i_grid_point+1)*ndof, :] = x_bar_f[i_grid_point*ndof:(i_grid_point+1)*ndof] + X_f[...] * G_local */
266: PetscCall(VecGetArrayRead(impl->mean, &mean_array));
267: PetscCall(MatDenseGetArray(E_analysis_rows, &ea_rows_array));
268: for (j = 0; j < m; j++) {
269: for (k = 0; k < ndof; k++) ea_rows_array[k + j * ndof] += mean_array[i_grid_point * ndof + k];
270: }
271: PetscCall(MatDenseRestoreArray(E_analysis_rows, &ea_rows_array));
272: PetscCall(VecRestoreArrayRead(impl->mean, &mean_array));
274: /* Store result back in ensemble[i_grid_point*ndof:(i_grid_point+1)*ndof, :] */
275: PetscCall(MatDenseGetArrayWrite(en->ensemble, &e_array));
276: PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
277: PetscCall(MatDenseGetArrayRead(E_analysis_rows, (const PetscScalar **)&ea_rows_array));
278: for (j = 0; j < m; j++) {
279: for (k = 0; k < ndof; k++) e_array[(i_grid_point * ndof + k) + j * lda_e] = ea_rows_array[k + j * ndof];
280: }
281: PetscCall(MatDenseRestoreArrayRead(E_analysis_rows, (const PetscScalar **)&ea_rows_array));
282: PetscCall(MatDenseRestoreArrayWrite(en->ensemble, &e_array));
283: }
284: }
285: PetscCall(MatDestroy(&E_analysis_rows));
286: PetscCall(MatDestroy(&X_rows));
287: PetscCall(VecDestroy(&s_transpose_delta));
288: PetscCall(VecDestroy(&w_local));
289: PetscCall(VecDestroy(&r_inv_sqrt_local));
290: PetscCall(VecDestroy(&delta_scaled_local));
291: PetscCall(VecDestroy(&y_mean_local));
292: PetscCall(VecDestroy(&y_local));
293: PetscCall(MatDestroy(&G_local));
294: PetscCall(MatDestroy(&T_sqrt_local));
295: PetscCall(MatDestroy(&S_local));
296: PetscCall(MatDestroy(&Z_local));
297: PetscFunctionReturn(PETSC_SUCCESS);
298: }
300: static PetscErrorCode PetscDAEnsembleAnalysis_LETKF(PetscDA da, Vec observation, Mat H)
301: {
302: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
303: Mat X;
304: PetscInt m;
305: PetscBool reallocate = PETSC_FALSE;
307: PetscFunctionBegin;
308: m = impl->en.size;
310: /* Check if localization matrix Q is set */
311: PetscCheck(impl->Q, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Localization matrix Q not set. Call PetscDALETKFSetLocalization() first.");
313: /* Warn if Cholesky sqrt type is used with LETKF - it produces an asymmetric
314: T^{-1/2} = L^{-T} which is incorrect for the local perturbation update.
315: LETKF requires the symmetric square root T^{-1/2} = V * D^{-1/2} * V^T. */
316: PetscCheck(impl->en.sqrt_type != PETSCDA_SQRT_CHOLESKY, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Cholesky sqrt type produces asymmetric T^{-1/2}, which is incorrect for LETKF. Use -petscda_ensemble_sqrt_type eigen or PetscDAEnsembleSetSqrtType(da, PETSCDA_SQRT_EIGEN) instead.");
318: /* Check that ensemble size <= number of local observations per vertex.
319: The eigen decomposition of T = I + S^T*S (m x m) requires that the
320: local observation count p >= m; otherwise T is rank-deficient and the
321: decomposition is ill-posed. */
322: PetscCheck(m <= impl->n_obs_vertex, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Ensemble size (%" PetscInt_FMT ") must be <= number of local observations per vertex (%" PetscInt_FMT ") for LETKF eigen decomposition to be well-posed", m,
323: impl->n_obs_vertex);
325: /* Check for reallocation needs */
326: if (impl->mean) {
327: PetscInt mean_size;
328: PetscCall(VecGetSize(impl->mean, &mean_size));
329: if (mean_size != da->state_size) reallocate = PETSC_TRUE;
330: }
331: if (impl->Z) {
332: PetscInt z_rows, z_cols;
333: PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
334: if (z_rows != da->obs_size || z_cols != m) reallocate = PETSC_TRUE;
335: }
337: /* Initialize or reallocate persistent work objects */
338: if (!impl->mean || reallocate) {
339: PetscCall(VecDestroy(&impl->mean));
340: PetscCall(VecDestroy(&impl->y_mean));
341: PetscCall(VecDestroy(&impl->delta_scaled));
342: PetscCall(VecDestroy(&impl->w));
343: PetscCall(VecDestroy(&impl->r_inv_sqrt));
344: PetscCall(MatDestroy(&impl->Z));
345: PetscCall(MatDestroy(&impl->S));
346: PetscCall(MatDestroy(&impl->T_sqrt));
347: PetscCall(MatDestroy(&impl->w_ones));
349: /* Create mean vector from ensemble matrix (right vector = state space) */
350: PetscCall(MatCreateVecs(impl->en.ensemble, NULL, &impl->mean));
352: /* Create Z matrix (obs_size x m) */
353: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, da->obs_size, m, NULL, &impl->Z));
354: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->Z, "dense_"));
355: PetscCall(MatSetFromOptions(impl->Z));
356: PetscCall(MatSetUp(impl->Z));
358: /* Create observation space vectors from Z matrix (left vector = observation space) */
359: PetscCall(MatCreateVecs(impl->Z, NULL, &impl->y_mean));
360: PetscCall(VecDuplicate(impl->y_mean, &impl->delta_scaled));
361: PetscCall(VecDuplicate(da->obs_error_var, &impl->r_inv_sqrt));
363: /* Create S matrix (same layout as Z) */
364: PetscCall(MatDuplicate(impl->Z, MAT_DO_NOT_COPY_VALUES, &impl->S));
366: /* Create T_sqrt matrix (m x m) - usually small */
367: /* T_sqrt will hold the result of applying T^{-1/2} to identity matrix */
368: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->T_sqrt));
369: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->T_sqrt, "dense_"));
370: PetscCall(MatSetFromOptions(impl->T_sqrt));
371: PetscCall(MatSetUp(impl->T_sqrt));
373: /* Create w_ones matrix (m x m) */
374: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->w_ones));
375: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->w_ones, "dense_"));
376: PetscCall(MatSetFromOptions(impl->w_ones));
377: PetscCall(MatSetUp(impl->w_ones));
378: }
380: /* Alg 6.4 line 1-2: Compute ensemble mean and scaled anomalies */
381: PetscCall(PetscDAEnsembleComputeMean(da, impl->mean));
383: /* Create anomaly matrix X = (E - x_mean * 1') / sqrt(m - 1) */
384: PetscCall(PetscDAEnsembleComputeAnomalies(da, impl->mean, &X));
386: /* Alg 6.4 line 3-4: Compute GLOBAL observation ensemble Z = H * E */
387: /* Note: When H is a Kokkos matrix type (e.g., aijkokkos), MatMatMult may fail
388: with non-Kokkos dense matrices. Use column-by-column multiplication with
389: temporary vectors that are compatible with H's type. */
390: {
391: Vec col_in, col_out, temp_in, temp_out;
392: PetscInt j;
394: /* Create temporary vectors compatible with H's type */
395: PetscCall(MatCreateVecs(H, &temp_in, &temp_out));
397: /* Compute Z = H * E column by column to avoid Kokkos vector type issues */
398: for (j = 0; j < m; j++) {
399: PetscCall(MatDenseGetColumnVecRead(impl->en.ensemble, j, &col_in));
400: PetscCall(MatDenseGetColumnVecWrite(impl->Z, j, &col_out));
402: /* Copy to temp vector, multiply, then copy back */
403: PetscCall(VecCopy(col_in, temp_in));
404: PetscCall(MatMult(H, temp_in, temp_out));
405: PetscCall(VecCopy(temp_out, col_out));
407: PetscCall(MatDenseRestoreColumnVecWrite(impl->Z, j, &col_out));
408: PetscCall(MatDenseRestoreColumnVecRead(impl->en.ensemble, j, &col_in));
409: }
410: PetscCall(MatAssemblyBegin(impl->Z, MAT_FINAL_ASSEMBLY));
411: PetscCall(MatAssemblyEnd(impl->Z, MAT_FINAL_ASSEMBLY));
413: PetscCall(VecDestroy(&temp_out));
414: PetscCall(VecDestroy(&temp_in));
415: }
417: /* Compute GLOBAL observation mean y_mean = H * x_mean */
418: /* Use temporary vector compatible with H's type */
419: {
420: Vec temp_mean, temp_y_mean;
421: PetscCall(MatCreateVecs(H, &temp_mean, &temp_y_mean));
422: PetscCall(VecCopy(impl->mean, temp_mean));
423: PetscCall(MatMult(H, temp_mean, temp_y_mean));
424: PetscCall(VecCopy(temp_y_mean, impl->y_mean));
425: PetscCall(VecDestroy(&temp_y_mean));
426: PetscCall(VecDestroy(&temp_mean));
427: }
429: /* Compute GLOBAL R^{-1/2} (assumes diagonal R) */
430: PetscCall(VecCopy(da->obs_error_var, impl->r_inv_sqrt));
431: PetscCall(VecSqrtAbs(impl->r_inv_sqrt));
432: PetscCall(VecReciprocal(impl->r_inv_sqrt));
433: /* Perform local analysis for all vertices */
435: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
436: /* Use GPU version only if:
437: 1. sqrt_type is eigen (GPU version only implements eigen/SVD, not cholesky)
438: 2. H matrix is a Kokkos type (aijkokkos) */
439: {
440: PetscBool use_gpu = PETSC_FALSE;
441: if (impl->en.sqrt_type == PETSCDA_SQRT_EIGEN) {
442: #if !defined(PETSC_USE_COMPLEX)
443: /* Check if H matrix is a Kokkos type */
444: PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, &use_gpu, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
445: #endif
446: }
448: /* Scatter global vectors to local work vectors if available */
449: if (impl->obs_scat) {
450: PetscCall(VecScatterBegin(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
451: PetscCall(VecScatterEnd(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
453: PetscCall(VecScatterBegin(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
454: PetscCall(VecScatterEnd(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
456: PetscCall(VecScatterBegin(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
457: PetscCall(VecScatterEnd(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
459: /* Handle Z matrix (scatter columns) */
460: {
461: PetscInt n_obs_local;
462: PetscCall(VecGetLocalSize(impl->obs_work, &n_obs_local));
463: if (!impl->Z_work) {
464: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
465: } else {
466: PetscInt m_old, n_old;
467: PetscCall(MatGetSize(impl->Z_work, &n_old, &m_old));
468: if (m_old != m || n_old != n_obs_local) {
469: PetscCall(MatDestroy(&impl->Z_work));
470: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
471: }
472: }
473: for (PetscInt i = 0; i < m; i++) {
474: Vec z_col_global, z_col_local;
475: PetscCall(MatDenseGetColumnVecRead(impl->Z, i, &z_col_global));
476: PetscCall(MatDenseGetColumnVecWrite(impl->Z_work, i, &z_col_local));
477: PetscCall(VecScatterBegin(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
478: PetscCall(VecScatterEnd(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
479: PetscCall(MatDenseRestoreColumnVecRead(impl->Z, i, &z_col_global));
480: PetscCall(MatDenseRestoreColumnVecWrite(impl->Z_work, i, &z_col_local));
481: }
482: }
483: }
485: if (use_gpu) {
486: PetscInt n_local;
487: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
488: /* Use local work vectors for GPU analysis */
489: PetscCall(PetscDALETKFLocalAnalysis_GPU(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
490: } else {
491: PetscInt n_local;
492: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
493: if (impl->obs_scat) {
494: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
495: } else {
496: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, observation, impl->Z, impl->y_mean, impl->r_inv_sqrt));
497: }
498: }
499: }
500: #else
501: /* Without Kokkos, use CPU version */
502: {
503: PetscInt n_local;
504: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
505: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, observation, impl->Z, impl->y_mean, impl->r_inv_sqrt));
506: }
507: #endif
508: PetscCall(MatDestroy(&X));
509: PetscFunctionReturn(PETSC_SUCCESS);
510: }
512: static PetscErrorCode PetscDALETKFSetObsPerVertex_LETKF(PetscDA da, PetscInt n_obs_vertex)
513: {
514: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
516: PetscFunctionBegin;
517: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
518: impl->n_obs_vertex = n_obs_vertex;
519: PetscFunctionReturn(PETSC_SUCCESS);
520: }
522: static PetscErrorCode PetscDALETKFGetObsPerVertex_LETKF(PetscDA da, PetscInt *n_obs_vertex)
523: {
524: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
526: PetscFunctionBegin;
527: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
528: *n_obs_vertex = impl->n_obs_vertex;
529: PetscFunctionReturn(PETSC_SUCCESS);
530: }
532: static PetscErrorCode PetscDALETKFSetLocalization_LETKF(PetscDA da, Mat Q, Mat H)
533: {
534: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
535: PetscInt i, nrows, ncols, nnz, rstart, rend;
537: PetscFunctionBegin;
538: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
540: /* Get matrix dimensions */
541: PetscCall(MatGetSize(Q, &nrows, &ncols));
543: /* Validate matrix dimensions */
544: PetscCheck(nrows == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix rows (%" PetscInt_FMT ") must match state size (%" PetscInt_FMT ")", nrows, da->state_size);
545: PetscCheck(ncols == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix columns (%" PetscInt_FMT ") must match observation size (%" PetscInt_FMT ")", ncols, da->obs_size);
547: /* Validate that each row has const non-zero entries */
548: PetscCall(MatGetOwnershipRange(Q, &rstart, &rend));
549: for (i = rstart; i < rend; i++) {
550: const PetscInt *cols;
551: const PetscScalar *vals;
552: PetscCall(MatGetRow(Q, i, &nnz, &cols, &vals));
553: PetscCheck(nnz == impl->n_obs_vertex, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Row %" PetscInt_FMT " has %" PetscInt_FMT " non-zeros, expected %" PetscInt_FMT, i, nnz, (PetscInt)impl->n_obs_vertex);
554: PetscCall(MatRestoreRow(Q, i, &nnz, &cols, &vals));
555: }
557: /* Store the localization matrix */
558: PetscCall(MatDestroy(&impl->Q));
559: PetscCall(PetscObjectReference((PetscObject)Q));
560: impl->Q = Q;
561: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
562: PetscCall(PetscDALETKFSetupLocalization_Kokkos(impl, H));
563: #endif
564: PetscFunctionReturn(PETSC_SUCCESS);
565: }
567: static PetscErrorCode PetscDAView_LETKF(PetscDA da, PetscViewer viewer)
568: {
569: PetscBool iascii;
570: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
572: PetscFunctionBegin;
573: PetscCall(PetscDAView_Ensemble(da, viewer));
574: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
575: if (iascii) {
576: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
577: if (impl->en.sqrt_type == PETSCDA_SQRT_CHOLESKY) {
578: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU\n"));
579: } else {
580: /* Check if R matrix is Kokkos type to determine if GPU will be used */
581: if (da->R) {
582: PetscBool is_kokkos = PETSC_FALSE;
583: PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, &is_kokkos, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
584: if (is_kokkos) {
585: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: Kokkos\n"));
586: } else {
587: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU\n"));
588: }
589: } else {
590: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU or Kokkos (depending on covariance matrix type)\n"));
591: }
592: }
593: #else
594: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU\n"));
595: #endif
596: PetscCall(PetscViewerASCIIPrintf(viewer, " Local observations per vertex: %" PetscInt_FMT "\n", impl->n_obs_vertex));
597: if (impl->batch_size > 0) {
598: PetscCall(PetscViewerASCIIPrintf(viewer, " GPU batch size: %" PetscInt_FMT "\n", impl->batch_size));
599: } else {
600: PetscCall(PetscViewerASCIIPrintf(viewer, " GPU batch size: auto\n"));
601: }
602: if (impl->Q) {
603: PetscCall(PetscViewerASCIIPrintf(viewer, " Localization matrix: set\n"));
604: } else {
605: PetscCall(PetscViewerASCIIPrintf(viewer, " Localization matrix: not set\n"));
606: }
607: }
608: PetscFunctionReturn(PETSC_SUCCESS);
609: }
611: static PetscErrorCode PetscDASetFromOptions_LETKF(PetscDA da, PetscOptionItems *PetscOptionsObjectPtr)
612: {
613: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
614: PetscOptionItems PetscOptionsObject = *PetscOptionsObjectPtr;
616: PetscFunctionBegin;
617: PetscCall(PetscDASetFromOptions_Ensemble(da, PetscOptionsObjectPtr));
618: PetscOptionsHeadBegin(PetscOptionsObject, "PetscDA LETKF Options");
619: PetscCall(PetscOptionsInt("-petscda_letkf_batch_size", "Batch size for GPU processing", "", impl->batch_size, &impl->batch_size, NULL));
620: PetscCall(PetscOptionsInt("-petscda_letkf_obs_per_vertex", "Number of local observations per vertex", "", impl->n_obs_vertex, &impl->n_obs_vertex, NULL));
621: PetscOptionsHeadEnd();
622: PetscFunctionReturn(PETSC_SUCCESS);
623: }
625: /*MC
626: PETSCDALETKF - The Local ETKF performs the analysis update locally around each grid point, enabling scalable assimilation on large
627: domains by avoiding the global ensemble covariance matrix.
629: Options Database Keys:
630: + -petscda_type letkf - set the `PetscDAType` to `PETSCDALETKF`
631: . -petscda_ensemble_size <size> - number of ensemble members
632: . -petscda_ensemble_sqrt_type <cholesky, eigen> - the square root of the matrix to use
633: . -petscda_letkf_batch_size <batch_size> - set the batch size for GPU processing
634: - -petscda_letkf_obs_per_vertex <n_obs_vertex> - number of observations per vertex
636: Level: beginner
638: .seealso: [](ch_da), `PetscDA`, `PetscDACreate()`, `PETSCDAETKF`, `PetscDALETKFSetObsPerVertex()`, `PetscDALETKFGetObsPerVertex()`,
639: `PetscDALETKFSetLocalization()`, `PetscDAEnsembleSetSize()`, `PetscDASetSizes()`, `PetscDAEnsembleSetSqrtType()`, `PetscDAEnsembleSetInflation()`,
640: `PetscDAEnsembleComputeMean()`, `PetscDAEnsembleComputeAnomalies()`, `PetscDAEnsembleAnalysis()`, `PetscDAEnsembleForecast()`
641: M*/
643: PETSC_INTERN PetscErrorCode PetscDACreate_LETKF(PetscDA da)
644: {
645: PetscDA_LETKF *impl;
647: PetscFunctionBegin;
648: PetscCall(PetscNew(&impl));
649: da->data = impl;
650: PetscCall(PetscDACreate_Ensemble(da));
651: da->ops->destroy = PetscDADestroy_LETKF;
652: da->ops->view = PetscDAView_LETKF;
653: da->ops->setfromoptions = PetscDASetFromOptions_LETKF;
654: impl->en.analysis = PetscDAEnsembleAnalysis_LETKF;
655: impl->en.forecast = PetscDAEnsembleForecast_Ensemble;
657: impl->n_obs_vertex = 9;
658: impl->Q = NULL;
659: impl->batch_size = 0;
661: /* Register the method for setting localization */
662: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalization_C", PetscDALETKFSetLocalization_LETKF));
663: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetObsPerVertex_C", PetscDALETKFSetObsPerVertex_LETKF));
664: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetObsPerVertex_C", PetscDALETKFGetObsPerVertex_LETKF));
665: PetscFunctionReturn(PETSC_SUCCESS);
666: }
668: /*@
669: PetscDALETKFSetObsPerVertex - Sets the number of local observations per vertex for the LETKF algorithm.
671: Logically Collective
673: Input Parameters:
674: + da - the `PetscDA` context
675: - n_obs_vertex - number of observations per vertex
677: Level: advanced
679: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalization()`
680: @*/
681: PetscErrorCode PetscDALETKFSetObsPerVertex(PetscDA da, PetscInt n_obs_vertex)
682: {
683: PetscFunctionBegin;
686: PetscTryMethod(da, "PetscDALETKFSetObsPerVertex_C", (PetscDA, PetscInt), (da, n_obs_vertex));
687: PetscFunctionReturn(PETSC_SUCCESS);
688: }
690: /*@
691: PetscDALETKFGetObsPerVertex - Gets the number of local observations per vertex for the LETKF algorithm.
693: Not Collective
695: Input Parameter:
696: . da - the `PetscDA` context
698: Output Parameter:
699: . n_obs_vertex - number of observations per vertex
701: Level: advanced
703: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetObsPerVertex()`
704: @*/
705: PetscErrorCode PetscDALETKFGetObsPerVertex(PetscDA da, PetscInt *n_obs_vertex)
706: {
707: PetscFunctionBegin;
709: PetscAssertPointer(n_obs_vertex, 2);
710: PetscUseMethod(da, "PetscDALETKFGetObsPerVertex_C", (PetscDA, PetscInt *), (da, n_obs_vertex));
711: PetscFunctionReturn(PETSC_SUCCESS);
712: }
714: /*@
715: PetscDALETKFSetLocalization - Sets the localization matrix for the LETKF algorithm.
717: Collective
719: Input Parameters:
720: + da - the `PetscDA` context
721: . Q - the localization matrix (N x P)
722: - H - the observation operator matrix (P x N)
724: Level: advanced
726: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`
727: @*/
728: PetscErrorCode PetscDALETKFSetLocalization(PetscDA da, Mat Q, Mat H)
729: {
730: PetscFunctionBegin;
734: PetscTryMethod(da, "PetscDALETKFSetLocalization_C", (PetscDA, Mat, Mat), (da, Q, H));
735: PetscFunctionReturn(PETSC_SUCCESS);
736: }