Actual source code: dalocalizationletkf.kokkos.cxx

  1: #include <petsc.h>
  2: #include <petscmat.h>
  3: #include <petsc_kokkos.hpp>
  4: #include <cmath>
  5: #include <cstdlib>
  6: #include <algorithm>
  7: #include <Kokkos_Core.hpp>

  9: typedef struct {
 10:   PetscReal distance;
 11:   PetscInt  obs_index;
 12: } DistObsPair;

 14: KOKKOS_INLINE_FUNCTION
 15: static PetscReal GaspariCohn(PetscReal distance, PetscReal radius)
 16: {
 17:   if (radius <= 0.0) return 0.0;
 18:   const PetscReal r = distance / radius;

 20:   if (r >= 2.0) return 0.0;

 22:   const PetscReal r2 = r * r;
 23:   const PetscReal r3 = r2 * r;
 24:   const PetscReal r4 = r3 * r;
 25:   const PetscReal r5 = r4 * r;

 27:   if (r <= 1.0) {
 28:     // Region [0, 1]
 29:     return -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0;
 30:   } else {
 31:     // Region [1, 2]
 32:     return (1.0 / 12.0) * r5 - 0.5 * r4 + 0.625 * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r;
 33:   }
 34: }

 36: #define RADIUS_FACTOR 1.1

 38: template <class ViewType>
 39: struct RadiusStatsFunctor {
 40:   ViewType best_dists;
 41:   PetscInt n_obs_vertex;

 43:   struct value_type {
 44:     double sum, sq_sum;
 45:   };

 47:   KOKKOS_INLINE_FUNCTION void operator()(const PetscInt i, value_type &update) const
 48:   {
 49:     PetscReal r2 = best_dists(i, n_obs_vertex - 1);
 50:     PetscReal r  = std::sqrt(r2);
 51:     r *= RADIUS_FACTOR;
 52:     if (r == 0.0) r = 1.0;
 53:     update.sum += r;
 54:     update.sq_sum += r * r;
 55:   }

 57:   KOKKOS_INLINE_FUNCTION void init(value_type &update) const
 58:   {
 59:     update.sum    = 0.0;
 60:     update.sq_sum = 0.0;
 61:   }

 63:   KOKKOS_INLINE_FUNCTION void join(value_type &dest, const value_type &src) const
 64:   {
 65:     dest.sum += src.sum;
 66:     dest.sq_sum += src.sq_sum;
 67:   }
 68: };

 70: /*@
 71:   PetscDALETKFGetLocalizationMatrix - Compute localization weight matrix for LETKF [move to ml/da/interface]

 73:   Collective

 75:   Input Parameters:
 76: + n_obs_vertex - Number of observations to localize to per vertex
 77: . n_dof        - Number of degrees of freedom
 78: . Vecxyz       - Array of vectors containing the vertex coordinates
 79: . bd           - Array of boundary extents per dimension (used for periodicity)
 80: - H            - Observation operator matrix

 82:   Output Parameter:
 83: . Q - Localization weight matrix (sparse, AIJ format)

 85:   Level: intermediate

 87:   Notes:
 88:   The output matrix Q has dimensions (n_vert_global x n_obs_global) where
 89:   n_vert_global is the number of vertices in the DMPlex. Each row contains
 90:   exactly n_obs_vertex non-zero entries corresponding to the nearest
 91:   observations, weighted by the Gaspari-Cohn fifth-order piecewise
 92:   rational function.

 94:   The observation locations are computed as H * V where V is the vector
 95:   of vertex coordinates. The localization weights ensure smooth tapering
 96:   of observation influence with distance.

 98:   Kokkos is required for this routine.

100: .seealso: [](ch_da), `PetscDALETKFSetLocalization()`
101: @*/
102: PetscErrorCode PetscDALETKFGetLocalizationMatrix(const PetscInt n_obs_vertex, const PetscInt n_dof, Vec Vecxyz[3], PetscReal bd[3], Mat H, Mat *Q)
103: {
104:   PetscInt dim = 0, n_vert_local, d, n_obs_global, n_obs_local;
105:   Vec     *obs_vecs;
106:   MPI_Comm comm;

108:   PetscFunctionBegin;
110:   PetscAssertPointer(Q, 6);

112:   PetscCall(PetscKokkosInitializeCheck());
113:   PetscCall(PetscObjectGetComm((PetscObject)H, &comm));
114:   PetscCall(MatGetLocalSize(H, &n_obs_local, NULL));
115:   PetscCall(MatGetSize(H, &n_obs_global, NULL));
116:   /* Infer dim from the number of vectors in Vecxyz */
117:   for (d = 0; d < 3; ++d) {
118:     if (Vecxyz[d]) dim++;
119:     else break;
120:   }
121:   PetscCall(VecGetLocalSize(Vecxyz[0], &n_vert_local));

123:   /* Check H dimensions */
124:   // If n_obs_global < n_obs_vertex, we will pad with -1 indices and 0.0 weights. ???
125:   // This is not an error condition, but rather a case where we have fewer observations than requested neighbors.
126:   PetscCheck(dim > 0, comm, PETSC_ERR_ARG_WRONG, "Dim must be > 0");
127:   PetscCheck(n_obs_vertex > 0 && n_obs_vertex <= n_obs_global, comm, PETSC_ERR_ARG_WRONG, "n_obs_vertex must be > 0 and <= n_obs_global");

129:   /* Allocate storage for observation locations */
130:   PetscCall(PetscMalloc1(dim, &obs_vecs));

132:   /* Compute observation locations per dimension */
133:   for (d = 0; d < dim; ++d) {
134:     PetscCall(MatCreateVecs(H, NULL, &obs_vecs[d]));
135:     PetscCall(MatMult(H, Vecxyz[d], obs_vecs[d]));
136:   }

138:   /* Create output matrix Q in N/n_dof x P */
139:   PetscCall(MatCreate(comm, Q));
140:   PetscCall(MatSetSizes(*Q, n_vert_local, n_obs_local, PETSC_DETERMINE, n_obs_global));
141:   PetscCall(MatSetType(*Q, MATAIJ));
142:   PetscCall(MatSeqAIJSetPreallocation(*Q, n_obs_vertex, NULL));
143:   PetscCall(MatMPIAIJSetPreallocation(*Q, n_obs_vertex, NULL, n_obs_vertex, NULL));
144:   PetscCall(MatSetFromOptions(*Q));
145:   PetscCall(MatSetUp(*Q));

147:   PetscCall(PetscInfo((PetscObject)*Q, "Computing LETKF localization matrix: %" PetscInt_FMT " vertices, %" PetscInt_FMT " observations, %" PetscInt_FMT " observations per vertex\n", n_vert_local, n_obs_global, n_obs_vertex));

149:   /* Prepare Kokkos Views */
150:   using ExecSpace = Kokkos::DefaultExecutionSpace;
151:   using MemSpace  = ExecSpace::memory_space;

153:   /* Vertex Coordinates */
154:   // Use LayoutLeft for coalesced access on GPU (i is contiguous)
155:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> vertex_coords_dev("vertex_coords", n_vert_local, dim);
156:   {
157:     // Host view must match the data layout from VecGetArray (d-major, i-minor implies LayoutLeft for (i,d) view)
158:     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", n_vert_local, dim);
159:     for (d = 0; d < dim; ++d) {
160:       const PetscScalar *local_coords_array;
161:       PetscCall(VecGetArrayRead(Vecxyz[d], &local_coords_array));
162:       // Copy data. Since vertex_coords_host is LayoutLeft, &vertex_coords_host(0, d) is the start of column d.
163:       for (PetscInt i = 0; i < n_vert_local; ++i) vertex_coords_host(i, d) = local_coords_array[i];
164:       PetscCall(VecRestoreArrayRead(Vecxyz[d], &local_coords_array));
165:     }
166:     Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host);
167:   }

