Actual source code: letkf.h
1: #pragma once
3: #include <petsc/private/daimpl.h>
4: #include <petsc/private/daensembleimpl.h>
6: typedef struct {
7: PetscDA_Ensemble en; /* MUST stay first: shared ensemble code casts da->data as (PetscDA_Ensemble *) */
8: Vec mean;
9: Vec y_mean;
10: Vec delta_scaled;
11: Vec w;
12: Vec s_transpose_delta; /* PETSC_COMM_SELF Vec of length m, scratch for the per-analysis S^T * delta projection. */
13: Vec r_inv_sqrt;
14: Mat Z;
15: Mat S;
16: Mat T_sqrt;
17: Mat w_ones;
18: /* Localization matrix (n_vertex_global x n_obs_global), variable nnz per row; built lazily on
19: first analysis. Setters that mutate Q-determining inputs (type, radius, coordinates) destroy
20: Q via PetscDALETKFResetLocalization() so the next analysis rebuilds. */
21: Mat Q;
22: PetscDALETKFLocalizationType type; /* Localization kernel type */
23: PetscReal localization_radius; /* Cutoff half-width for built-in kernels */
25: /* Cached inputs for lazy Q construction (built-in kernels only) */
26: Vec coord_xyz[3]; /* Coordinate vectors for grid points (per dimension) */
27: PetscReal coord_bd[3]; /* Periodic-domain extents (0 = non-periodic) */
28: Mat coord_H; /* Observation operator used to map coordinates to observation locations */
30: PetscInt max_nnz_per_row; /* Cached max nnz across all rows of Q (global) */
31: PetscInt n_nnz_local; /* Cached total local nnz of Q (sum over local rows); set by InstallQ */
32: PetscInt batch_size; /* Batch size for GPU processing */
34: /* Localization support for MPI */
35: IS obs_is_local; /* Indices of observations needed by this process */
36: VecScatter obs_scat; /* Scatter context for observations */
37: Vec obs_work; /* Local work vector for observations */
38: Vec y_mean_work; /* Local work vector for y_mean */
39: Vec r_inv_sqrt_work; /* Local work vector for r_inv_sqrt */
40: Mat Z_work; /* Local work matrix for Z (SeqDense) */
41: PetscHMapI obs_g2l; /* Map global observation index to local index in obs_work */
43: /* Cached H-compatible work vecs to bridge MATAIJKOKKOS H with possibly-different impl->Z type
44: during the per-column Z = H*E and the y_mean = H*x_mean products. Built lazily on first
45: analysis and rebuilt when H's row/col layout or vec type changes. H_vec_type stores the
46: `MatGetVecType(H)` string snapshot used to build the temps; compared against the live
47: `MatGetVecType(H)` to detect H switching between e.g. AIJ and AIJKOKKOS. We cache the
48: mat-side string (umbrella name like "kokkos") rather than `VecGetType(H_temp_in)` (concrete
49: name like "seqkokkos") so a fresh `MatGetVecType` lookup matches without normalization. */
50: Vec H_temp_in;
51: Vec H_temp_out;
52: char *H_vec_type;
54: /* Device-side CSR view of Q (Kokkos Views cast to void*; backend reinterprets) */
55: void *Q_device_i; /* Row pointers (length n_vert_local + 1) */
56: void *Q_device_j; /* Column indices, LOCAL into obs_work (not global obs indices) */
57: void *Q_device_a; /* Nonzero values */
59: /* Persistent solver handles and workspace */
60: void *solver_handle; /* cusolverDnHandle_t / rocblas_handle / sycl::queue* */
61: void *eigen_work; /* EigenWorkspace* */
62: } PetscDA_LETKF;
64: PETSC_INTERN const char *const PetscDALETKFLocalizationTypes[];
66: PETSC_INTERN PetscErrorCode PetscDALETKFCreateLocalizationMat(PetscDALETKFLocalizationType, PetscReal, Vec[], PetscReal[], Mat, PetscBool, Mat *, PetscInt *, PetscInt *);
67: PETSC_INTERN PetscErrorCode PetscDALETKFGatherObsBbox(PetscInt, Vec[], PetscReal[], PetscReal, Mat, Vec[], PetscInt *, PetscInt **, PetscReal **);
68: PETSC_INTERN PetscErrorCode PetscDALETKFComputeObsCoords(Mat, Vec[], PetscInt *, Vec **);
69: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyObsCoords(PetscInt, Vec **);
70: PETSC_INTERN PetscErrorCode PetscDALETKFAssembleQFromCSR(Mat, PetscInt, PetscInt, PetscInt, MatType, const PetscInt[], const PetscInt[], const PetscInt[], const PetscScalar[], Mat *);
71: PETSC_INTERN PetscErrorCode PetscDALETKFLogQStats(Mat, PetscDALETKFLocalizationType, PetscReal, PetscInt, PetscInt, const PetscInt[]);
72: PETSC_INTERN PetscErrorCode PetscDALETKFCoalesceNnzMinMax(MPI_Comm, PetscInt *, PetscInt *);
73: PETSC_INTERN PetscErrorCode PetscDALETKFSetupObsScatter(PetscDA_LETKF *, Mat);
74: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyObsScatter(PetscDA_LETKF *);
75: PETSC_INTERN PetscErrorCode PetscDALETKFReplicateWeightVector(Vec, PetscInt, Mat);
76: PETSC_INTERN PetscErrorCode PetscDALETKFEnsureGlobalScratch(PetscDA_LETKF *, PetscInt);
77: PETSC_INTERN PetscErrorCode PetscDALETKFLocalAnalysis(PetscDA, PetscDA_LETKF *, PetscInt, PetscInt, Mat, Vec, Mat, Vec, Vec);
78: #if defined(PETSC_HAVE_KOKKOS_KERNELS)
79: PETSC_INTERN PetscErrorCode PetscDALETKFCreateLocalizationMat_Kokkos(PetscDALETKFLocalizationType, PetscReal, Vec[], PetscReal[], Mat, Mat *, PetscInt *, PetscInt *);
80: PETSC_INTERN PetscErrorCode PetscDALETKFLocalAnalysis_Kokkos(PetscDA, PetscDA_LETKF *, PetscInt, PetscInt, Mat, Vec, Mat, Vec, Vec);
81: PETSC_INTERN PetscErrorCode PetscDALETKFGlobalAnalysis_Kokkos(PetscDA, PetscDA_LETKF *, PetscInt, Mat, Vec);
82: PETSC_INTERN PetscErrorCode PetscDALETKFSetupLocalization_Kokkos(PetscDA_LETKF *);
83: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyQDeviceMirrors_Kokkos(PetscDA_LETKF *);
84: PETSC_INTERN PetscErrorCode PetscDALETKFDestroyLocalization_Kokkos(PetscDA_LETKF *);
85: #endif