Actual source code: letkf.h

  1: #pragma once

  3: #include <petsc/private/daimpl.h>
  4: #include <petsc/private/daensembleimpl.h>

  6: typedef struct {
  7:   PetscDA_Ensemble en; /* MUST stay first: shared ensemble code casts da->data as (PetscDA_Ensemble *) */
  8:   Vec              mean;
  9:   Vec              y_mean;
 10:   Vec              delta_scaled;
 11:   Vec              w;
 12:   Vec              s_transpose_delta; /* PETSC_COMM_SELF Vec of length m, scratch for the per-analysis S^T * delta projection. */
 13:   Vec              r_inv_sqrt;
 14:   Mat              Z;
 15:   Mat              S;
 16:   Mat              T_sqrt;
 17:   Mat              w_ones;
 18:   /* Localization matrix (n_vertex_global x n_obs_global), variable nnz per row; built lazily on
 19:      first analysis. Setters that mutate Q-determining inputs (type, radius, coordinates) destroy
 20:      Q via PetscDALETKFResetLocalization() so the next analysis rebuilds. */
 21:   Mat                          Q;
 22:   PetscDALETKFLocalizationType type;                /* Localization kernel type */
 23:   PetscReal                    localization_radius; /* Cutoff half-width for built-in kernels */

 25:   /* Cached inputs for lazy Q construction (built-in kernels only) */
 26:   Vec       coord_xyz[3]; /* Coordinate vectors for grid points (per dimension) */
 27:   PetscReal coord_bd[3];  /* Periodic-domain extents (0 = non-periodic) */
 28:   Mat       coord_H;      /* Observation operator used to map coordinates to observation locations */

 30:   PetscInt max_nnz_per_row; /* Cached max nnz across all rows of Q (global) */
 31:   PetscInt n_nnz_local;     /* Cached total local nnz of Q (sum over local rows); set by InstallQ */
 32:   PetscInt batch_size;      /* Batch size for GPU processing */

 34:   /* Localization support for MPI */
 35:   IS         obs_is_local;    /* Indices of observations needed by this process */
 36:   VecScatter obs_scat;        /* Scatter context for observations */
 37:   Vec        obs_work;        /* Local work vector for observations */
 38:   Vec        y_mean_work;     /* Local work vector for y_mean */
 39:   Vec        r_inv_sqrt_work; /* Local work vector for r_inv_sqrt */
 40:   Mat        Z_work;          /* Local work matrix for Z (SeqDense) */
 41:   PetscHMapI obs_g2l;         /* Map global observation index to local index in obs_work */

 43:   /* Cached H-compatible work vecs to bridge MATAIJKOKKOS H with possibly-different impl->Z type
 44:      during the per-column Z = H*E and the y_mean = H*x_mean products. Built lazily on first
 45:      analysis and rebuilt when H's row/col layout or vec type changes. H_vec_type stores the
 46:      `MatGetVecType(H)` string snapshot used to build the temps; compared against the live
 47:      `MatGetVecType(H)` to detect H switching between e.g. AIJ and AIJKOKKOS. We cache the
 48:      mat-side string (umbrella name like "kokkos") rather than `VecGetType(H_temp_in)` (concrete
 49:      name like "seqkokkos") so a fresh `MatGetVecType` lookup matches without normalization. */
 50:   Vec   H_temp_in;
 51:   Vec   H_temp_out;
 52:   char *H_vec_type;

 54:   /* Device-side CSR view of Q (Kokkos Views cast to void*; backend reinterprets) */
 55:   void *Q_device_i; /* Row pointers (length n_vert_local + 1) */
 56:   void *Q_device_j; /* Column indices, LOCAL into obs_work (not global obs indices) */
 57:   void *Q_device_a; /* Nonzero values */

 59:   /* Persistent solver handles and workspace */
 60:   void *solver_handle; /* cusolverDnHandle_t / rocblas_handle / sycl::queue* */
 61:   void *eigen_work;    /* EigenWorkspace* */
 62: } PetscDA_LETKF;

 64: PETSC_INTERN const char *const PetscDALETKFLocalizationTypes[];

 66: PETSC_INTERN PetscErrorCode PetscDALETKFCreateLocalizationMat(PetscDALETKFLocalizationType, PetscReal, Vec[], PetscReal[], Mat, PetscBool, Mat *, PetscInt *, PetscInt *);
 67: PETSC_INTERN PetscErrorCode PetscDALETKFGatherObsBbox(PetscInt, Vec[], PetscReal[], PetscReal, Mat, Vec[], PetscInt *, PetscInt **, PetscReal **);
 68: PETSC_INTERN PetscErrorCode PetscDALETKFComputeObsCoords(Mat, Vec[], PetscInt *, Vec **);
 69: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyObsCoords(PetscInt, Vec **);
 70: PETSC_INTERN PetscErrorCode PetscDALETKFAssembleQFromCSR(Mat, PetscInt, PetscInt, PetscInt, MatType, const PetscInt[], const PetscInt[], const PetscInt[], const PetscScalar[], Mat *);
 71: PETSC_INTERN PetscErrorCode PetscDALETKFLogQStats(Mat, PetscDALETKFLocalizationType, PetscReal, PetscInt, PetscInt, const PetscInt[]);
 72: PETSC_INTERN PetscErrorCode PetscDALETKFCoalesceNnzMinMax(MPI_Comm, PetscInt *, PetscInt *);
 73: PETSC_INTERN PetscErrorCode PetscDALETKFSetupObsScatter(PetscDA_LETKF *, Mat);
 74: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyObsScatter(PetscDA_LETKF *);
 75: PETSC_INTERN PetscErrorCode PetscDALETKFReplicateWeightVector(Vec, PetscInt, Mat);
 76: PETSC_INTERN PetscErrorCode PetscDALETKFEnsureGlobalScratch(PetscDA_LETKF *, PetscInt);
 77: PETSC_INTERN PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA, PetscDA_LETKF *, PetscInt, PetscInt, Mat, Vec, Mat, Vec, Vec);
 78: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
 79: PETSC_INTERN PetscErrorCode PetscDALETKFCreateLocalizationMat_Kokkos(PetscDALETKFLocalizationType, PetscReal, Vec[], PetscReal[], Mat, Mat *, PetscInt *, PetscInt *);
 80: PETSC_INTERN PetscErrorCode PetscDALETKFLocalAnalysis_Kokkos(PetscDA, PetscDA_LETKF *, PetscInt, PetscInt, Mat, Vec, Mat, Vec, Vec);
 81: PETSC_INTERN PetscErrorCode PetscDALETKFGlobalAnalysis_Kokkos(PetscDA, PetscDA_LETKF *, PetscInt, Mat, Vec);
 82: PETSC_INTERN PetscErrorCode PetscDALETKFSetupLocalization_Kokkos(PetscDA_LETKF *);
 83: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyQDeviceMirrors_Kokkos(PetscDA_LETKF *);
 84: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyLocalization_Kokkos(PetscDA_LETKF *);
 85: #endif