Actual source code: daensemble.c

  1: #include <petscda.h>
  2: #include <petsc/private/daimpl.h>
  3: #include <petscblaslapack.h>
  4: #include <petsc/private/daensembleimpl.h>

  6: /*
  7:      Code that is shared by PETSCDALETKF (and any future ensemble methods).

  9: */
 10: /*  T-Matrix Factorization and Application Methods [Alg 6.4 line 7] */

 12: /*
 13:    Tolerance for matrix square root verification in debug mode
 14:    Use a more relaxed tolerance to account for accumulated floating-point errors
 15:    in multiple matrix operations (Y^T * T * Y involves 3 matrix multiplications).
 16:    A tolerance of 1e-2 (1%) is reasonable for numerical verification. */
 17: #define MATRIX_SQRT_TOLERANCE_FACTOR 1.0e-2

 19: /*
 20:   PetscDAEnsembleTFactorFromGram - Build (or refresh) en->I_StS from a host m x m gram buffer
 21:   (column-major), shift by 1/inflation, and run the eigendecomposition.

 23:   Contract: the caller supplies gram_host = S^T S, where S already contains the 1/sqrt(m-1)
 24:   normalization. This routine adds the (1/inflation) I shift and computes the eigendecomposition,
 25:   so en->I_StS = (1/inflation) I + S^T S on return.

 27:   The matrix lives on PETSC_COMM_SELF so the caller is responsible for any cross-rank reduction
 28:   on gram_host before calling.
 29: */
 30: PETSC_INTERN PetscErrorCode PetscDAEnsembleTFactorFromGram(PetscDA da, PetscInt m, const PetscScalar *gram_host)
 31: {
 32:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
 33:   PetscScalar      *dst;

 35:   PetscFunctionBegin;
 36:   if (en->I_StS) {
 37:     PetscInt rows, cols;
 38:     PetscCall(MatGetSize(en->I_StS, &rows, &cols));
 39:     if (rows != m || cols != m) {
 40:       PetscCall(MatDestroy(&en->I_StS));
 41:       PetscCall(MatDestroy(&en->V));
 42:       PetscCall(VecDestroy(&en->sqrt_eigen_vals));
 43:     }
 44:   }
 45:   if (!en->I_StS) PetscCall(MatCreateSeqDense(PETSC_COMM_SELF, m, m, NULL, &en->I_StS));
 46:   PetscCall(MatDenseGetArrayWrite(en->I_StS, &dst));
 47:   PetscCall(PetscArraycpy(dst, gram_host, (size_t)m * m));
 48:   PetscCall(MatDenseRestoreArrayWrite(en->I_StS, &dst));
 49:   PetscCall(MatShift(en->I_StS, 1.0 / en->inflation));
 50:   PetscCall(PetscDAEnsembleTFactor_Eigen(da));
 51:   PetscFunctionReturn(PETSC_SUCCESS);
 52: }

 54: /*
 55:   PetscDAEnsembleTFactor_Eigen - Compute the symmetric eigendecomposition of the m x m matrix
 56:   held in en->I_StS (the user pre-shifted it by 1/inflation). On return, en->V holds the
 57:   eigenvectors and en->sqrt_eigen_vals holds the eigenvalues (the elementwise sqrt is taken
 58:   later by PetscDAEnsembleApplySqrtTInverse_Eigen()).
 59: */
 60: PETSC_INTERN PetscErrorCode PetscDAEnsembleTFactor_Eigen(PetscDA da)
 61: {
 62:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
 63:   PetscBLASInt      n, lda, lwork, info;
 64:   PetscScalar      *a_array, *work, *eig_array;
 65:   PetscInt          m_V, N_V;
 66: #if defined(PETSC_USE_COMPLEX)
 67:   PetscReal *rwork = NULL;
 68: #endif

 70:   PetscFunctionBegin;
 71:   /* Initialize or update V matrix */
 72:   if (!en->V) PetscCall(MatDuplicate(en->I_StS, MAT_COPY_VALUES, &en->V));
 73:   else PetscCall(MatCopy(en->I_StS, en->V, SAME_NONZERO_PATTERN));

 75:   /* Initialize or update eigenvalue vector */
 76:   if (!en->sqrt_eigen_vals) PetscCall(MatCreateVecs(en->I_StS, &en->sqrt_eigen_vals, NULL));

 78:   /* Get matrix dimensions */
 79:   PetscCall(MatGetSize(en->V, &m_V, &N_V));
 80:   PetscCheck(m_V == N_V, PetscObjectComm((PetscObject)en->V), PETSC_ERR_ARG_WRONG, "Matrix must be square");
 81:   PetscCall(PetscBLASIntCast(N_V, &n));
 82:   lda = n;

 84:   /* Get arrays */
 85:   PetscCall(MatDenseGetArrayWrite(en->V, &a_array));
 86:   PetscCall(VecGetArrayWrite(en->sqrt_eigen_vals, &eig_array));

 88:   /* Query optimal workspace size */
 89:   lwork = -1;
 90:   PetscCall(PetscMalloc1(1, &work));
 91: #if defined(PETSC_USE_COMPLEX)
 92:   PetscCall(PetscMalloc1(PetscMax(1, 3 * n - 2), &rwork));
 93:   PetscCallBLAS("LAPACKsyev", LAPACKsyev_("V", "U", &n, a_array, &lda, (PetscReal *)eig_array, work, &lwork, rwork, &info));
 94: #else
 95:   PetscCallBLAS("LAPACKsyev", LAPACKsyev_("V", "U", &n, a_array, &lda, eig_array, work, &lwork, &info));
 96: #endif
 97:   PetscCheck(info == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Error in LAPACK routine xSYEV work query: info=%" PetscBLASInt_FMT, info);

 99:   /* Allocate workspace. LAPACK returns the optimal lwork as a double-valued integer in work[0];
100:      wrap with PetscCeilReal before narrowing so a 1-ulp shrink (some LAPACK builds return
101:      e.g. 2591.999...) cannot under-allocate. PetscBLASIntCast then checks the int range. */
102:   PetscCall(PetscBLASIntCast((PetscInt)PetscCeilReal(PetscRealPart(work[0])), &lwork));
103:   PetscCall(PetscFree(work));
104:   PetscCall(PetscMalloc1(lwork, &work));

106:   /* Compute eigendecomposition */
107: #if defined(PETSC_USE_COMPLEX)
108:   PetscCallBLAS("LAPACKsyev", LAPACKsyev_("V", "U", &n, a_array, &lda, (PetscReal *)eig_array, work, &lwork, rwork, &info));
109:   PetscCall(PetscFree(rwork));
110: #else
111:   PetscCallBLAS("LAPACKsyev", LAPACKsyev_("V", "U", &n, a_array, &lda, eig_array, work, &lwork, &info));
112: #endif
113:   PetscCheck(info == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Error in LAPACK routine xSYEV: info=%" PetscBLASInt_FMT, info);

115:   /* Cleanup */
116:   PetscCall(PetscFree(work));
117:   PetscCall(VecRestoreArrayWrite(en->sqrt_eigen_vals, &eig_array));
118:   PetscCall(MatDenseRestoreArrayWrite(en->V, &a_array));

120:   /* T = (1/rho)*I + S^T*S is SPD by construction (rho > 0, S^T*S is PSD), so a strongly negative
121:      eigenvalue means the decomposition went wrong upstream. Catch in debug builds before
122:      VecSqrtAbs() rewrites the sign and the analysis silently uses garbage T^{-1/2}. The tolerance
123:      is sqrt(eps_machine)*||T||_F so the test scales with both working precision and problem
124:      magnitude; this is far tighter than MATRIX_SQRT_TOLERANCE_FACTOR (used downstream for
125:      matrix-reconstruction verification) because we are checking a sign error, not the
126:      accuracy of an O(eps)-noisy reconstruction. */
127:   if (PetscDefined(USE_DEBUG)) {
128:     PetscReal lambda_min, norm_T, tol;

130:     PetscCall(VecMin(en->sqrt_eigen_vals, NULL, &lambda_min));
131:     PetscCall(MatNorm(en->I_StS, NORM_FROBENIUS, &norm_T));
132:     tol = PetscSqrtReal(PETSC_MACHINE_EPSILON) * norm_T;
133:     PetscCheck(lambda_min >= -tol, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "T = (1/rho)I + S^T*S has eigenvalue %g; expected >= -%g (sqrt(eps)*||T||, ||T|| = %g)", (double)lambda_min, (double)tol, (double)norm_T);
134:   }

136:   /* Compute sqrt(eigenvalues) */
137:   PetscCall(VecSqrtAbs(en->sqrt_eigen_vals));

139:   /* Debug verification: Ensure V * D * V^T == T */
140:   if (PetscDefined(USE_DEBUG)) {
141:     PetscReal norm_T, norm_diff, relative_error;
142:     Mat       V_D, VDVt;

144:     /* Compute D * V^T by scaling rows */
145:     PetscCall(MatDuplicate(en->V, MAT_COPY_VALUES, &V_D));

147:     /* Restore D for verification (since sqrt_eigen_vals currently holds sqrt(D)) */
148:     PetscCall(VecPointwiseMult(en->sqrt_eigen_vals, en->sqrt_eigen_vals, en->sqrt_eigen_vals));

150:     PetscCall(MatDiagonalScale(V_D, NULL, en->sqrt_eigen_vals));

152:     /* Compute V * D * V^T */
153:     PetscCall(MatMatTransposeMult(V_D, en->V, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &VDVt));

155:     /* Compute ||V*D*V^T - T|| / ||T|| */
156:     PetscCall(MatAXPY(VDVt, -1.0, en->I_StS, SAME_NONZERO_PATTERN));
157:     PetscCall(MatNorm(en->I_StS, NORM_FROBENIUS, &norm_T));
158:     PetscCall(MatNorm(VDVt, NORM_FROBENIUS, &norm_diff));

160:     PetscCheck(norm_T > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "T = 0");
161:     relative_error = norm_diff / norm_T;
162:     PetscCheck(relative_error < MATRIX_SQRT_TOLERANCE_FACTOR, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "Eigendecomposition verification failed: ||V*D*V^T - T||/||T|| = %g", (double)relative_error);

164:     /* Restore sqrt(D) back to sqrt_eigen_vals */
165:     PetscCall(VecSqrtAbs(en->sqrt_eigen_vals));

167:     /* Cleanup debug matrices */
168:     PetscCall(MatDestroy(&V_D));
169:     PetscCall(MatDestroy(&VDVt));
170:   }
171:   PetscFunctionReturn(PETSC_SUCCESS);
172: }

174: /*@
175:   PetscDAEnsembleTFactor - Compute and store factorization of T matrix

177:   Collective

179:   Input Parameters:
180: + da - the `PetscDA` context
181: - S  - normalized innovation matrix (obs_size x m)

183:   Level: advanced

185:   Notes:
186:   This function computes $T = (1/\rho) I + S^T * S$ (where $\rho$ is the inflation factor set via
187:   `PetscDAEnsembleSetInflation()`) and stores its symmetric eigendecomposition, i.e. eigenvectors
188:   $V$ and eigenvalues $D$ such that $T = V * D * V^T$.

190:   The implementation uses matrix reuse (`MAT_REUSE_MATRIX`) to minimize memory allocation
191:   overhead when the ensemble size remains constant across analysis cycles.

193: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleApplyTInverse()`, `PetscDAEnsembleApplySqrtTInverse()`
194: @*/
195: PetscErrorCode PetscDAEnsembleTFactor(PetscDA da, Mat S)
196: {
197:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
198:   PetscInt          m, s_rows, s_cols;
199:   MatReuse          scall = MAT_INITIAL_MATRIX;

201:   PetscFunctionBegin;
204:   PetscCall(MatGetSize(S, &s_rows, &s_cols));
205:   m = s_cols; /* Ensemble size */
206:   PetscCheck(m > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Innovation matrix S must have positive columns, got %" PetscInt_FMT, m);
207:   PetscCheck(m == en->size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "S matrix columns (%" PetscInt_FMT ") must match ensemble size (%" PetscInt_FMT ") defined in PetscDA", m, en->size);

209:   /* 2. Manage Resource Reuse */
210:   /* Check if we can reuse the T matrix (I_StS) and dependent factors */
211:   if (en->I_StS) {
212:     PetscInt t_rows, t_cols;
213:     PetscCall(MatGetSize(en->I_StS, &t_rows, &t_cols));

215:     /* If dimensions have changed, drop the stale T/V/eigen state so the MAT_INITIAL_MATRIX
216:        initializer at declaration takes effect; otherwise switch to MAT_REUSE_MATRIX. */
217:     if (t_rows != m || t_cols != m) {
218:       PetscCall(MatDestroy(&en->I_StS));
219:       PetscCall(MatDestroy(&en->V));
220:       PetscCall(VecDestroy(&en->sqrt_eigen_vals));
221:       PetscCall(PetscInfo(da, "Ensemble size changed (old: %" PetscInt_FMT ", new: %" PetscInt_FMT "), reallocating T matrix and factors\n", t_rows, m));
222:     } else scall = MAT_REUSE_MATRIX;
223:   }

225:   /* 3. Compute T = (1/rho)I + S^T * S (the (1/rho) shift is added below). */
226:   /*
227:      MatTransposeMatMult computes C = A^T * B (here C = S^T * S).
228:      When using MAT_REUSE_MATRIX, the existing C is overwritten with the new result.
229:   */
230:   PetscCall(MatTransposeMatMult(S, S, scall, PETSC_DEFAULT, &en->I_StS));

232:   /* Add Identity: T = (1/rho)I + S^T*S */
233:   PetscCall(MatShift(en->I_StS, 1.0 / en->inflation));

235:   /* 4. Compute symmetric eigendecomposition T = V * D * V^T */
236:   PetscCall(PetscDAEnsembleTFactor_Eigen(da));
237:   PetscFunctionReturn(PETSC_SUCCESS);
238: }

240: /*
241:   ApplyTInverse_Eigen - Helper for Eigendecomposition solver path
242: */
243: static PetscErrorCode ApplyTInverse_Eigen(PetscDA da, Vec sdel, Vec w)
244: {
245:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
246:   Vec               temp;

248:   PetscFunctionBegin;
249:   PetscCheck(en->V, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Eigenvectors not computed");
250:   PetscCheck(en->sqrt_eigen_vals, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Eigenvalues not computed");

252:   /* Allocate temporary vector for projection */
253:   PetscCall(VecDuplicate(sdel, &temp));

255:   /* 1. Project onto eigenvectors: temp = V^T * sdel */
256:   PetscCall(MatMultTranspose(en->V, sdel, temp));

258:   /* 2. Scale by inverse eigenvalues: temp = D^{-1} * temp */
259:   /* We store sqrt(D), so divide twice: temp = (temp / sqrt(D)) / sqrt(D) */
260:   PetscCall(VecPointwiseDivide(temp, temp, en->sqrt_eigen_vals));
261:   PetscCall(VecPointwiseDivide(temp, temp, en->sqrt_eigen_vals));

263:   /* 3. Map back to standard basis: w = V * temp */
264:   PetscCall(MatMult(en->V, temp, w));

266:   PetscCall(VecDestroy(&temp));
267:   PetscFunctionReturn(PETSC_SUCCESS);
268: }

270: /*@
271:   PetscDAEnsembleApplyTInverse - Apply T^{-1} to a vector [Alg 6.4 line 8]

273:   Collective

275:   Input Parameters:
276: + da   - the `PetscDA` context
277: - sdel - input vector S^T-delta

279:   Output Parameter:
280: . w - output vector w = T^{-1} * sdel

282:   Level: advanced

284:   Notes:
285:   This function applies the inverse of $T = (1/\rho) I + S^T S$ (with $\rho$ the inflation factor)
286:   using the stored symmetric eigendecomposition: $T^{-1} = V D^{-1} V^T$.

288: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleTFactor()`, `PetscDAEnsembleApplySqrtTInverse()`
289: @*/
290: PetscErrorCode PetscDAEnsembleApplyTInverse(PetscDA da, Vec sdel, Vec w)
291: {
292:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

294:   PetscFunctionBegin;

299:   PetscCheck(en->I_StS, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "T matrix not factored. Call PetscDAEnsembleTFactor first");
300:   PetscCall(ApplyTInverse_Eigen(da, sdel, w));
301:   PetscFunctionReturn(PETSC_SUCCESS);
302: }

304: /*
305:   ApplySqrtTInverse_Eigen - Computes Y = V * D^{-1/2} * V^T * U.

307:   Notes:
308:   This computes the symmetric square root T^{-1/2} = V * D^{-1/2} * V^T.
309:   The operation is performed as Y = V * (D^{-1/2} * (V^T * U)) to strictly follow
310:   linear algebra operations for general matrix U.
311: */
312: static PetscErrorCode ApplySqrtTInverse_Eigen(PetscDA da, Mat U, Mat Y)
313: {
314:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
315:   Mat               W;
316:   Vec               diag_inv;

318:   PetscFunctionBegin;
319:   PetscCheck(en->V, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Eigenvectors not computed");
320:   PetscCheck(en->sqrt_eigen_vals, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "Eigenvalues not computed");

322:   /* Prepare inverse sqrt eigenvalues: D^{-1/2}
323:      Note: en->sqrt_eigen_vals currently stores sqrt(D) */
324:   PetscCall(VecDuplicate(en->sqrt_eigen_vals, &diag_inv));
325:   PetscCall(VecCopy(en->sqrt_eigen_vals, diag_inv));
326:   PetscCall(VecReciprocal(diag_inv)); /* Now diag_inv contains 1/sqrt(D) = D^{-1/2} */

328:   if (U) {
329:     /* General case: Compute Y = V * D^{-1/2} * V^T * U */
330:     /* Step 1: Compute W = V^T * U (Project U onto eigenbasis) */
331:     PetscCall(MatTransposeMatMult(en->V, U, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &W));

333:     /* Step 2: Scale rows of W by D^{-1/2}: W <- D^{-1/2} * W */
334:     PetscCall(MatDiagonalScale(W, diag_inv, NULL));

336:     /* Step 3: Compute Y = V * W (Project back to standard basis)
337:        Y = V * (D^{-1/2} * V^T * U) */
338:     {
339:       Mat Y_temp;
340:       PetscCall(MatMatMult(en->V, W, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &Y_temp));
341:       PetscCall(MatCopy(Y_temp, Y, SAME_NONZERO_PATTERN));
342:       PetscCall(MatDestroy(&Y_temp));
343:     }

345:     /* Cleanup */
346:     PetscCall(MatDestroy(&W));
347:   } else {
348:     /* U is NULL (identity): Compute Y = V * D^{-1/2} * V^T directly */
349:     /* Step 1: Compute W = V * D^{-1/2} (scale columns of V) */
350:     PetscCall(MatDuplicate(en->V, MAT_COPY_VALUES, &W));
351:     PetscCall(MatDiagonalScale(W, NULL, diag_inv));

353:     /* Step 2: Compute Y = W * V^T = V * D^{-1/2} * V^T */
354:     {
355:       Mat Y_temp;
356:       PetscCall(MatMatTransposeMult(W, en->V, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &Y_temp));
357:       PetscCall(MatCopy(Y_temp, Y, SAME_NONZERO_PATTERN));
358:       PetscCall(MatDestroy(&Y_temp));
359:     }

361:     /* Cleanup */
362:     PetscCall(MatDestroy(&W));
363:   }

365:   PetscCall(VecDestroy(&diag_inv));
366:   PetscFunctionReturn(PETSC_SUCCESS);
367: }

369: /*@
370:   PetscDAEnsembleApplySqrtTInverse - Apply T^{-1/2} to a matrix U [Alg 6.4 line 9]

372:   Collective

374:   Input Parameters:
375: + da - the `PetscDA` context
376: - U  - input matrix (usually Identity, but can be general)

378:   Output Parameter:
379: . Y - output matrix Y = T^{-1/2} * U

381:   Level: advanced

383:   Notes:
384:   This function applies the symmetric inverse square root of $T = (1/\rho) I + S^T * S$ (with $\rho$
385:   the inflation factor) using the stored eigendecomposition: $Y = V D^{-1/2} V^T U$. The result
386:   satisfies $Y^T * T * Y = U^T * U$, preserving the metric.

388: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleTFactor()`, `PetscDAEnsembleApplyTInverse()`
389: @*/
390: PetscErrorCode PetscDAEnsembleApplySqrtTInverse(PetscDA da, Mat U, Mat Y)
391: {
392:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

394:   PetscFunctionBegin;

399:   PetscCheck(en->I_StS, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONGSTATE, "I_StS matrix not created. Call PetscDAEnsembleTFactor first");
400:   PetscCall(ApplySqrtTInverse_Eigen(da, U, Y));

402:   /* Debugging verification: Check that metric is preserved
403:      Verify that Y^T * T * Y = U^T * U (or Y^T * T * Y = I if U is NULL) */
404:   if (PetscDefined(USE_DEBUG)) {
405:     Mat       YtTY, T_Y;
406:     PetscReal norm_T, norm_diff;

408:     /* Compute LHS: Y^T * T * Y */
409:     PetscCall(MatMatMult(en->I_StS, Y, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &T_Y));     /* T * Y */
410:     PetscCall(MatTransposeMatMult(Y, T_Y, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &YtTY)); /* Y^T * (T * Y) */

412:     if (U) {
413:       Mat       UtU;
414:       PetscReal norm_ref;

416:       /* Compute RHS: U^T * U and difference YtTY <- YtTY - U^T*U */
417:       PetscCall(MatTransposeMatMult(U, U, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &UtU));
418:       PetscCall(MatAXPY(YtTY, -1.0, UtU, SAME_NONZERO_PATTERN));

420:       /* Check norms. When ||U^T*U|| == 0 the relative form is undefined, so fall back to an
421:          absolute tolerance scaled by ||T|| (the only nonzero scale we have on hand) instead of
422:          silently passing on any norm_diff. */
423:       PetscCall(MatNorm(UtU, NORM_FROBENIUS, &norm_ref));
424:       PetscCall(MatNorm(YtTY, NORM_FROBENIUS, &norm_diff));
425:       if (norm_ref > 0.0) PetscCheck(norm_diff / norm_ref < MATRIX_SQRT_TOLERANCE_FACTOR, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "T^{-1/2} verification failed. ||Y^T*T*Y - U^T*U||/||U^T*U|| = %g", (double)(norm_diff / norm_ref));
426:       else {
427:         PetscCall(MatNorm(en->I_StS, NORM_FROBENIUS, &norm_T));
428:         PetscCheck(norm_diff <= MATRIX_SQRT_TOLERANCE_FACTOR * norm_T, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "T^{-1/2} verification failed (U^T*U is zero). ||Y^T*T*Y|| = %g, ||T|| = %g", (double)norm_diff, (double)norm_T);
429:       }
430:       PetscCall(MatDestroy(&UtU));
431:     } else {
432:       /* RHS is the identity: form YtTY - I via MatShift, then compare against ||T|| */
433:       PetscCall(MatShift(YtTY, -1.0));
434:       PetscCall(MatNorm(YtTY, NORM_FROBENIUS, &norm_diff));
435:       PetscCall(MatNorm(en->I_StS, NORM_FROBENIUS, &norm_T));
436:       PetscCheck(norm_diff <= MATRIX_SQRT_TOLERANCE_FACTOR * norm_T, PetscObjectComm((PetscObject)da), PETSC_ERR_PLIB, "T^{-1/2} verification failed (U is NULL). ||Y^T*T*Y - I|| = %g, ||T|| = %g", (double)norm_diff, (double)norm_T);
437:     }

439:     /* Cleanup debug matrices */
440:     PetscCall(MatDestroy(&T_Y));
441:     PetscCall(MatDestroy(&YtTY));
442:   }
443:   PetscFunctionReturn(PETSC_SUCCESS);
444: }

446: /*@
447:   PetscDAEnsembleSetInflation - Sets the inflation factor for the data assimilation method.

449:   Logically Collective

451:   Input Parameters:
452: + da        - the `PetscDA` context
453: - inflation - the inflation factor (must be >= 1.0)

455:   Level: intermediate

457: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleGetInflation()`
458: @*/
459: PetscErrorCode PetscDAEnsembleSetInflation(PetscDA da, PetscReal inflation)
460: {
461:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

463:   PetscFunctionBegin;
466:   PetscCheck(inflation >= 1.0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Inflation factor must be >= 1.0, got %g", (double)inflation);
467:   en->inflation = inflation;
468:   PetscFunctionReturn(PETSC_SUCCESS);
469: }

471: /*@
472:   PetscDAEnsembleGetInflation - Gets the inflation factor for the data assimilation method.

474:   Not Collective

476:   Input Parameter:
477: . da - the `PetscDA` context

479:   Output Parameter:
480: . inflation - the inflation factor

482:   Level: intermediate

484: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleSetInflation()`
485: @*/
486: PetscErrorCode PetscDAEnsembleGetInflation(PetscDA da, PetscReal *inflation)
487: {
488:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

490:   PetscFunctionBegin;
492:   PetscAssertPointer(inflation, 2);
493:   *inflation = en->inflation;
494:   PetscFunctionReturn(PETSC_SUCCESS);
495: }

497: /*@
498:   PetscDAEnsembleGetMember - Returns a read-only view of an ensemble member stored in the `PetscDA`.

500:   Collective

502:   Input Parameters:
503: + da         - the `PetscDA` context
504: - member_idx - index of the requested member (0 <= idx < ensemble_size)

506:   Output Parameter:
507: . member - read-only vector view; call `PetscDAEnsembleRestoreMember()` when done

509:   Level: intermediate

511: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleRestoreMember()`, `PetscDAEnsembleSetMember()`
512: @*/
513: PetscErrorCode PetscDAEnsembleGetMember(PetscDA da, PetscInt member_idx, Vec *member)
514: {
515:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

517:   PetscFunctionBegin;
519:   PetscAssertPointer(member, 3);
520:   PetscCheck(en->ensemble, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "PetscDASetUp() must be called before accessing ensemble members");
521:   PetscCheck(member_idx >= 0 && member_idx < en->size, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Member index %" PetscInt_FMT " out of range [0, %" PetscInt_FMT ")", member_idx, en->size);

523:   PetscCall(MatDenseGetColumnVecRead(en->ensemble, member_idx, member));
524:   PetscFunctionReturn(PETSC_SUCCESS);
525: }

527: /*@
528:   PetscDAEnsembleRestoreMember - Returns a column view obtained with `PetscDAEnsembleGetMember()`.

530:   Collective

532:   Input Parameters:
533: + da         - the `PetscDA` context
534: . member_idx - index that was previously requested
535: - member     - location that holds the view to restore

537:   Level: intermediate

539: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleGetMember()`
540: @*/
541: PetscErrorCode PetscDAEnsembleRestoreMember(PetscDA da, PetscInt member_idx, Vec *member)
542: {
543:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

545:   PetscFunctionBegin;
547:   PetscAssertPointer(member, 3);
548:   PetscCheck(member_idx >= 0 && member_idx < en->size, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Member index %" PetscInt_FMT " out of range [0, %" PetscInt_FMT ")", member_idx, en->size);

550:   PetscCall(MatDenseRestoreColumnVecRead(en->ensemble, member_idx, member));
551:   PetscFunctionReturn(PETSC_SUCCESS);
552: }

554: /*@
555:   PetscDAEnsembleSetMember - Overwrites an ensemble member with user-provided state data.

557:   Collective

559:   Input Parameters:
560: + da         - the `PetscDA` context
561: . member_idx - index of the entry to modify
562: - member     - vector containing the new state values

564:   Level: intermediate

566: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleGetMember()`
567: @*/
568: PetscErrorCode PetscDAEnsembleSetMember(PetscDA da, PetscInt member_idx, Vec member)
569: {
570:   Vec               col;
571:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

573:   PetscFunctionBegin;
576:   PetscCheck(en->ensemble, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "PetscDASetUp() must be called before setting ensemble members");
577:   PetscCheck(member_idx >= 0 && member_idx < en->size, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Member index %" PetscInt_FMT " out of range [0, %" PetscInt_FMT ")", member_idx, en->size);

579:   PetscCall(MatDenseGetColumnVecWrite(en->ensemble, member_idx, &col));
580:   PetscCall(VecCopy(member, col));
581:   PetscCall(MatDenseRestoreColumnVecWrite(en->ensemble, member_idx, &col));
582:   PetscFunctionReturn(PETSC_SUCCESS);
583: }

585: /*@
586:   PetscDAEnsembleComputeMean - Computes ensemble mean for a `PetscDA`

588:   Collective

590:   Input Parameter:
591: . da - the `PetscDA` context

593:   Output Parameter:
594: . mean - vector that will hold the ensemble mean

596:   Level: intermediate

598: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleComputeAnomalies()`
599: @*/
600: PetscErrorCode PetscDAEnsembleComputeMean(PetscDA da, Vec mean)
601: {
602:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
603:   PetscScalar       inv_m;
604:   PetscInt          m;

606:   PetscFunctionBegin;
609:   PetscCheck(en->ensemble, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "PetscDASetUp() must be called before computing the ensemble mean");
610:   PetscCheck(en->size > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_WRONG, "Ensemble size must be positive");

612:   m     = en->size;
613:   inv_m = 1.0 / (PetscScalar)m;
614:   PetscCall(MatGetRowSum(en->ensemble, mean));
615:   PetscCall(VecScale(mean, inv_m));
616:   PetscFunctionReturn(PETSC_SUCCESS);
617: }

619: /*@
620:   PetscDAEnsembleInitialize - Initialize ensemble members with Gaussian perturbations

622:   Collective

624:   Input Parameters:
625: + da            - PetscDA context
626: . x0            - Background state
627: . obs_error_std - Target ensemble spread (standard deviation) after sample-mean removal
628: - rng           - Random number generator

630:   Level: beginner

632:   Notes:
633:   Each member is drawn as `Gaussian(0, obs_error_std * sqrt(m / (m - 1)))` (with `m` the ensemble size),
634:   the sample mean across the ensemble is subtracted, and `x0` is added. The pre-mean-removal scale
635:   by `sqrt(m / (m - 1))` compensates for the variance reduction from centering, so the per-member
636:   spread after the subtraction is approximately `obs_error_std` regardless of `m`.

638: .seealso: [](ch_da), `PETSCDALETKF`, `PetscDA`
639: @*/
640: PetscErrorCode PetscDAEnsembleInitialize(PetscDA da, Vec x0, PetscReal obs_error_std, PetscRandom rng)
641: {
642:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
643:   Vec               member, col, x_mean;
644:   PetscReal         scale;

646:   PetscFunctionBegin;
650:   PetscCall(VecDuplicate(x0, &member));
651:   PetscCall(VecDuplicate(x0, &x_mean));

653:   /*
654:      Scale factor to maintain consistent ensemble spread across different ensemble sizes.
655:      After removing the sample mean, the ensemble variance is approximately:
656:        Var_final ~= Var_initial * (m-1)/m
657:      To maintain consistent initial spread regardless of m, we scale by sqrt(m/(m-1)).
658:      This ensures the final ensemble spread is approximately obs_error_std^2. */
659:   scale = PetscSqrtReal((PetscReal)en->size / (PetscReal)(en->size - 1));

661:   /* Populate the Gaussian draws with scaled standard deviation */
662:   for (PetscInt i = 0; i < en->size; i++) {
663:     PetscCall(VecSetRandomGaussian(member, rng, 0.0, obs_error_std * scale));
664:     PetscCall(PetscDAEnsembleSetMember(da, i, member));
665:   }
666:   /* get mean of perturbations */
667:   PetscCall(PetscDAEnsembleComputeMean(da, x_mean));
668:   /* remove mean and add x0 */
669:   for (PetscInt i = 0; i < en->size; i++) {
670:     PetscCall(MatDenseGetColumnVecWrite(en->ensemble, i, &col));
671:     PetscCall(VecAXPY(col, -1.0, x_mean));
672:     PetscCall(VecAXPY(col, 1.0, x0));
673:     PetscCall(MatDenseRestoreColumnVecWrite(en->ensemble, i, &col));
674:   }

676:   PetscCall(VecDestroy(&member));
677:   PetscCall(VecDestroy(&x_mean));
678:   PetscFunctionReturn(PETSC_SUCCESS);
679: }

681: /*@
682:   PetscDAEnsembleComputeAnomalies - Forms the state-space anomalies matrix for a `PetscDA`.

684:   Collective

686:   Input Parameters:
687: + da      - the `PetscDA` context
688: - mean_in - optional mean state vector (pass `NULL` to compute internally)

690:   Output Parameter:
691: . anomalies_out - location to store the newly created anomalies matrix

693:   Level: intermediate

695:   Notes:
696:   If `mean` is `NULL`, the function will create a temporary vector and compute
697:   the ensemble mean using `PetscDAEnsembleComputeMean()`. If `mean` is provided,
698:   it will be used directly, which can improve performance when the mean has
699:   already been computed.

701: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleComputeMean()`
702: @*/
703: PetscErrorCode PetscDAEnsembleComputeAnomalies(PetscDA da, Vec mean_in, Mat *anomalies_out)
704: {
705:   PetscDA_Ensemble *en   = (PetscDA_Ensemble *)da->data;
706:   Vec               mean = NULL;
707:   Vec               col_in, col_out;
708:   Mat               anomalies;
709:   MPI_Comm          comm;
710:   PetscReal         scale;
711:   PetscInt          ensemble_size;
712:   PetscInt          j;
713:   PetscBool         mean_created = PETSC_FALSE;

715:   PetscFunctionBegin;
718:   PetscAssertPointer(anomalies_out, 3);
719:   PetscCheck(en->ensemble, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "PetscDASetUp() must be called before computing anomalies");
720:   PetscCheck(en->size > 1, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size must be at least 2 to form anomalies");
721:   PetscCheck(da->state_size > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "State size must be positive");

723:   /* Cache frequently-used values for clarity and efficiency */
724:   ensemble_size = en->size;
725:   comm          = PetscObjectComm((PetscObject)en->ensemble);

727:   /*
728:     Compute normalization scale for anomalies.
729:     Alg 6.4 line 2: anomalies are normalized by 1/sqrt(m-1) so that
730:     the anomalies matrix X satisfies X*X^T = ensemble covariance matrix.
731:     This ensures proper statistical properties for ensemble-based methods.
732:   */
733:   scale = 1.0 / PetscSqrtReal((PetscReal)(ensemble_size - 1));

735:   /* Allocate anomalies matrix (state_size x ensemble_size) */
736:   PetscCall(MatCreateDense(comm, da->local_state_size, PETSC_DECIDE, da->state_size, ensemble_size, NULL, &anomalies));
737:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)anomalies, "dense_"));
738:   PetscCall(MatSetFromOptions(anomalies));
739:   PetscCall(MatSetUp(anomalies));

741:   /* Use provided mean or create and compute it */
742:   if (mean_in) {
743:     mean = mean_in;
744:   } else {
745:     /* Create and compute ensemble mean vector */
746:     PetscCall(MatCreateVecs(anomalies, NULL, &mean));
747:     PetscCall(VecSetFromOptions(mean));
748:     mean_created = PETSC_TRUE;

750:     /* Alg 6.4 line 1: \bar{x} = (1/m)\sum_j x^{(j)} */
751:     PetscCall(PetscDAEnsembleComputeMean(da, mean));
752:   }

754:   /*
755:     Form anomalies by subtracting mean from each ensemble member and scaling.
756:     For each column j: anomaly_j = (ensemble_j - mean) / sqrt(m-1)
757:   */
758:   for (j = 0; j < ensemble_size; ++j) {
759:     PetscCall(MatDenseGetColumnVecRead(en->ensemble, j, &col_in));
760:     PetscCall(MatDenseGetColumnVecWrite(anomalies, j, &col_out));

762:     /* Alg 6.4 line 2: subtract the mean column-wise to form x^{(j)} - \bar{x} */
763:     PetscCall(VecWAXPY(col_out, -1.0, mean, col_in));
764:     /* Alg 6.4 line 2: scale anomalies by 1/\sqrt{m-1} */
765:     PetscCall(VecScale(col_out, scale));

767:     PetscCall(MatDenseRestoreColumnVecWrite(anomalies, j, &col_out));
768:     PetscCall(MatDenseRestoreColumnVecRead(en->ensemble, j, &col_in));
769:   }
770:   /* Transfer ownership to output and clean up temporary resources */
771:   *anomalies_out = anomalies;
772:   if (mean_created) PetscCall(VecDestroy(&mean));
773:   PetscFunctionReturn(PETSC_SUCCESS);
774: }

776: /*@
777:   PetscDAEnsembleAnalysis - Executes the analysis (update) step using sparse observation matrix H

779:   Collective

781:   Input Parameters:
782: + da          - the `PetscDA` context
783: . observation - observation vector y in R^P
784: - H           - observation operator matrix (P x N), sparse AIJ format

786:   Level: intermediate

788:   Notes:
789:   The observation matrix H maps from state space (N dimensions) to observation
790:   space (P dimensions): y = H*x + noise

792:   H must be a sparse AIJ matrix

794:   For identity observations (observe entire state), use an identity matrix for H.
795:   For partial observations, set appropriate rows and columns to observe
796:   specific state components. On return, the ensemble matrix held by `da` has
797:   been updated in place: every member has been replaced by its analysis update.
798:   Read the analysis state with `PetscDAEnsembleGetMember()` or `PetscDAEnsembleComputeMean()`.

800: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleForecast()`, `PetscDASetObsErrorVariance()`,
801:           `PetscDAEnsembleGetMember()`, `PetscDAEnsembleComputeMean()`
802: @*/
803: PetscErrorCode PetscDAEnsembleAnalysis(PetscDA da, Vec observation, Mat H)
804: {
805:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
806:   PetscInt          h_rows, h_cols;

808:   PetscFunctionBegin;
812:   PetscCheck(en->size > 1, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size must be > 1, got %" PetscInt_FMT, en->size);
813:   PetscCall(MatGetSize(H, &h_rows, &h_cols));
814:   PetscCheck(h_rows == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "H matrix rows (%" PetscInt_FMT ") must match obs_size (%" PetscInt_FMT ")", h_rows, da->obs_size);
815:   PetscCheck(h_cols == da->state_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "H matrix cols (%" PetscInt_FMT ") must match state_size (%" PetscInt_FMT ")", h_cols, da->state_size);
816:   PetscCall(VecGetSize(observation, &h_rows));
817:   PetscCheck(h_rows == da->obs_size, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_INCOMP, "observation vector size (%" PetscInt_FMT ") must match obs_size (%" PetscInt_FMT ")", h_rows, da->obs_size);

819:   PetscCall(PetscLogEventBegin(PetscDA_Analysis, (PetscObject)da, 0, 0, 0));
820:   PetscCall((*en->analysis)(da, observation, H));
821:   PetscCall(PetscLogEventEnd(PetscDA_Analysis, (PetscObject)da, 0, 0, 0));
822:   PetscFunctionReturn(PETSC_SUCCESS);
823: }

825: /*@C
826:   PetscDAEnsembleForecast - Advances the entire ensemble through the user-supplied forecast model.

828:   Collective

830:   Input Parameters:
831: + da    - the `PetscDA` context
832: . model - routine that advances the ensemble matrix in place; if the model can only advance one state
833:           at a time (e.g. a `TS`-driven step), it must loop over columns itself
834: - ctx   - optional context for `model`

836:   Level: intermediate

838:   Note:
839:   The columns of the ensemble matrix are the individual members; `model` advances them in place.

841: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAEnsembleAnalysis()`
842: @*/
843: PetscErrorCode PetscDAEnsembleForecast(PetscDA da, PetscDAEnsembleForecastFn *model, PetscCtx ctx)
844: {
845:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

847:   PetscFunctionBegin;
849:   PetscCall((*en->forecast)(da, model, ctx));
850:   PetscFunctionReturn(PETSC_SUCCESS);
851: }

853: PetscErrorCode PetscDAView_Ensemble(PetscDA da, PetscViewer viewer)
854: {
855:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
856:   PetscBool         iascii;

858:   PetscFunctionBegin;
859:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
860:   if (iascii) {
861:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Ensemble size: %" PetscInt_FMT "\n", en->size));
862:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Assembled: %s\n", en->assembled ? "true" : "false"));
863:     PetscCall(PetscViewerASCIIPrintf(viewer, "  Inflation: %g\n", (double)en->inflation));
864:   }
865:   PetscFunctionReturn(PETSC_SUCCESS);
866: }

868: PetscErrorCode PetscDASetUp_Ensemble(PetscDA da)
869: {
870:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
871:   MPI_Comm          comm;

873:   PetscFunctionBegin;
874:   if (en->assembled) PetscFunctionReturn(PETSC_SUCCESS);

876:   PetscCheck(da->state_size > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "Must set state size before calling PetscDASetUp()");
877:   PetscCheck(da->obs_size > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "Must set observation size before calling PetscDASetUp()");
878:   PetscCheck(en->size > 0, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "Must set ensemble size before calling PetscDASetUp()");

880:   comm = PetscObjectComm((PetscObject)da);
881:   if (!en->ensemble) {
882:     PetscCall(MatCreateDense(comm, da->local_state_size, PETSC_DECIDE, da->state_size, en->size, NULL, &en->ensemble));
883:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)en->ensemble, "dense_"));
884:     PetscCall(MatSetFromOptions(en->ensemble));
885:     PetscCall(MatSetUp(en->ensemble));
886:   }
887:   en->assembled = PETSC_TRUE;
888:   PetscFunctionReturn(PETSC_SUCCESS);
889: }

891: /*@
892:   PetscDAEnsembleSetSize - Sets the ensemble dimensions used by a `PetscDA`.

894:   Collective

896:   Input Parameters:
897: + da            - the `PetscDA` context
898: - ensemble_size - number of ensemble members

900:   Options Database Key:
901: . -petscda_ensemble_size size - number of ensemble members

903:   Level: beginner

905:   Note:
906:   The size must be greater than or equal to two. See the scale factor in `PetscDAEnsembleInitialize()` and `PetscDALETKFLocalAnalysis()`

908: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDAGetSizes()`, `PetscDASetSizes()`, `PetscDASetUp()`
909: @*/
910: PetscErrorCode PetscDAEnsembleSetSize(PetscDA da, PetscInt ensemble_size)
911: {
912:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

914:   PetscFunctionBegin;
917:   PetscCheck(!en->assembled, PetscObjectComm((PetscObject)da), PETSC_ERR_ORDER, "Cannot change sizes after PetscDASetUp() has been called");
918:   PetscCheck(ensemble_size > 1, PetscObjectComm((PetscObject)da), PETSC_ERR_ARG_SIZ, "Ensemble size must be at least two");
919:   en->size = ensemble_size;
920:   PetscFunctionReturn(PETSC_SUCCESS);
921: }

923: /*@
924:   PetscDAEnsembleGetSize - Retrieves the dimension of the ensemble in a `PetscDA`.

926:   Not Collective

928:   Input Parameter:
929: . da - the `PetscDA` context

931:   Output Parameters:
932: . ensemble_size - number of ensemble members

934:   Level: beginner

936: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDASetSizes()`, `PetscDAGetSizes()`
937: @*/
938: PetscErrorCode PetscDAEnsembleGetSize(PetscDA da, PetscInt *ensemble_size)
939: {
940:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

942:   PetscFunctionBegin;
944:   PetscAssertPointer(ensemble_size, 2);
945:   *ensemble_size = en->size;
946:   PetscFunctionReturn(PETSC_SUCCESS);
947: }

949: PetscErrorCode PetscDASetFromOptions_Ensemble(PetscDA da, PetscOptionItems *PetscOptionsObjectPtr)
950: {
951:   PetscDA_Ensemble *en                 = (PetscDA_Ensemble *)da->data;
952:   PetscOptionItems  PetscOptionsObject = *PetscOptionsObjectPtr;
953:   PetscReal         inflation_val      = en->inflation;
954:   PetscBool         inflation_set, flg;
955:   PetscInt          ensemble_size;

957:   PetscFunctionBegin;
958:   PetscOptionsHeadBegin(PetscOptionsObject, "PetscDA Ensemble Options");

960:   PetscCall(PetscOptionsReal("-petscda_ensemble_inflation", "Inflation factor", "PetscDAEnsembleSetInflation", en->inflation, &inflation_val, &inflation_set));
961:   if (inflation_set) PetscCall(PetscDAEnsembleSetInflation(da, inflation_val));

963:   PetscCall(PetscOptionsInt("-petscda_ensemble_size", "Number of ensemble members", "PetscDAEnsembleSetSize", en->size, &ensemble_size, &flg));
964:   if (flg) PetscCall(PetscDAEnsembleSetSize(da, ensemble_size));
965:   PetscOptionsHeadEnd();
966:   PetscFunctionReturn(PETSC_SUCCESS);
967: }

969: PetscErrorCode PetscDADestroy_Ensemble(PetscDA da)
970: {
971:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

973:   PetscFunctionBegin;
974:   PetscCall(MatDestroy(&en->ensemble));
975:   PetscCall(VecDestroy(&da->obs_error_var));
976:   PetscCall(MatDestroy(&da->R));

978:   /* Destroy T-matrix factorization data */
979:   PetscCall(MatDestroy(&en->V));
980:   PetscCall(VecDestroy(&en->sqrt_eigen_vals));
981:   PetscCall(MatDestroy(&en->I_StS));
982:   PetscFunctionReturn(PETSC_SUCCESS);
983: }

985: PetscErrorCode PetscDACreate_Ensemble(PetscDA da)
986: {
987:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

989:   PetscFunctionBegin;
990:   en->size      = 0;
991:   en->ensemble  = NULL;
992:   en->assembled = PETSC_FALSE;
993:   en->inflation = 1.0;

995:   /* Initialize T-matrix factorization fields */
996:   en->V               = NULL;
997:   en->sqrt_eigen_vals = NULL;
998:   en->I_StS           = NULL;
999:   PetscFunctionReturn(PETSC_SUCCESS);
1000: }

1002: /*@
1003:   PetscDAEnsembleComputeNormalizedInnovationMatrix - Computes S = R^{-1/2}(Z - y_mean * 1')/sqrt(m-1) [Alg 6.4 line 5]

1005:   Collective

1007:   Input Parameters:
1008: + Z          - observation ensemble matrix
1009: . y_mean     - mean of observations
1010: . r_inv_sqrt - R^{-1/2}
1011: . m          - ensemble size
1012: - scale      - 1/sqrt(m-1)

1014:   Output Parameter:
1015: . S - normalized innovation matrix

1017:   Level: developer

1019: .seealso: [](ch_da), `PetscDA`, `PETSCDALETKF`, `PetscDASetSizes()`, `PetscDAGetSizes()`
1020: @*/
1021: PetscErrorCode PetscDAEnsembleComputeNormalizedInnovationMatrix(Mat Z, Vec y_mean, Vec r_inv_sqrt, PetscInt m, PetscScalar scale, Mat S)
1022: {
1023:   const PetscScalar *z_array, *y_array, *r_array;
1024:   PetscScalar       *s_array;
1025:   PetscInt           obs_size, obs_size_local, z_cols, i, j;
1026:   PetscInt           y_local_size, r_local_size;
1027:   PetscInt           lda_z, lda_s;

1029:   PetscFunctionBegin;
1036:   PetscCheck(m > 0, PetscObjectComm((PetscObject)Z), PETSC_ERR_ARG_OUTOFRANGE, "Ensemble size m must be positive, got %" PetscInt_FMT, m);
1037:   PetscCall(MatGetSize(Z, &obs_size, &z_cols));
1038:   PetscCall(MatGetLocalSize(Z, &obs_size_local, NULL));
1039:   PetscCheck(z_cols == m, PetscObjectComm((PetscObject)Z), PETSC_ERR_ARG_INCOMP, "Matrix Z has %" PetscInt_FMT " columns but ensemble size is %" PetscInt_FMT, z_cols, m);

1041:   /* Verify vector dimensions match observation size (both global and local) */
1042:   PetscCall(VecGetLocalSize(y_mean, &y_local_size));
1043:   PetscCall(VecGetLocalSize(r_inv_sqrt, &r_local_size));
1044:   PetscCheck(y_local_size == obs_size_local, PetscObjectComm((PetscObject)Z), PETSC_ERR_ARG_INCOMP, "Vector y_mean local size %" PetscInt_FMT " does not match matrix local rows %" PetscInt_FMT, y_local_size, obs_size_local);
1045:   PetscCheck(r_local_size == obs_size_local, PetscObjectComm((PetscObject)Z), PETSC_ERR_ARG_INCOMP, "Vector r_inv_sqrt local size %" PetscInt_FMT " does not match matrix local rows %" PetscInt_FMT, r_local_size, obs_size_local);

1047:   /* Get direct access to arrays for performance */
1048:   PetscCall(MatDenseGetArrayRead(Z, &z_array));
1049:   PetscCall(MatDenseGetArrayWrite(S, &s_array));
1050:   PetscCall(VecGetArrayRead(y_mean, &y_array));
1051:   PetscCall(VecGetArrayRead(r_inv_sqrt, &r_array));

1053:   /* Get Leading Dimension (LDA) to handle padding/strides correctly */
1054:   PetscCall(MatDenseGetLDA(Z, &lda_z));
1055:   PetscCall(MatDenseGetLDA(S, &lda_s));

1057:   /* Compute normalized innovation: S_ij = (Z_ij - y_mean_i) * scale * r_inv_sqrt_i
1058:      Iterate column-wise (j) then row-wise (i) for optimal cache access with column-major storage */
1059:   for (j = 0; j < m; j++) {
1060:     const PetscScalar *z_col = z_array + j * lda_z;
1061:     PetscScalar       *s_col = s_array + j * lda_s;

1063:     for (i = 0; i < obs_size_local; i++) s_col[i] = (z_col[i] - y_array[i]) * scale * r_array[i];
1064:   }

1066:   /* Restore arrays */
1067:   PetscCall(VecRestoreArrayRead(r_inv_sqrt, &r_array));
1068:   PetscCall(VecRestoreArrayRead(y_mean, &y_array));
1069:   PetscCall(MatDenseRestoreArrayWrite(S, &s_array));
1070:   PetscCall(MatDenseRestoreArrayRead(Z, &z_array));
1071:   PetscFunctionReturn(PETSC_SUCCESS);
1072: }

1074: PETSC_INTERN PetscErrorCode PetscDAEnsembleForecast_Ensemble(PetscDA da, PetscDAEnsembleForecastFn *model, PetscCtx ctx)
1075: {
1076:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;

1078:   PetscFunctionBegin;
1080:   PetscCall((*model)(en->ensemble, ctx));
1081:   PetscFunctionReturn(PETSC_SUCCESS);
1082: }