169:   /* Observation Coordinates */
170:   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", n_obs_global, dim);
171:   {
172:     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", n_obs_global, dim);
173:     for (d = 0; d < dim; ++d) {
174:       VecScatter         ctx;
175:       Vec                seq_vec;
176:       const PetscScalar *array;

178:       PetscCall(VecScatterCreateToAll(obs_vecs[d], &ctx, &seq_vec));
179:       PetscCall(VecScatterBegin(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD));
180:       PetscCall(VecScatterEnd(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD));

182:       PetscCall(VecGetArrayRead(seq_vec, &array));
183:       for (PetscInt j = 0; j < n_obs_global; ++j) obs_coords_host(j, d) = PetscRealPart(array[j]);
184:       PetscCall(VecRestoreArrayRead(seq_vec, &array));
185:       PetscCall(VecScatterDestroy(&ctx));
186:       PetscCall(VecDestroy(&seq_vec));
187:     }
188:     Kokkos::deep_copy(obs_coords_dev, obs_coords_host);
189:   }

191:   PetscInt rstart;
192:   PetscCall(VecGetOwnershipRange(Vecxyz[0], &rstart, NULL));

194:   /* Output Views */
195:   // LayoutLeft for coalesced access on GPU
196:   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace>    indices_dev("indices", n_vert_local, n_obs_vertex);
197:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> values_dev("values", n_vert_local, n_obs_vertex);

