Actual source code: plexlocalizationletkf.kokkos.cxx

  1: #include <petsc/private/dmpleximpl.h>
  2: #include <petscdmplex.h>
  3: #include <petscmat.h>
  4: #include <petsc_kokkos.hpp>
  5: #include <cmath>
  6: #include <cstdlib>
  7: #include <algorithm>
  8: #include <Kokkos_Core.hpp>

 10: typedef struct {
 11:   PetscReal distance;
 12:   PetscInt  obs_index;
 13: } DistObsPair;

 15: KOKKOS_INLINE_FUNCTION
 16: static PetscReal GaspariCohn(PetscReal distance, PetscReal radius)
 17: {
 18:   if (radius <= 0.0) return 0.0;
 19:   const PetscReal r = distance / radius;

 21:   if (r >= 2.0) return 0.0;

 23:   const PetscReal r2 = r * r;
 24:   const PetscReal r3 = r2 * r;
 25:   const PetscReal r4 = r3 * r;
 26:   const PetscReal r5 = r4 * r;

 28:   if (r <= 1.0) {
 29:     // Region [0, 1]
 30:     return -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0;
 31:   } else {
 32:     // Region [1, 2]
 33:     return (1.0 / 12.0) * r5 - 0.5 * r4 + 0.625 * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r;
 34:   }
 35: }

 37: /*@
 38:   DMPlexGetLETKFLocalizationMatrix - Compute localization weight matrix for LETKF [move to ml/da/interface]

 40:   Collective

 42:   Input Parameters:
 43: + n_obs_vertex - Number of nearest observations to use per vertex (eg, MAX_Q_NUM_LOCAL_OBSERVATIONS in LETKF)
 44: . n_obs_local - Number of local observations
 45: . n_dof - Number of degrees of freedom
 46: . Vecxyz - Array of vectors containing the coordinates
 47: - H - Observation operator matrix

 49:   Output Parameter:
 50: . Q - Localization weight matrix (sparse, AIJ format)

 52:   Notes:
 53:   The output matrix Q has dimensions (n_vert_global x n_obs_global) where
 54:   n_vert_global is the number of vertices in the DMPlex. Each row contains
 55:   exactly n_obs_vertex non-zero entries corresponding to the nearest
 56:   observations, weighted by the Gaspari-Cohn fifth-order piecewise
 57:   rational function.

 59:   The observation locations are computed as H * V where V is the vector
 60:   of vertex coordinates. The localization weights ensure smooth tapering
 61:   of observation influence with distance.

 63:   Kokkos is required for this routine.

 65:   Level: intermediate

 67: .seealso:
 68: @*/
 69: PetscErrorCode DMPlexGetLETKFLocalizationMatrix(const PetscInt n_obs_vertex, const PetscInt n_obs_local, const PetscInt n_dof, Vec Vecxyz[3], Mat H, Mat *Q)
 70: {
 71:   PetscInt dim = 0, n_vert_local, d, N, n_obs_global, n_state_local;
 72:   Vec     *obs_vecs;
 73:   MPI_Comm comm;
 74:   PetscInt n_state_global;

 76:   PetscFunctionBegin;
 78:   PetscAssertPointer(Q, 6);

 80:   PetscCall(PetscKokkosInitializeCheck());

 82:   PetscCall(PetscObjectGetComm((PetscObject)H, &comm));

 84:   /* Infer dim from the number of vectors in Vecxyz */
 85:   for (d = 0; d < 3; ++d) {
 86:     if (Vecxyz[d]) dim++;
 87:     else break;
 88:   }

 90:   PetscCheck(dim > 0, comm, PETSC_ERR_ARG_WRONG, "Dim must be > 0");
 91:   PetscCheck(n_obs_vertex > 0, comm, PETSC_ERR_ARG_WRONG, "n_obs_vertex must be > 0");

 93:   PetscCall(VecGetSize(Vecxyz[0], &n_state_global));
 94:   PetscCall(VecGetLocalSize(Vecxyz[0], &n_state_local));
 95:   n_vert_local = n_state_local / n_dof;

 97:   /* Check H dimensions */
 98:   PetscCall(MatGetSize(H, &n_obs_global, &N));
 99:   PetscCheck(N == n_state_global, comm, PETSC_ERR_ARG_SIZ, "H number of columns %" PetscInt_FMT " != global state size %" PetscInt_FMT, N, n_state_global);
100:   // If n_obs_global < n_obs_vertex, we will pad with -1 indices and 0.0 weights.
101:   // This is not an error condition, but rather a case where we have fewer observations than requested neighbors.

103:   /* Allocate storage for observation locations */
104:   PetscCall(PetscMalloc1(dim, &obs_vecs));

106:   /* Compute observation locations per dimension */
107:   for (d = 0; d < dim; ++d) {
108:     PetscCall(MatCreateVecs(H, NULL, &obs_vecs[d]));
109:     PetscCall(MatMult(H, Vecxyz[d], obs_vecs[d]));
110:   }

112:   /* Create output matrix Q in N/n_dof x P */
113:   PetscCall(MatCreate(comm, Q));
114:   PetscCall(MatSetSizes(*Q, n_vert_local, n_obs_local, PETSC_DETERMINE, n_obs_global));
115:   PetscCall(MatSetType(*Q, MATAIJ));
116:   PetscCall(MatSeqAIJSetPreallocation(*Q, n_obs_vertex, NULL));
117:   PetscCall(MatMPIAIJSetPreallocation(*Q, n_obs_vertex, NULL, n_obs_vertex, NULL));
118:   PetscCall(MatSetFromOptions(*Q));
119:   PetscCall(MatSetUp(*Q));

121:   PetscCall(PetscInfo((PetscObject)*Q, "Computing LETKF localization matrix: %" PetscInt_FMT " vertices, %" PetscInt_FMT " observations, %" PetscInt_FMT " neighbors\n", n_vert_local, n_obs_global, n_obs_vertex));

123:   /* Prepare Kokkos Views */
124:   using ExecSpace = Kokkos::DefaultExecutionSpace;
125:   using MemSpace  = ExecSpace::memory_space;

127:   /* Vertex Coordinates */
128:   // Use LayoutLeft for coalesced access on GPU (i is contiguous)
129:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> vertex_coords_dev("vertex_coords", n_vert_local, dim);
130:   {
131:     // Host view must match the data layout from VecGetArray (d-major, i-minor implies LayoutLeft for (i,d) view)
132:     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", n_vert_local, dim);
133:     for (d = 0; d < dim; ++d) {
134:       const PetscScalar *local_coords_array;
135:       PetscCall(VecGetArrayRead(Vecxyz[d], &local_coords_array));
136:       // Copy data. Since vertex_coords_host is LayoutLeft, &vertex_coords_host(0, d) is the start of column d.
137:       for (PetscInt i = 0; i < n_vert_local; ++i) vertex_coords_host(i, d) = local_coords_array[i];
138:       PetscCall(VecRestoreArrayRead(Vecxyz[d], &local_coords_array));
139:     }
140:     Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host);
141:   }

