Actual source code: etkfilter.c
1: #include <petscda.h>
2: #include <petsc/private/daimpl.h>
3: #include <petsc/private/daensembleimpl.h>
5: typedef struct {
6: PetscDA_Ensemble en;
7: Vec mean;
8: Vec y_mean;
9: Vec delta_scaled;
10: Vec w;
11: Vec r_inv_sqrt;
12: Mat Z;
13: Mat S;
14: Mat T_sqrt;
15: Mat w_ones;
16: } PetscDA_ETKF;
18: /*
19: BroadcastWeightVector - Creates matrix with weight vector replicated across all columns
21: Input Parameters:
22: + w - weight vector of size m (analysis weights from ETKF update)
23: - m - ensemble size (number of columns to replicate, must equal vector size)
25: Output Parameter:
26: . w_ones - m x m dense matrix where each column is a copy of w (i.e., w * 1^T)
28: Notes:
29: This function constructs the broadcast matrix w * 1^T, where w is the m-dimensional
30: weight vector and 1 is an m-dimensional vector of ones. This matrix is a fundamental
31: component in the ETKF transform: G = w * 1^T + sqrt(m-1) * T^{1/2} * U.
33: The implementation uses direct array access for performance, avoiding the overhead of
34: repeated vector wrapping and copying. This is particularly efficient for dense matrices
35: where memory is contiguous column-wise.
37: Complexity: O(m^2) time and memory.
39: Level: developer
41: */
42: static PetscErrorCode BroadcastWeightVector(Vec w, PetscInt m, Mat w_ones)
43: {
44: const PetscScalar *w_array;
45: PetscScalar *mat_array;
46: PetscInt w_size, w_size_local, mat_rows_local, mat_cols_local;
47: PetscInt i, lda;
49: PetscFunctionBegin;
50: PetscCheck(m > 0, PetscObjectComm((PetscObject)w), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size m must be positive for broadcasting, got %" PetscInt_FMT, m);
51: /* Check for potential overflow in matrix size calculation */
52: PetscCheck(m <= PETSC_MAX_INT / m, PetscObjectComm((PetscObject)w), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size m = %" PetscInt_FMT " too large", m);
54: /* Verify dimensions */
55: PetscCall(VecGetSize(w, &w_size));
56: PetscCall(VecGetLocalSize(w, &w_size_local));
57: PetscCheck(w_size == m, PetscObjectComm((PetscObject)w), PETSC_ERR_ARG_INCOMP, "Weight vector global size (%" PetscInt_FMT ") must match ensemble size (%" PetscInt_FMT ")", w_size, m);
59: /* Verify consistent parallel layout between vector and matrix */
60: PetscCall(MatGetLocalSize(w_ones, &mat_rows_local, &mat_cols_local));
61: PetscCheck(mat_rows_local == w_size_local, PetscObjectComm((PetscObject)w), PETSC_ERR_PLIB, "Matrix row distribution (%" PetscInt_FMT ") inconsistent with vector distribution (%" PetscInt_FMT ")", mat_rows_local, w_size_local);
62: PetscCheck(mat_cols_local == m, PetscObjectComm((PetscObject)w), PETSC_ERR_PLIB, "Matrix local columns (%" PetscInt_FMT ") must equal global columns m (%" PetscInt_FMT ") for MPIDense", mat_cols_local, m);
64: /* Access raw arrays for efficient broadcasting */
65: PetscCall(VecGetArrayRead(w, &w_array));
66: PetscCall(MatDenseGetArrayWrite(w_ones, &mat_array));
67: PetscCall(MatDenseGetLDA(w_ones, &lda));
69: /* Copy w to each column of w_ones */
70: /* Note: MatDense uses column-major storage. We copy the vector w into each column. */
71: for (i = 0; i < m; i++) PetscCall(PetscArraycpy(mat_array + i * lda, w_array, w_size_local));
73: /* Restore arrays */
74: PetscCall(MatDenseRestoreArrayWrite(w_ones, &mat_array));
75: PetscCall(VecRestoreArrayRead(w, &w_array));
77: /* Finalize matrix assembly */
78: PetscCall(MatAssemblyBegin(w_ones, MAT_FINAL_ASSEMBLY));
79: PetscCall(MatAssemblyEnd(w_ones, MAT_FINAL_ASSEMBLY));
80: PetscFunctionReturn(PETSC_SUCCESS);
81: }
83: /*
84: UpdateEnsembleWithTransform - Updates ensemble via ETKF transform: E = mean * 1' + X * G [Alg 6.4 line 9]
86: Input Parameters:
87: + mean - ensemble mean vector (size state_size), must be initialized
88: . X - scaled anomaly matrix (state_size x ensemble_size), X = (E - mean*1')/sqrt(m-1)
89: . G - ETKF transform matrix (ensemble_size x ensemble_size), G = w*1' + sqrt(m-1)*T^{1/2}*U
90: . m - ensemble size (number of columns in ensemble), must be > 0
91: - ensemble - ensemble matrix to update in-place (state_size x ensemble_size)
93: Notes:
94: This function performs the final step (Step 10) of the ETKF analysis algorithm from
95: Asch, M., Bocquet, M., and Nodet, M., transforming the forecast ensemble into the analysis ensemble.
96: The operation E^a = mean + X * G is computed using matrix-matrix multiplication followed
97: by column-wise addition to efficiently handle large state spaces.
99: Error Handling:
100: - Validates all input dimensions for consistency
101: - Checks for positive ensemble size
102: - Ensures proper matrix/vector initialization
103: - Handles parallel assembly correctly
105: Performance Considerations:
106: - Memory: Creates one temporary matrix X_G of size (state_size x m)
107: - Time complexity: O(state_size * m^2) for matrix multiply + O(state_size * m) for additions
108: - Optimization: Uses direct array access for dense matrices to avoid Vec overhead
109: - Parallel: Fully parallelizable across both matrix multiply and column updates
111: Level: developer
113: */
114: static PetscErrorCode UpdateEnsembleWithTransform(Vec mean, Mat X, Mat G, PetscInt m, Mat ensemble)
115: {
116: Mat X_G;
117: const PetscScalar *xg_array, *mean_array;
118: PetscScalar *ens_array;
119: PetscInt x_rows, x_cols, g_rows, g_cols, ens_rows, ens_cols;
120: PetscInt n_local_ens, n_local_xg, mean_local_size;
121: PetscInt lda_ens, lda_xg;
122: PetscInt mean_size, i, j;
124: PetscFunctionBegin;
125: /* Validate input parameters for correct types and null pointers */
132: /* Retrieve and validate matrix dimensions for compatibility */
133: PetscCall(MatGetSize(X, &x_rows, &x_cols));
134: PetscCall(MatGetSize(G, &g_rows, &g_cols));
135: PetscCall(MatGetSize(ensemble, &ens_rows, &ens_cols));
136: PetscCall(VecGetSize(mean, &mean_size));
138: /* Verify dimension consistency across all inputs */
139: PetscCheck(x_cols == m, PetscObjectComm((PetscObject)X), PETSC_ERR_ARG_INCOMP, "Anomaly matrix X columns (%" PetscInt_FMT ") must equal ensemble size (%" PetscInt_FMT ")", x_cols, m);
140: PetscCheck(g_rows == m, PetscObjectComm((PetscObject)G), PETSC_ERR_ARG_INCOMP, "Transform matrix G rows (%" PetscInt_FMT ") must equal ensemble size (%" PetscInt_FMT ")", g_rows, m);
141: PetscCheck(g_cols == m, PetscObjectComm((PetscObject)G), PETSC_ERR_ARG_INCOMP, "Transform matrix G must be square, got %" PetscInt_FMT " x %" PetscInt_FMT, g_rows, g_cols);
142: PetscCheck(ens_rows == x_rows, PetscObjectComm((PetscObject)ensemble), PETSC_ERR_ARG_INCOMP, "Ensemble rows (%" PetscInt_FMT ") must match anomaly matrix X rows (%" PetscInt_FMT ")", ens_rows, x_rows);
143: PetscCheck(ens_cols == m, PetscObjectComm((PetscObject)ensemble), PETSC_ERR_ARG_INCOMP, "Ensemble columns (%" PetscInt_FMT ") must equal ensemble size (%" PetscInt_FMT ")", ens_cols, m);
144: PetscCheck(mean_size == x_rows, PetscObjectComm((PetscObject)mean), PETSC_ERR_ARG_INCOMP, "Mean vector size (%" PetscInt_FMT ") must match state size (%" PetscInt_FMT ")", mean_size, x_rows);
146: /* Compute transformed anomaly matrix: X_G = X * G (state_size x m) */
147: PetscCall(MatMatMult(X, G, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &X_G));
149: /* Access underlying data arrays for direct performance access
150: This avoids creating/destroying m Vec objects and calling VecWAXPY m times. */
151: PetscCall(MatDenseGetArrayRead(X_G, &xg_array));
152: PetscCall(MatDenseGetArrayWrite(ensemble, &ens_array));
153: PetscCall(VecGetArrayRead(mean, &mean_array));
155: /* Get local dimensions and strides for array traversal */
156: PetscCall(MatGetLocalSize(ensemble, &n_local_ens, NULL));
157: PetscCall(MatGetLocalSize(X_G, &n_local_xg, NULL));
158: PetscCall(VecGetLocalSize(mean, &mean_local_size));
159: PetscCall(MatDenseGetLDA(ensemble, &lda_ens));
160: PetscCall(MatDenseGetLDA(X_G, &lda_xg));
162: /* Verify local dimensions match before direct array access */
163: PetscCheck(n_local_ens == n_local_xg, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Local row size mismatch: ensemble (%" PetscInt_FMT ") vs X_G (%" PetscInt_FMT ")", n_local_ens, n_local_xg);
164: PetscCheck(n_local_ens == mean_local_size, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Local row size mismatch: ensemble (%" PetscInt_FMT ") vs mean (%" PetscInt_FMT ")", n_local_ens, mean_local_size);
166: /* Update each ensemble member: E_ij = (XG)_ij + mean_i
167: Loop over columns (j) and rows (i) of the local data block */
168: for (j = 0; j < m; j++) {
169: const PetscScalar *xg_col = xg_array + j * lda_xg;
170: PetscScalar *ens_col = ens_array + j * lda_ens;
171: for (i = 0; i < n_local_ens; i++) ens_col[i] = xg_col[i] + mean_array[i];
172: }
174: /* Restore arrays and finalize assembly */
175: PetscCall(VecRestoreArrayRead(mean, &mean_array));
176: PetscCall(MatDenseRestoreArrayWrite(ensemble, &ens_array));
177: PetscCall(MatDenseRestoreArrayRead(X_G, &xg_array));
179: PetscCall(MatAssemblyBegin(ensemble, MAT_FINAL_ASSEMBLY));
180: PetscCall(MatAssemblyEnd(ensemble, MAT_FINAL_ASSEMBLY));
182: /* Clean up temporary transformed anomaly matrix */
183: PetscCall(MatDestroy(&X_G));
184: PetscFunctionReturn(PETSC_SUCCESS);
185: }
187: static PetscErrorCode PetscDADestroy_ETKF(PetscDA da)
188: {
189: PetscDA_ETKF *impl = (PetscDA_ETKF *)da->data;
191: PetscFunctionBegin;
192: PetscCall(VecDestroy(&impl->mean));
193: PetscCall(VecDestroy(&impl->y_mean));
194: PetscCall(VecDestroy(&impl->delta_scaled));
195: PetscCall(VecDestroy(&impl->w));
196: PetscCall(VecDestroy(&impl->r_inv_sqrt));
197: PetscCall(MatDestroy(&impl->Z));
198: PetscCall(MatDestroy(&impl->S));
199: PetscCall(MatDestroy(&impl->T_sqrt));
200: PetscCall(MatDestroy(&impl->w_ones));
201: PetscCall(PetscDADestroy_Ensemble(da));
202: PetscCall(PetscFree(da->data));
203: PetscFunctionReturn(PETSC_SUCCESS);
204: }
206: static PetscErrorCode PetscDAEnsembleAnalysis_ETKF(PetscDA da, Vec observation, Mat H)
207: {
208: PetscDA_ETKF *impl = (PetscDA_ETKF *)da->data;
209: Mat X;
210: PetscInt m = impl->en.size;
211: PetscScalar scale, sqrt_m_minus_1;
212: PetscBool reallocate = PETSC_FALSE;
214: PetscFunctionBegin;
215: scale = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
216: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
217: PetscCall(PetscInfo(da, "square root type %s, %" PetscInt_FMT " ensembles\n", (impl->en.sqrt_type == PETSCDA_SQRT_EIGEN) ? "eigen" : "cholesky", m));
219: /* Check for reallocation needs */
220: if (impl->mean) {
221: PetscInt mean_size;
222: PetscCall(VecGetSize(impl->mean, &mean_size));
223: if (mean_size != da->state_size) reallocate = PETSC_TRUE;
224: }
225: if (impl->Z) {
226: PetscInt z_rows, z_cols;
227: PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
228: if (z_rows != da->obs_size || z_cols != impl->en.size) reallocate = PETSC_TRUE;
229: }
230: if (impl->w) {
231: PetscInt w_size;
232: PetscCall(VecGetSize(impl->w, &w_size));
233: if (w_size != impl->en.size) reallocate = PETSC_TRUE;
234: }
236: /* Initialize or reallocate persistent work objects */
237: if (!impl->mean || reallocate) {
238: PetscCall(VecDestroy(&impl->mean));
239: PetscCall(VecDestroy(&impl->y_mean));
240: PetscCall(VecDestroy(&impl->delta_scaled));
241: PetscCall(VecDestroy(&impl->w));
242: PetscCall(VecDestroy(&impl->r_inv_sqrt));
243: PetscCall(MatDestroy(&impl->Z));
244: PetscCall(MatDestroy(&impl->S));
245: PetscCall(MatDestroy(&impl->T_sqrt));
246: PetscCall(MatDestroy(&impl->w_ones));
248: /* Create mean vector from ensemble matrix (right vector = state space) */
249: PetscCall(MatCreateVecs(impl->en.ensemble, NULL, &impl->mean));
251: /* Create Z matrix (obs_size x m) */
252: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, da->obs_size, m, NULL, &impl->Z));
253: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->Z, "dense_"));
254: PetscCall(MatSetFromOptions(impl->Z));
255: PetscCall(MatSetUp(impl->Z));
257: /* Create observation space vectors from Z matrix (left vector = observation space) */
258: PetscCall(MatCreateVecs(impl->Z, NULL, &impl->y_mean));
259: PetscCall(VecDuplicate(impl->y_mean, &impl->delta_scaled));
260: PetscCall(VecDuplicate(da->obs_error_var, &impl->r_inv_sqrt));
262: /* Create w vector (size m) for analysis weights */
263: PetscCall(MatCreateVecs(impl->Z, &impl->w, NULL));
265: /* Create S matrix (same layout as Z) */
266: PetscCall(MatDuplicate(impl->Z, MAT_DO_NOT_COPY_VALUES, &impl->S));
268: /* Create T_sqrt matrix (m x m) - usually small */
269: /* T_sqrt will hold the result of applying T^{-1/2} to identity matrix */
270: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->T_sqrt));
271: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->T_sqrt, "dense_"));
272: PetscCall(MatSetFromOptions(impl->T_sqrt));
273: PetscCall(MatSetUp(impl->T_sqrt));
275: /* Create w_ones matrix (m x m) */
276: PetscCall(MatCreateDense(PetscObjectComm((PetscObject)impl->en.ensemble), PETSC_DECIDE, PETSC_DECIDE, m, m, NULL, &impl->w_ones));
277: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)impl->w_ones, "dense_"));
278: PetscCall(MatSetFromOptions(impl->w_ones));
279: PetscCall(MatSetUp(impl->w_ones));
280: }
282: /* Alg 6.4 line 1-2: Compute ensemble mean and scaled anomalies */
283: PetscCall(PetscDAEnsembleComputeMean(da, impl->mean));
285: /* X = (E - x_mean * 1') / sqrt(m - 1) */
286: /* Note: PetscDAComputeAnomalies creates a NEW matrix X every time.
287: We should probably optimize this too in the future, but for now we follow the API. */
288: PetscCall(PetscDAEnsembleComputeAnomalies(da, impl->mean, &X));
290: /* Alg 6.4 line 3-4: Compute observation ensemble Z = H * E */
291: {
292: MatReuse scall = MAT_INITIAL_MATRIX;
293: if (impl->Z) {
294: PetscInt z_rows, z_cols;
295: PetscCall(MatGetSize(impl->Z, &z_rows, &z_cols));
296: if (z_rows == da->obs_size && z_cols == impl->en.size) scall = MAT_REUSE_MATRIX;
297: else {
298: PetscCall(MatDestroy(&impl->Z));
299: scall = MAT_INITIAL_MATRIX;
300: }
301: }
302: PetscCall(MatMatMult(H, impl->en.ensemble, scall, PETSC_DEFAULT, &impl->Z));
303: }
305: /* Compute observation mean y_mean = H * x_mean */
306: PetscCall(MatMult(H, impl->mean, impl->y_mean));
308: /* Alg 6.4 line 5-6: Build normalized innovation statistics */
309: PetscCall(VecCopy(da->obs_error_var, impl->r_inv_sqrt));
310: PetscCall(VecSqrtAbs(impl->r_inv_sqrt));
311: PetscCall(VecReciprocal(impl->r_inv_sqrt));
313: /* S = R^{-1/2} * (Z - y_mean * 1') / sqrt(m - 1) */
314: PetscCall(PetscDAEnsembleComputeNormalizedInnovationMatrix(impl->Z, impl->y_mean, impl->r_inv_sqrt, m, scale, impl->S));
316: /* delta_scaled = R^{-1/2} * (y^o - y_mean) [Alg 6.4 line 6] */
317: PetscCall(VecWAXPY(impl->delta_scaled, -1.0, impl->y_mean, observation));
318: PetscCall(VecPointwiseMult(impl->delta_scaled, impl->delta_scaled, impl->r_inv_sqrt));
320: /* Alg 6.4 line 7: Factor T = (I + S^T S) and store factorization */
321: /* Note: Inflation is handled inside PetscDAEnsembleTFactor by shifting the diagonal of T */
322: PetscCall(PetscDAEnsembleTFactor(da, impl->S));
324: /* Alg 6.4 line 8: Compute analysis weights w = T^{-1} * S^T * delta_scaled */
325: {
326: Vec s_transpose_delta;
327: /* Create temporary vector for S^T * delta_scaled */
328: PetscCall(MatCreateVecs(impl->Z, &s_transpose_delta, NULL));
329: PetscCall(MatMultTranspose(impl->S, impl->delta_scaled, s_transpose_delta));
331: PetscCall(PetscDAEnsembleApplyTInverse(da, s_transpose_delta, impl->w));
332: PetscCall(VecDestroy(&s_transpose_delta));
333: }
335: /* Alg 6.4 line 9: Compute square-root transform T^{-1/2} */
336: PetscCall(PetscDAEnsembleApplySqrtTInverse(da, NULL, impl->T_sqrt));
338: /* Alg 6.4 line 9: Form transform G = w * 1' + sqrt(m - 1) * T^{1/2} * U */
339: {
340: Mat T_sqrt_scaled;
341: PetscCall(MatDuplicate(impl->T_sqrt, MAT_COPY_VALUES, &T_sqrt_scaled));
342: PetscCall(MatScale(T_sqrt_scaled, sqrt_m_minus_1));
344: /* w_ones = w * 1' (broadcast weight vector to all columns) */
345: PetscCall(BroadcastWeightVector(impl->w, m, impl->w_ones));
347: /* G = w_ones + sqrt(m-1)*T_sqrt
348: Accumulate the scaled T_sqrt into w_ones to form the transform matrix G */
349: PetscCall(MatAXPY(impl->w_ones, 1.0, T_sqrt_scaled, SAME_NONZERO_PATTERN));
351: PetscCall(MatDestroy(&T_sqrt_scaled));
352: }
354: /* Alg 6.4 line 9: Update ensemble E = x_mean * 1' + X * G */
355: PetscCall(UpdateEnsembleWithTransform(impl->mean, X, impl->w_ones, m, impl->en.ensemble));
357: /* Cleanup temporary X matrix */
358: PetscCall(MatDestroy(&X));
359: PetscFunctionReturn(PETSC_SUCCESS);
360: }
362: PETSC_INTERN PetscErrorCode PetscDAEnsembleForecast_Ensemble(PetscDA da, PetscErrorCode (*model)(Vec, Vec, PetscCtx), PetscCtx ctx)
363: {
364: PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
365: Vec col_in, col_out, temp;
366: PetscInt i;
368: PetscFunctionBegin;
371: /* Create temp vector from ensemble matrix (right vector = state space) */
372: PetscCall(MatCreateVecs(en->ensemble, NULL, &temp));
374: for (i = 0; i < en->size; i++) {
375: PetscCall(MatDenseGetColumnVecRead(en->ensemble, i, &col_in));
376: PetscCall(model(col_in, temp, ctx));
377: PetscCall(MatDenseRestoreColumnVecRead(en->ensemble, i, &col_in));
379: PetscCall(MatDenseGetColumnVecWrite(en->ensemble, i, &col_out));
380: PetscCall(VecCopy(temp, col_out));
381: PetscCall(MatDenseRestoreColumnVecWrite(en->ensemble, i, &col_out));
382: }
384: PetscCall(VecDestroy(&temp));
385: PetscFunctionReturn(PETSC_SUCCESS);
386: }
388: /*MC
389: PETSCDAETKF - Ensemble transform Kalman filter data assimilation using a deterministic square-root update that avoids stochastic perturbations.
391: Options Database Keys:
392: + -petscda_type etkf - set the `PetscDAType` to `PETSCDAETKF`
393: . -petscda_ensemble_size <size> - number of ensemble members
394: - -petscda_ensemble_sqrt_type <cholesky, eigen> - the square root of the matrix to use
396: Level: beginner
398: Note:
399: The ETKF algorithm is based on Algorithm 6.4 in {cite}`da2016`
401: .seealso: [](ch_da), `PetscDA`, `PetscDACreate()`, `PETSCDALETKF`, `PetscDAEnsembleSetSize()`, `PetscDASetSizes()`, `PetscDAEnsembleSetSqrtType()`,
402: `PetscDAEnsembleSetInflation()`, `PetscDAType`,
403: `PetscDAEnsembleComputeMean()`, `PetscDAEnsembleComputeAnomalies()`, `PetscDAEnsembleAnalysis()`, `PetscDAEnsembleForecast()`
404: M*/
405: PETSC_INTERN PetscErrorCode PetscDACreate_ETKF(PetscDA da)
406: {
407: PetscDA_ETKF *impl;
409: PetscFunctionBegin;
410: PetscCall(PetscNew(&impl));
411: da->data = impl;
412: PetscCall(PetscDACreate_Ensemble(da));
413: da->ops->setup = PetscDASetUp_Ensemble;
414: da->ops->destroy = PetscDADestroy_ETKF;
415: da->ops->view = PetscDAView_Ensemble;
416: da->ops->setfromoptions = PetscDASetFromOptions_Ensemble;
417: impl->en.analysis = PetscDAEnsembleAnalysis_ETKF;
418: impl->en.forecast = PetscDAEnsembleForecast_Ensemble;
419: PetscFunctionReturn(PETSC_SUCCESS);
420: }