Actual source code: letkfilter.c
1: #include <petscda.h>
2: #include <petsc/private/daimpl.h>
3: #include <petsc/private/daensembleimpl.h>
4: #include <../src/ml/da/impls/ensemble/letkf/letkf.h>
6: static PetscErrorCode PetscDADestroy_LETKF(PetscDA da)
7: {
8: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
10: PetscFunctionBegin;
11: PetscCall(VecDestroy(&impl->mean));
12: PetscCall(VecDestroy(&impl->y_mean));
13: PetscCall(VecDestroy(&impl->delta_scaled));
14: PetscCall(VecDestroy(&impl->w));
15: PetscCall(VecDestroy(&impl->r_inv_sqrt));
16: PetscCall(MatDestroy(&impl->Z));
17: PetscCall(MatDestroy(&impl->S));
18: PetscCall(MatDestroy(&impl->T_sqrt));
19: PetscCall(MatDestroy(&impl->w_ones));
20: PetscCall(MatDestroy(&impl->Q));
21: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
22: PetscCall(PetscDALETKFDestroyLocalization_Kokkos(impl));
23: #endif
24: PetscCall(ISDestroy(&impl->obs_is_local));
25: PetscCall(VecScatterDestroy(&impl->obs_scat));
26: PetscCall(VecDestroy(&impl->obs_work));
27: PetscCall(VecDestroy(&impl->y_mean_work));
28: PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
29: PetscCall(MatDestroy(&impl->Z_work));
30: PetscCall(PetscDADestroy_Ensemble(da));
31: PetscCall(PetscFree(da->data));
33: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalization_C", NULL));
34: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetObsPerVertex_C", NULL));
35: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetObsPerVertex_C", NULL));
36: PetscFunctionReturn(PETSC_SUCCESS);
37: }
39: /*
40: ExtractLocalObservations - Extracts local observations for a vertex using localization matrix Q (CPU version)
42: Input Parameters:
43: + Q - localization matrix (state_size/ndof x obs_size), each row has constant non-zeros
44: . vertex_idx - index of the vertex (row of Q)
45: . Z_global - global observation ensemble matrix (obs_size x m) OR local work matrix
46: . y_global - global observation vector (size obs_size) OR local work vector
47: . y_mean_global - global observation mean (size obs_size) OR local work vector
48: . r_inv_sqrt_global - global R^{-1/2} (size obs_size) OR local work vector
49: . obs_g2l - map from global observation index to local index (if using local work vectors)
50: . m - ensemble size
52: Output Parameters:
53: . Z_local - local observation ensemble (p_local x m), pre-allocated
54: . y_local - local observation vector (size p_local), pre-allocated
55: . y_mean_local - local observation mean (size p_local), pre-allocated
56: - r_inv_sqrt_local - local R^{-1/2} (size p_local), pre-allocated
57: */
58: static PetscErrorCode ExtractLocalObservations(Mat Q, PetscInt vertex_idx, Mat Z_global, Vec y_global, Vec y_mean_global, Vec r_inv_sqrt_global, PetscHMapI obs_g2l, PetscInt m, Mat Z_local, Vec y_local, Vec y_mean_local, Vec r_inv_sqrt_local)
59: {
60: const PetscInt *cols;
61: const PetscScalar *vals;
62: PetscInt ncols, k, j;
63: const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
64: PetscScalar *z_local_array, *y_local_array, *y_mean_local_array, *r_inv_sqrt_local_array;
65: PetscInt lda_z_global, lda_z_local;
67: PetscFunctionBegin;
68: /* Get the row of Q corresponding to this vertex */
69: PetscCall(MatGetRow(Q, vertex_idx, &ncols, &cols, &vals));
71: /* Get array access to global data */
72: PetscCall(MatDenseGetArrayRead(Z_global, &z_global_array));
73: PetscCall(VecGetArrayRead(y_global, &y_global_array));
74: PetscCall(VecGetArrayRead(y_mean_global, &y_mean_global_array));
75: PetscCall(VecGetArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
77: /* Get array access to local data */
78: PetscCall(MatDenseGetArrayWrite(Z_local, &z_local_array));
79: PetscCall(VecGetArray(y_local, &y_local_array));
80: PetscCall(VecGetArray(y_mean_local, &y_mean_local_array));
81: PetscCall(VecGetArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
83: /* Get leading dimensions */
84: PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
85: PetscCall(MatDenseGetLDA(Z_local, &lda_z_local));
87: /* Extract local observations and weight R^{-1/2} */
88: for (k = 0; k < ncols; k++) {
89: PetscInt obs_idx = cols[k];
90: PetscScalar weight = vals[k];
91: PetscInt local_idx = obs_idx;
93: /* If using local work vectors, map global index to local index */
94: if (obs_g2l) {
95: PetscCall(PetscHMapIGet(obs_g2l, obs_idx, &local_idx));
96: PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local map", obs_idx);
97: }
99: y_local_array[k] = y_global_array[local_idx];
100: y_mean_local_array[k] = y_mean_global_array[local_idx];
101: r_inv_sqrt_local_array[k] = r_inv_sqrt_global_array[local_idx] * PetscSqrtScalar(weight);
103: /* Extract Z matrix row (column-major layout) */
104: for (j = 0; j < m; j++) z_local_array[k + j * lda_z_local] = z_global_array[local_idx + j * lda_z_global];
105: }
107: /* Restore arrays */
108: PetscCall(VecRestoreArray(r_inv_sqrt_local, &r_inv_sqrt_local_array));
109: PetscCall(VecRestoreArray(y_mean_local, &y_mean_local_array));
110: PetscCall(VecRestoreArray(y_local, &y_local_array));
111: PetscCall(MatDenseRestoreArrayWrite(Z_local, &z_local_array));
112: PetscCall(VecRestoreArrayRead(r_inv_sqrt_global, &r_inv_sqrt_global_array));
113: PetscCall(VecRestoreArrayRead(y_mean_global, &y_mean_global_array));
114: PetscCall(VecRestoreArrayRead(y_global, &y_global_array));
115: PetscCall(MatDenseRestoreArrayRead(Z_global, &z_global_array));
117: /* Restore Q row */
118: PetscCall(MatRestoreRow(Q, vertex_idx, &ncols, &cols, &vals));
120: /* Assemble local matrices/vectors */
121: PetscCall(MatAssemblyBegin(Z_local, MAT_FINAL_ASSEMBLY));
122: PetscCall(MatAssemblyEnd(Z_local, MAT_FINAL_ASSEMBLY));
123: PetscFunctionReturn(PETSC_SUCCESS);
124: }
126: /*
127: PetscDALETKFLocalAnalysis - Performs local LETKF analysis for all grid points (CPU version)
129: Input Parameters:
130: + da - the PetscDA context
131: . impl - LETKF implementation data
132: . m - ensemble size
133: . n_vertices - number of grid points
134: . X - global anomaly matrix (state_size x m)
135: . observation - observation vector
136: . Z_global - global observation ensemble (obs_size x m)
137: . y_mean_global - global observation mean
138: - r_inv_sqrt_global - global R^{-1/2}
140: Output:
141: . da->ensemble - updated with analysis ensemble
143: Notes:
144: This function performs the local analysis loop for LETKF, processing each grid point
145: independently using its local observations defined by the localization matrix Q.
146: This is the CPU version that does not use Kokkos acceleration.
148: All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
149: y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
150: created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
151: */
152: PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
153: {
154: PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
155: Mat Z_local, S_local, T_sqrt_local, G_local;
156: Vec y_local, y_mean_local, delta_scaled_local, r_inv_sqrt_local;
157: Vec w_local, s_transpose_delta;
158: PetscInt ndof;
159: PetscReal sqrt_m_minus_1, scale;
160: PetscInt rstart;
161: Mat X_rows, E_analysis_rows;
163: PetscFunctionBegin;
164: ndof = da->ndof;
165: scale = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
166: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
167: /* Create local analysis workspace (n_obs_vertex x m matrices and vectors) */
168: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, impl->n_obs_vertex, m, NULL, &Z_local));
169: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)Z_local, "dense_"));
170: PetscCall(MatSetFromOptions(Z_local));
171: PetscCall(MatSetUp(Z_local));
172: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, impl->n_obs_vertex, m, NULL, &S_local));
173: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)S_local, "dense_"));
174: PetscCall(MatSetFromOptions(S_local));
175: PetscCall(MatSetUp(S_local));
176: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &T_sqrt_local));
177: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)T_sqrt_local, "dense_"));
178: PetscCall(MatSetFromOptions(T_sqrt_local));
179: PetscCall(MatSetUp(T_sqrt_local));
180: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &G_local));
181: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)G_local, "dense_"));
182: PetscCall(MatSetFromOptions(G_local));
183: PetscCall(MatSetUp(G_local));
185: /* Create vectors using MatCreateVecs from Z_local (n_obs_vertex x m) */
186: PetscCall(MatCreateVecs(Z_local, &w_local, &y_local));
187: PetscCall(VecDuplicate(y_local, &y_mean_local));
188: PetscCall(VecDuplicate(y_local, &delta_scaled_local));
189: PetscCall(VecDuplicate(y_local, &r_inv_sqrt_local));
190: PetscCall(VecDuplicate(w_local, &s_transpose_delta));
192: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, ndof, m, NULL, &X_rows));
193: PetscCall(MatDuplicate(X_rows, MAT_DO_NOT_COPY_VALUES, &E_analysis_rows));
195: /* LETKF: Loop over all grid points and perform local analysis */
196: PetscCall(MatGetOwnershipRange(impl->Q, &rstart, NULL));
198: for (PetscInt i_grid_point = 0; i_grid_point < n_vertices; i_grid_point++) {
199: /* Extract local observations for this grid point using Q[i_grid_point,:] */
200: /* Note: i_grid_point is local index, but MatGetRow needs global index */
201: PetscCall(ExtractLocalObservations(impl->Q, rstart + i_grid_point, Z_global, observation, y_mean_global, r_inv_sqrt_global, impl->obs_g2l, m, Z_local, y_local, y_mean_local, r_inv_sqrt_local));
203: /* Compute local normalized innovation matrix: S_local = R_local^{-1/2} * (Z_local - y_mean_local * 1') / sqrt(m - 1) */
204: PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(Z_local, y_mean_local, r_inv_sqrt_local, m, scale, S_local));
206: /* Compute local delta_scaled = R_local^{-1/2} * (y_local - y_mean_local) */
207: PetscCall(VecWAXPY(delta_scaled_local, -1.0, y_mean_local, y_local));
208: PetscCall(VecPointwiseMult(delta_scaled_local, delta_scaled_local, r_inv_sqrt_local));
210: /* Factor local T = (I + S_local^T * S_local) */
211: PetscCall(PetscDAEnsembleTFactor(da, S_local));
213: /* Compute local analysis weights: w_local = T_local^{-1} * S_local^T * delta_scaled_local */
214: PetscCall(MatMultTranspose(S_local, delta_scaled_local, s_transpose_delta));
215: PetscCall(PetscDAEnsembleApplyTInverse(da, s_transpose_delta, w_local));
217: /* Compute local square-root transform: T_sqrt_local = T_local^{-1/2} (U is identity, so pass NULL) */
218: PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, T_sqrt_local));
220: /* Form local transform G_local = w_local * 1' + sqrt(m - 1) * T_sqrt_local * U
221: Instead of creating w_ones_local = w_local * 1', we add w_local to each column of G_local */
222: PetscCall(MatCopy(T_sqrt_local, G_local, SAME_NONZERO_PATTERN));
223: PetscCall(MatScale(G_local, sqrt_m_minus_1));
224: {
225: const PetscScalar *w_array;
226: PetscScalar *g_array;
227: PetscInt j, k, lda_g;
229: PetscCall(VecGetArrayRead(w_local, &w_array));
230: PetscCall(MatDenseGetArrayWrite(G_local, &g_array));
231: PetscCall(MatDenseGetLDA(G_local, &lda_g));
232: for (j = 0; j < m; j++)
233: for (k = 0; k < m; k++) g_array[k + j * lda_g] += w_array[k];
234: PetscCall(MatDenseRestoreArrayWrite(G_local, &g_array));
235: PetscCall(VecRestoreArrayRead(w_local, &w_array));
236: }
238: /* LETKF Algorithm 2, Line 13: Update ensemble at grid point i_grid_point
239: E_a[i,:] = x_bar_f[i] + X_f[i,:] * G_local
241: Where:
242: - x_bar_f[i] is the forecast mean at grid point i_grid_point (ndof values from global mean vector)
243: - X_f[i,:] is the forecast anomaly rows at grid point i_grid_point (ndof rows from global anomaly matrix X)
244: - G_local = w_local * 1' + sqrt(m-1) * T_local^{1/2} * U (computed above in G_local)
245: */
246: {
247: const PetscScalar *x_array, *mean_array;
248: PetscScalar *e_array, *x_rows_array, *ea_rows_array;
249: PetscInt j, k, lda_x, lda_e;
251: /* Extract ndof rows starting at (i_grid_point * ndof) from X: X_f[i_grid_point*ndof:(i_grid_point+1)*ndof, :] */
252: PetscCall(MatDenseGetArrayRead(X, &x_array));
253: PetscCall(MatDenseGetArray(X_rows, &x_rows_array));
254: PetscCall(MatDenseGetLDA(X, &lda_x));
255: for (j = 0; j < m; j++) {
256: for (k = 0; k < ndof; k++) x_rows_array[k + j * ndof] = x_array[(i_grid_point * ndof + k) + j * lda_x];
257: }
258: PetscCall(MatDenseRestoreArray(X_rows, &x_rows_array));
259: PetscCall(MatDenseRestoreArrayRead(X, &x_array));
261: /* Apply local transform: E_analysis_rows = X_rows * G_local^T */
262: PetscCall(MatMatMult(X_rows, G_local, MAT_REUSE_MATRIX, PETSC_DEFAULT, &E_analysis_rows));
264: /* Add local mean: E_a[i_grid_point*ndof:(i_grid_point+1)*ndof, :] = x_bar_f[i_grid_point*ndof:(i_grid_point+1)*ndof] + X_f[...] * G_local */
265: PetscCall(VecGetArrayRead(impl->mean, &mean_array));
266: PetscCall(MatDenseGetArray(E_analysis_rows, &ea_rows_array));
267: for (j = 0; j < m; j++) {
268: for (k = 0; k < ndof; k++) ea_rows_array[k + j * ndof] += mean_array[i_grid_point * ndof + k];
269: }
270: PetscCall(MatDenseRestoreArray(E_analysis_rows, &ea_rows_array));
271: PetscCall(VecRestoreArrayRead(impl->mean, &mean_array));
273: /* Store result back in ensemble[i_grid_point*ndof:(i_grid_point+1)*ndof, :] */
274: PetscCall(MatDenseGetArrayWrite(en->ensemble, &e_array));
275: PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
276: PetscCall(MatDenseGetArrayRead(E_analysis_rows, (const PetscScalar **)&ea_rows_array));
277: for (j = 0; j < m; j++) {
278: for (k = 0; k < ndof; k++) e_array[(i_grid_point * ndof + k) + j * lda_e] = ea_rows_array[k + j * ndof];
279: }
280: PetscCall(MatDenseRestoreArrayRead(E_analysis_rows, (const PetscScalar **)&ea_rows_array));
281: PetscCall(MatDenseRestoreArrayWrite(en->ensemble, &e_array));
282: }
283: }
284: PetscCall(MatDestroy(&E_analysis_rows));
285: PetscCall(MatDestroy(&X_rows));
286: PetscCall(VecDestroy(&s_transpose_delta));
287: PetscCall(VecDestroy(&w_local));
288: PetscCall(VecDestroy(&r_inv_sqrt_local));
289: PetscCall(VecDestroy(&delta_scaled_local));
290: PetscCall(VecDestroy(&y_mean_local));
291: PetscCall(VecDestroy(&y_local));
292: PetscCall(MatDestroy(&G_local));
293: PetscCall(MatDestroy(&T_sqrt_local));
294: PetscCall(MatDestroy(&S_local));
295: PetscCall(MatDestroy(&Z_local));
296: PetscFunctionReturn(PETSC_SUCCESS);
297: }
299: static PetscErrorCode PetscDAEnsembleAnalysis_LETKF(PetscDA da, Vec observation, Mat H)
300: {
301: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
302: Mat X;
303: PetscInt m;
304: PetscBool reallocate = PETSC_FALSE;
306: PetscFunctionBegin;
307: m = impl->en.size;
309: /* Check if localization matrix Q is set */
310: PetscCheck(impl->Q, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Localization matrix Q not set. Call PetscDALETKFSetLocalization() first.");
312: /* Warn if Cholesky sqrt type is used with LETKF - it produces an asymmetric
313: T^{-1/2} = L^{-T} which is incorrect for the local perturbation update.
314: LETKF requires the symmetric square root T^{-1/2} = V * D^{-1/2} * V^T. */
315: PetscCheck(impl->en.sqrt_type != PETSCDA_SQRT_CHOLESKY, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Cholesky sqrt type produces asymmetric T^{-1/2}, which is incorrect for LETKF. Use -petscda_ensemble_sqrt_type eigen or PetscDAEnsembleSetSqrtType(da, PETSCDA_SQRT_EIGEN) instead.");
317: /* Check that ensemble size <= number of local observations per vertex.
318: The eigen decomposition of T = I + S^T*S (m x m) requires that the
319: local observation count p >= m; otherwise T is rank-deficient and the
320: decomposition is ill-posed. */
321: PetscCheck(m <= impl->n_obs_vertex, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Ensemble size (%" PetscInt_FMT ") must be <= number of local observations per vertex (%" PetscInt_FMT ") for LETKF eigen decomposition to be well-posed", m,
322: impl->n_obs_vertex);
324: /* Check for reallocation needs */
325: if (impl->mean) {
326: PetscInt mean_size;
327: PetscCall(VecGetSize(impl->mean, &mean_size));
328: if (mean_size != da->state_size) reallocate = PETSC_TRUE;
329: }
330: if (impl->Z) {
331: PetscInt z_rows, z_cols;
332: PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
333: if (z_rows != da->obs_size || z_cols != m) reallocate = PETSC_TRUE;
334: }
336: /* Initialize or reallocate persistent work objects */
337: if (!impl->mean || reallocate) {
338: PetscCall(VecDestroy(&impl->mean));
339: PetscCall(VecDestroy(&impl->y_mean));
340: PetscCall(VecDestroy(&impl->delta_scaled));
341: PetscCall(VecDestroy(&impl->w));
342: PetscCall(VecDestroy(&impl->r_inv_sqrt));
343: PetscCall(MatDestroy(&impl->Z));
344: PetscCall(MatDestroy(&impl->S));
345: PetscCall(MatDestroy(&impl->T_sqrt));
346: PetscCall(MatDestroy(&impl->w_ones));
348: /* Create mean vector from ensemble matrix (right vector = state space) */
349: PetscCall(MatCreateVecs(impl->en.ensemble, NULL, &impl->mean));
351: /* Create Z matrix (obs_size x m) */
352: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, da->obs_size, m, NULL, &impl->Z));
353: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->Z, "dense_"));
354: PetscCall(MatSetFromOptions(impl->Z));
355: PetscCall(MatSetUp(impl->Z));
357: /* Create observation space vectors from Z matrix (left vector = observation space) */
358: PetscCall(MatCreateVecs(impl->Z, NULL, &impl->y_mean));
359: PetscCall(VecDuplicate(impl->y_mean, &impl->delta_scaled));
360: PetscCall(VecDuplicate(da->obs_error_var, &impl->r_inv_sqrt));
362: /* Create S matrix (same layout as Z) */
363: PetscCall(MatDuplicate(impl->Z, MAT_DO_NOT_COPY_VALUES, &impl->S));
365: /* Create T_sqrt matrix (m x m) - usually small */
366: /* T_sqrt will hold the result of applying T^{-1/2} to identity matrix */
367: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->T_sqrt));
368: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->T_sqrt, "dense_"));
369: PetscCall(MatSetFromOptions(impl->T_sqrt));
370: PetscCall(MatSetUp(impl->T_sqrt));
372: /* Create w_ones matrix (m x m) */
373: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->w_ones));
374: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->w_ones, "dense_"));
375: PetscCall(MatSetFromOptions(impl->w_ones));
376: PetscCall(MatSetUp(impl->w_ones));
377: }
379: /* Alg 6.4 line 1-2: Compute ensemble mean and scaled anomalies */
380: PetscCall(PetscDAEnsembleComputeMean(da, impl->mean));
382: /* Create anomaly matrix X = (E - x_mean * 1') / sqrt(m - 1) */
383: PetscCall(PetscDAEnsembleComputeAnomalies(da, impl->mean, &X));
385: /* Alg 6.4 line 3-4: Compute GLOBAL observation ensemble Z = H * E */
386: /* Note: When H is a Kokkos matrix type (e.g., aijkokkos), MatMatMult may fail
387: with non-Kokkos dense matrices. Use column-by-column multiplication with
388: temporary vectors that are compatible with H's type. */
389: {
390: Vec col_in, col_out, temp_in, temp_out;
392: /* Create temporary vectors compatible with H's type */
393: PetscCall(MatCreateVecs(H, &temp_in, &temp_out));
395: /* Compute Z = H * E column by column to avoid Kokkos vector type issues */
396: for (PetscInt j = 0; j < m; j++) {
397: PetscCall(MatDenseGetColumnVecRead(impl->en.ensemble, j, &col_in));
398: PetscCall(MatDenseGetColumnVecWrite(impl->Z, j, &col_out));
400: /* Copy to temp vector, multiply, then copy back */
401: PetscCall(VecCopy(col_in, temp_in));
402: PetscCall(MatMult(H, temp_in, temp_out));
403: PetscCall(VecCopy(temp_out, col_out));
405: PetscCall(MatDenseRestoreColumnVecWrite(impl->Z, j, &col_out));
406: PetscCall(MatDenseRestoreColumnVecRead(impl->en.ensemble, j, &col_in));
407: }
408: PetscCall(MatAssemblyBegin(impl->Z, MAT_FINAL_ASSEMBLY));
409: PetscCall(MatAssemblyEnd(impl->Z, MAT_FINAL_ASSEMBLY));
411: PetscCall(VecDestroy(&temp_out));
412: PetscCall(VecDestroy(&temp_in));
413: }
415: /* Compute GLOBAL observation mean y_mean = H * x_mean */
416: /* Use temporary vector compatible with H's type */
417: {
418: Vec temp_mean, temp_y_mean;
419: PetscCall(MatCreateVecs(H, &temp_mean, &temp_y_mean));
420: PetscCall(VecCopy(impl->mean, temp_mean));
421: PetscCall(MatMult(H, temp_mean, temp_y_mean));
422: PetscCall(VecCopy(temp_y_mean, impl->y_mean));
423: PetscCall(VecDestroy(&temp_y_mean));
424: PetscCall(VecDestroy(&temp_mean));
425: }
427: /* Compute GLOBAL R^{-1/2} (assumes diagonal R) */
428: PetscCall(VecCopy(da->obs_error_var, impl->r_inv_sqrt));
429: PetscCall(VecSqrtAbs(impl->r_inv_sqrt));
430: PetscCall(VecReciprocal(impl->r_inv_sqrt));
431: /* Perform local analysis for all vertices */
433: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
434: /* Use GPU version only if:
435: 1. sqrt_type is eigen (GPU version only implements eigen/SVD, not cholesky)
436: 2. H matrix is a Kokkos type (aijkokkos) */
437: {
438: PetscBool use_gpu = PETSC_FALSE;
439: if (impl->en.sqrt_type == PETSCDA_SQRT_EIGEN) {
440: #if !defined(PETSC_USE_COMPLEX)
441: /* Check if H matrix is a Kokkos type */
442: PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, &use_gpu, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
443: #endif
444: }
446: /* Scatter global vectors to local work vectors if available */
447: if (impl->obs_scat) {
448: PetscCall(VecScatterBegin(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
449: PetscCall(VecScatterEnd(impl->obs_scat, observation, impl->obs_work, INSERT_VALUES, SCATTER_FORWARD));
451: PetscCall(VecScatterBegin(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
452: PetscCall(VecScatterEnd(impl->obs_scat, impl->y_mean, impl->y_mean_work, INSERT_VALUES, SCATTER_FORWARD));
454: PetscCall(VecScatterBegin(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
455: PetscCall(VecScatterEnd(impl->obs_scat, impl->r_inv_sqrt, impl->r_inv_sqrt_work, INSERT_VALUES, SCATTER_FORWARD));
457: /* Handle Z matrix (scatter columns) */
458: {
459: PetscInt n_obs_local;
460: PetscCall(VecGetLocalSize(impl->obs_work, &n_obs_local));
461: if (!impl->Z_work) {
462: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
463: } else {
464: PetscInt m_old, n_old;
465: PetscCall(MatGetSize(impl->Z_work, &n_old, &m_old));
466: if (m_old != m || n_old != n_obs_local) {
467: PetscCall(MatDestroy(&impl->Z_work));
468: PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, n_obs_local, m, NULL, &impl->Z_work));
469: }
470: }
471: for (PetscInt i = 0; i < m; i++) {
472: Vec z_col_global, z_col_local;
473: PetscCall(MatDenseGetColumnVecRead(impl->Z, i, &z_col_global));
474: PetscCall(MatDenseGetColumnVecWrite(impl->Z_work, i, &z_col_local));
475: PetscCall(VecScatterBegin(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
476: PetscCall(VecScatterEnd(impl->obs_scat, z_col_global, z_col_local, INSERT_VALUES, SCATTER_FORWARD));
477: PetscCall(MatDenseRestoreColumnVecRead(impl->Z, i, &z_col_global));
478: PetscCall(MatDenseRestoreColumnVecWrite(impl->Z_work, i, &z_col_local));
479: }
480: }
481: }
483: if (use_gpu) {
484: PetscInt n_local;
485: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
486: /* Use local work vectors for GPU analysis */
487: PetscCall(PetscDALETKFLocalAnalysis_GPU(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
488: } else {
489: PetscInt n_local;
490: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
491: if (impl->obs_scat) {
492: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, impl->obs_work, impl->Z_work, impl->y_mean_work, impl->r_inv_sqrt_work));
493: } else {
494: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, observation, impl->Z, impl->y_mean, impl->r_inv_sqrt));
495: }
496: }
497: }
498: #else
499: /* Without Kokkos, use CPU version */
500: {
501: PetscInt n_local;
502: PetscCall(MatGetLocalSize(impl->Q, &n_local, NULL));
503: PetscCall(PetscDALETKFLocalAnalysis(da, impl, m, n_local, X, observation, impl->Z, impl->y_mean, impl->r_inv_sqrt));
504: }
505: #endif
506: PetscCall(MatDestroy(&X));
507: PetscFunctionReturn(PETSC_SUCCESS);
508: }
510: static PetscErrorCode PetscDALETKFSetObsPerVertex_LETKF(PetscDA da, PetscInt n_obs_vertex)
511: {
512: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
514: PetscFunctionBegin;
515: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
516: impl->n_obs_vertex = n_obs_vertex;
517: PetscFunctionReturn(PETSC_SUCCESS);
518: }
520: static PetscErrorCode PetscDALETKFGetObsPerVertex_LETKF(PetscDA da, PetscInt *n_obs_vertex)
521: {
522: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
524: PetscFunctionBegin;
525: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
526: *n_obs_vertex = impl->n_obs_vertex;
527: PetscFunctionReturn(PETSC_SUCCESS);
528: }
530: static PetscErrorCode PetscDALETKFSetLocalization_LETKF(PetscDA da, Mat Q, Mat H)
531: {
532: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
533: PetscInt i, nrows, ncols, nnz, rstart, rend;
535: PetscFunctionBegin;
536: PetscCheck(impl, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "PetscDA not properly initialized for LETKF");
538: /* Get matrix dimensions */
539: PetscCall(MatGetSize(Q, &nrows, &ncols));
541: /* Validate matrix dimensions */
542: PetscCheck(nrows == da->state_size / da->ndof, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix rows (%" PetscInt_FMT ") must match state size (%" PetscInt_FMT ")", nrows, da->state_size);
543: PetscCheck(ncols == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Localization matrix columns (%" PetscInt_FMT ") must match observation size (%" PetscInt_FMT ")", ncols, da->obs_size);
545: /* Validate that each row has const non-zero entries */
546: PetscCall(MatGetOwnershipRange(Q, &rstart, &rend));
547: for (i = rstart; i < rend; i++) {
548: const PetscInt *cols;
549: const PetscScalar *vals;
550: PetscCall(MatGetRow(Q, i, &nnz, &cols, &vals));
551: PetscCheck(nnz == impl->n_obs_vertex, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "Row %" PetscInt_FMT " has %" PetscInt_FMT " non-zeros, expected %" PetscInt_FMT, i, nnz, (PetscInt)impl->n_obs_vertex);
552: PetscCall(MatRestoreRow(Q, i, &nnz, &cols, &vals));
553: }
555: /* Store the localization matrix */
556: PetscCall(MatDestroy(&impl->Q));
557: PetscCall(PetscObjectReference((PetscObject)Q));
558: impl->Q = Q;
559: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
560: PetscCall(PetscDALETKFSetupLocalization_Kokkos(impl, H));
561: #endif
562: PetscFunctionReturn(PETSC_SUCCESS);
563: }
565: static PetscErrorCode PetscDAView_LETKF(PetscDA da, PetscViewer viewer)
566: {
567: PetscBool iascii;
568: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
570: PetscFunctionBegin;
571: PetscCall(PetscDAView_Ensemble(da, viewer));
572: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
573: if (iascii) {
574: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
575: if (impl->en.sqrt_type == PETSCDA_SQRT_CHOLESKY) {
576: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU\n"));
577: } else {
578: /* Check if R matrix is Kokkos type to determine if GPU will be used */
579: if (da->R) {
580: PetscBool is_kokkos = PETSC_FALSE;
581: PetscCall(PetscObjectTypeCompareAny((PetscObject)da->R, &is_kokkos, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
582: if (is_kokkos) {
583: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: Kokkos\n"));
584: } else {
585: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU\n"));
586: }
587: } else {
588: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU or Kokkos (depending on covariance matrix type)\n"));
589: }
590: }
591: #else
592: PetscCall(PetscViewerASCIIPrintf(viewer, " Local analysis: CPU\n"));
593: #endif
594: PetscCall(PetscViewerASCIIPrintf(viewer, " Local observations per vertex: %" PetscInt_FMT "\n", impl->n_obs_vertex));
595: if (impl->batch_size > 0) {
596: PetscCall(PetscViewerASCIIPrintf(viewer, " GPU batch size: %" PetscInt_FMT "\n", impl->batch_size));
597: } else {
598: PetscCall(PetscViewerASCIIPrintf(viewer, " GPU batch size: auto\n"));
599: }
600: if (impl->Q) {
601: PetscCall(PetscViewerASCIIPrintf(viewer, " Localization matrix: set\n"));
602: } else {
603: PetscCall(PetscViewerASCIIPrintf(viewer, " Localization matrix: not set\n"));
604: }
605: }
606: PetscFunctionReturn(PETSC_SUCCESS);
607: }
609: static PetscErrorCode PetscDASetFromOptions_LETKF(PetscDA da, PetscOptionItems *PetscOptionsObjectPtr)
610: {
611: PetscDA_LETKF *impl = (PetscDA_LETKF *)da->data;
612: PetscOptionItems PetscOptionsObject = *PetscOptionsObjectPtr;
614: PetscFunctionBegin;
615: PetscCall(PetscDASetFromOptions_Ensemble(da, PetscOptionsObjectPtr));
616: PetscOptionsHeadBegin(PetscOptionsObject, "PetscDA LETKF Options");
617: PetscCall(PetscOptionsInt("-petscda_letkf_batch_size", "Batch size for GPU processing", "", impl->batch_size, &impl->batch_size, NULL));
618: PetscCall(PetscOptionsInt("-petscda_letkf_obs_per_vertex", "Number of local observations per vertex", "", impl->n_obs_vertex, &impl->n_obs_vertex, NULL));
619: PetscOptionsHeadEnd();
620: PetscFunctionReturn(PETSC_SUCCESS);
621: }
623: /*MC
624: PETSCDALETKF - The Local ETKF performs the analysis update locally around each grid point, enabling scalable assimilation on large
625: domains by avoiding the global ensemble covariance matrix.
627: Options Database Keys:
628: + -petscda_type letkf - set the `PetscDAType` to `PETSCDALETKF`
629: . -petscda_ensemble_size <size> - number of ensemble members
630: . -petscda_ensemble_sqrt_type <cholesky, eigen> - the square root of the matrix to use
631: . -petscda_letkf_batch_size <batch_size> - set the batch size for GPU processing
632: - -petscda_letkf_obs_per_vertex <n_obs_vertex> - number of observations per vertex
634: Level: beginner
636: .seealso: [](ch_da), `PetscDA`, `PetscDACreate()`, `PETSCDAETKF`, `PetscDALETKFSetObsPerVertex()`, `PetscDALETKFGetObsPerVertex()`,
637: `PetscDALETKFSetLocalization()`, `PetscDAEnsembleSetSize()`, `PetscDASetSizes()`, `PetscDAEnsembleSetSqrtType()`, `PetscDAEnsembleSetInflation()`,
638: `PetscDAEnsembleComputeMean()`, `PetscDAEnsembleComputeAnomalies()`, `PetscDAEnsembleAnalysis()`, `PetscDAEnsembleForecast()`
639: M*/
641: PETSC_INTERN PetscErrorCode PetscDACreate_LETKF(PetscDA da)
642: {
643: PetscDA_LETKF *impl;
645: PetscFunctionBegin;
646: PetscCall(PetscNew(&impl));
647: da->data = impl;
648: PetscCall(PetscDACreate_Ensemble(da));
649: da->ops->destroy = PetscDADestroy_LETKF;
650: da->ops->view = PetscDAView_LETKF;
651: da->ops->setfromoptions = PetscDASetFromOptions_LETKF;
652: impl->en.analysis = PetscDAEnsembleAnalysis_LETKF;
653: impl->en.forecast = PetscDAEnsembleForecast_Ensemble;
655: impl->n_obs_vertex = 9;
656: impl->Q = NULL;
657: impl->batch_size = 0;
659: /* Register the method for setting localization */
660: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetLocalization_C", PetscDALETKFSetLocalization_LETKF));
661: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFSetObsPerVertex_C", PetscDALETKFSetObsPerVertex_LETKF));
662: PetscCall(PetscObjectComposeFunction((PetscObject)da, "PetscDALETKFGetObsPerVertex_C", PetscDALETKFGetObsPerVertex_LETKF));
663: PetscFunctionReturn(PETSC_SUCCESS);
664: }
666: /*@
667: PetscDALETKFSetObsPerVertex - Sets the number of local observations per vertex for the LETKF algorithm.
669: Logically Collective
671: Input Parameters:
672: + da - the `PetscDA` context
673: - n_obs_vertex - number of observations per vertex
675: Level: advanced
677: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetLocalization()`
678: @*/
679: PetscErrorCode PetscDALETKFSetObsPerVertex(PetscDA da, PetscInt n_obs_vertex)
680: {
681: PetscFunctionBegin;
684: PetscTryMethod(da, "PetscDALETKFSetObsPerVertex_C", (PetscDA, PetscInt), (da, n_obs_vertex));
685: PetscFunctionReturn(PETSC_SUCCESS);
686: }
688: /*@
689: PetscDALETKFGetObsPerVertex - Gets the number of local observations per vertex for the LETKF algorithm.
691: Not Collective
693: Input Parameter:
694: . da - the `PetscDA` context
696: Output Parameter:
697: . n_obs_vertex - number of observations per vertex
699: Level: advanced
701: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`, `PetscDALETKFSetObsPerVertex()`
702: @*/
703: PetscErrorCode PetscDALETKFGetObsPerVertex(PetscDA da, PetscInt *n_obs_vertex)
704: {
705: PetscFunctionBegin;
707: PetscAssertPointer(n_obs_vertex, 2);
708: PetscUseMethod(da, "PetscDALETKFGetObsPerVertex_C", (PetscDA, PetscInt *), (da, n_obs_vertex));
709: PetscFunctionReturn(PETSC_SUCCESS);
710: }
712: /*@
713: PetscDALETKFSetLocalization - Sets the localization matrix for the LETKF algorithm.
715: Collective
717: Input Parameters:
718: + da - the `PetscDA` context
719: . Q - the localization matrix (N x P)
720: - H - the observation operator matrix (P x N)
722: Level: advanced
724: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`
725: @*/
726: PetscErrorCode PetscDALETKFSetLocalization(PetscDA da, Mat Q, Mat H)
727: {
728: PetscFunctionBegin;
732: PetscTryMethod(da, "PetscDALETKFSetLocalization_C", (PetscDA, Mat, Mat), (da, Q, H));
733: PetscFunctionReturn(PETSC_SUCCESS);
734: }