143:   /* Observation Coordinates */
144:   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", n_obs_global, dim);
145:   {
146:     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", n_obs_global, dim);
147:     for (d = 0; d < dim; ++d) {
148:       VecScatter         ctx;
149:       Vec                seq_vec;
150:       const PetscScalar *array;

152:       PetscCall(VecScatterCreateToAll(obs_vecs[d], &ctx, &seq_vec));
153:       PetscCall(VecScatterBegin(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD));
154:       PetscCall(VecScatterEnd(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD));

156:       PetscCall(VecGetArrayRead(seq_vec, &array));
157:       for (PetscInt j = 0; j < n_obs_global; ++j) obs_coords_host(j, d) = PetscRealPart(array[j]);
158:       PetscCall(VecRestoreArrayRead(seq_vec, &array));
159:       PetscCall(VecScatterDestroy(&ctx));
160:       PetscCall(VecDestroy(&seq_vec));
161:     }
162:     Kokkos::deep_copy(obs_coords_dev, obs_coords_host);
163:   }

165:   PetscInt rstart;
166:   PetscCall(VecGetOwnershipRange(Vecxyz[0], &rstart, NULL));

168:   /* Output Views */
169:   // LayoutLeft for coalesced access on GPU
170:   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace>    indices_dev("indices", n_vert_local, n_obs_vertex);
171:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> values_dev("values", n_vert_local, n_obs_vertex);

173:   /* Temporary storage for top-k per vertex */
174:   // LayoutLeft for coalesced access on GPU.
175:   // Note: For the insertion sort within a thread, LayoutRight would offer better cache locality for the thread's private list.
176:   // However, LayoutLeft is preferred for coalesced access across threads during the final weight computation and initialization.
177:   // Given the random access nature of the sort (divergence), we stick to the default GPU layout (Left).
178:   Kokkos::View<PetscReal **, Kokkos::LayoutLeft, MemSpace> best_dists_dev("best_dists", n_vert_local, n_obs_vertex);
179:   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace>  best_idxs_dev("best_idxs", n_vert_local, n_obs_vertex);