199:   /* Temporary storage for top-k per vertex */
200:   // LayoutLeft for coalesced access on GPU.
201:   // Note: For the insertion sort within a thread, LayoutRight would offer better cache locality for the thread's private list.
202:   // However, LayoutLeft is preferred for coalesced access across threads during the final weight computation and initialization.
203:   // Given the random access nature of the sort (divergence), we stick to the default GPU layout (Left).
204:   Kokkos::View<PetscReal **, Kokkos::LayoutLeft, MemSpace> best_dists_dev("best_dists", n_vert_local, n_obs_vertex);
205:   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace>  best_idxs_dev("best_idxs", n_vert_local, n_obs_vertex);

207:   /* Copy boundary data to device */
208:   Kokkos::View<PetscReal *, MemSpace> bd_dev("bd_dev", dim);
209:   {
210:     Kokkos::View<PetscReal *, Kokkos::HostSpace> bd_host("bd_host", dim);
211:     for (PetscInt d = 0; d < dim; ++d) bd_host(d) = bd[d];
212:     Kokkos::deep_copy(bd_dev, bd_host);
213:   }

215:   /* Main Kernel */
216:   Kokkos::parallel_for(
217:     "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, n_vert_local), KOKKOS_LAMBDA(const PetscInt i) {
218:       PetscReal current_max_dist = PETSC_MAX_REAL;

220:       // Cache vertex coordinates in registers to avoid repeated global memory access
221:       // dim is small (<= 3), so this fits easily in registers
222:       PetscReal v_coords[3] = {0.0, 0.0, 0.0};
223:       for (PetscInt d = 0; d < dim; ++d) v_coords[d] = PetscRealPart(vertex_coords_dev(i, d));

225:       // Initialize with infinity
226:       for (PetscInt k = 0; k < n_obs_vertex; ++k) {
227:         best_dists_dev(i, k) = PETSC_MAX_REAL;
228:         best_idxs_dev(i, k)  = -1;
229:       }

231:       // Iterate over all observations
232:       for (PetscInt j = 0; j < n_obs_global; ++j) {
233:         PetscReal dist2 = 0.0;
234:         for (PetscInt d = 0; d < dim; ++d) {
235:           PetscReal diff = v_coords[d] - obs_coords_dev(j, d);
236:           if (bd_dev(d) != 0) { // Periodic boundary
237:             PetscReal domain_size = bd_dev(d);
238:             if (diff > 0.5 * domain_size) diff -= domain_size;
239:             else if (diff < -0.5 * domain_size) diff += domain_size;
240:           }
241:           dist2 += diff * diff;
242:         }

244:         // Check if this observation is closer than the furthest stored observation
245:         if (dist2 < current_max_dist) {
246:           // Insert sorted
247:           PetscInt pos = n_obs_vertex - 1;
248:           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
249:             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
250:             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
251:             pos--;
252:           }
253:           best_dists_dev(i, pos) = dist2;
254:           best_idxs_dev(i, pos)  = j;

256:           // Update current max distance
257:           current_max_dist = best_dists_dev(i, n_obs_vertex - 1);
258:         }
259:       }
260:       // Compute weights
261:       PetscReal radius2 = best_dists_dev(i, n_obs_vertex - 1);
262:       PetscReal radius  = std::sqrt(radius2);
263:       radius *= RADIUS_FACTOR;
264:       if (radius == 0.0) radius = 1.0;