181:   /* Main Kernel */
182:   Kokkos::parallel_for(
183:     "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, n_vert_local), KOKKOS_LAMBDA(const PetscInt i) {
184:       PetscReal current_max_dist = PETSC_MAX_REAL;

186:       // Cache vertex coordinates in registers to avoid repeated global memory access
187:       // dim is small (<= 3), so this fits easily in registers
188:       PetscReal v_coords[3] = {0.0, 0.0, 0.0};
189:       for (PetscInt d = 0; d < dim; ++d) v_coords[d] = PetscRealPart(vertex_coords_dev(i, d));

191:       // Initialize with infinity
192:       for (PetscInt k = 0; k < n_obs_vertex; ++k) {
193:         best_dists_dev(i, k) = PETSC_MAX_REAL;
194:         best_idxs_dev(i, k)  = -1;
195:       }

197:       // Iterate over all observations
198:       for (PetscInt j = 0; j < n_obs_global; ++j) {
199:         PetscReal dist2 = 0.0;
200:         for (PetscInt d = 0; d < dim; ++d) {
201:           PetscReal diff = v_coords[d] - obs_coords_dev(j, d);
202:           dist2 += diff * diff;
203:         }

205:         // Check if this observation is closer than the furthest stored observation
206:         if (dist2 < current_max_dist) {
207:           // Insert sorted
208:           PetscInt pos = n_obs_vertex - 1;
209:           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
210:             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
211:             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
212:             pos--;
213:           }
214:           best_dists_dev(i, pos) = dist2;
215:           best_idxs_dev(i, pos)  = j;

217:           // Update current max distance
218:           current_max_dist = best_dists_dev(i, n_obs_vertex - 1);
219:         }
220:       }

222:       // Compute weights
223:       PetscReal radius2 = best_dists_dev(i, n_obs_vertex - 1);
224:       PetscReal radius  = std::sqrt(radius2);
225:       if (radius == 0.0) radius = 1.0;

227:       for (PetscInt k = 0; k < n_obs_vertex; ++k) {
228:         if (best_idxs_dev(i, k) != -1) {
229:           PetscReal dist    = std::sqrt(best_dists_dev(i, k));
230:           indices_dev(i, k) = best_idxs_dev(i, k);
231:           values_dev(i, k)  = GaspariCohn(dist, radius);
232:         } else {
233:           indices_dev(i, k) = -1; // Ignore this entry
234:           values_dev(i, k)  = 0.0;
235:         }
236:       }
237:     });

239:   /* Copy back to host and fill matrix */
240:   // Host views must be LayoutRight for MatSetValues (row-major)
241:   Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace>    indices_host("indices_host", n_vert_local, n_obs_vertex);
242:   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host("values_host", n_vert_local, n_obs_vertex);

244:   // Deep copy will handle layout conversion (transpose) if device views are LayoutLeft
245:   // Note: Kokkos::deep_copy cannot copy between different layouts if the memory spaces are different (e.g. GPU to Host).
246:   // We need an intermediate mirror view on the host with the same layout as the device view.
247:   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, Kokkos::HostSpace>    indices_host_left = Kokkos::create_mirror_view(indices_dev);
248:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> values_host_left  = Kokkos::create_mirror_view(values_dev);

250:   Kokkos::deep_copy(indices_host_left, indices_dev);
251:   Kokkos::deep_copy(values_host_left, values_dev);

253:   // Now copy from LayoutLeft host view to LayoutRight host view
254:   Kokkos::deep_copy(indices_host, indices_host_left);
255:   Kokkos::deep_copy(values_host, values_host_left);

257:   for (PetscInt i = 0; i < n_vert_local; ++i) {
258:     PetscInt globalRow = rstart + i;
259:     PetscCall(MatSetValues(*Q, 1, &globalRow, n_obs_vertex, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES));
260:   }

262:   /* Cleanup Phase 2 storage */
263:   for (d = 0; d < dim; ++d) PetscCall(VecDestroy(&obs_vecs[d]));
264:   PetscCall(PetscFree(obs_vecs));

266:   /* Assemble matrix */
267:   PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY));
268:   PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY));
269:   PetscFunctionReturn(PETSC_SUCCESS);
270: }