266:       for (PetscInt k = 0; k < n_obs_vertex; ++k) {
267:         if (best_idxs_dev(i, k) != -1) {
268:           PetscReal dist    = std::sqrt(best_dists_dev(i, k));
269:           indices_dev(i, k) = best_idxs_dev(i, k);
270:           values_dev(i, k)  = GaspariCohn(dist, radius);
271:         } else {
272:           indices_dev(i, k) = -1; // Ignore this entry
273:           values_dev(i, k)  = 0.0;
274:         }
275:       }
276:     });

278:   /* Copy back to host and fill matrix */
279:   // Host views must be LayoutRight for MatSetValues (row-major)
280:   Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace>    indices_host("indices_host", n_vert_local, n_obs_vertex);
281:   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host("values_host", n_vert_local, n_obs_vertex);

283:   // Deep copy will handle layout conversion (transpose) if device views are LayoutLeft
284:   // Note: Kokkos::deep_copy cannot copy between different layouts if the memory spaces are different (e.g. GPU to Host).
285:   // We need an intermediate mirror view on the host with the same layout as the device view.
286:   Kokkos::View<PetscInt **, Kokkos::LayoutLeft, Kokkos::HostSpace>    indices_host_left = Kokkos::create_mirror_view(indices_dev);
287:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> values_host_left  = Kokkos::create_mirror_view(values_dev);

289:   Kokkos::deep_copy(indices_host_left, indices_dev);
290:   Kokkos::deep_copy(values_host_left, values_dev);

292:   // Now copy from LayoutLeft host view to LayoutRight host view
293:   Kokkos::deep_copy(indices_host, indices_host_left);
294:   Kokkos::deep_copy(values_host, values_host_left);

296:   for (PetscInt i = 0; i < n_vert_local; ++i) {
297:     PetscInt globalRow = rstart + i;
298:     PetscCall(MatSetValues(*Q, 1, &globalRow, n_obs_vertex, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES));
299:   }

301:   /* Compute mean and std dev of localization radius */
302:   {
303:     using FunctorType = RadiusStatsFunctor<decltype(best_dists_dev)>;
304:     typename FunctorType::value_type result;
305:     Kokkos::parallel_reduce("ComputeRadiusStats", Kokkos::RangePolicy<ExecSpace>(0, n_vert_local), FunctorType{best_dists_dev, n_obs_vertex}, result);

307:     if (n_vert_local > 0) {
308:       double mean   = result.sum / n_vert_local;
309:       double var    = (result.sq_sum / n_vert_local) - (mean * mean);
310:       double stddev = (var > 1e-1 * PETSC_SQRT_MACHINE_EPSILON) ? std::sqrt(var) : 0.0;
311:       PetscCall(PetscInfo((PetscObject)obs_vecs[0], "LETKF localization radius: mean %g, std dev %g\n", mean, stddev));
312:     }
313:   }

315:   /* Cleanup storage */
316:   for (d = 0; d < dim; ++d) PetscCall(VecDestroy(&obs_vecs[d]));
317:   PetscCall(PetscFree(obs_vecs));

319:   /* Assemble matrix */
320:   PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY));
321:   PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY));
322:   PetscFunctionReturn(PETSC_SUCCESS);
323: }