Actual source code: rosenbrock4.h

  1: #pragma once

  3: #include <petsctao.h>
  4: #include <petscsf.h>
  5: #include <petscdevice.h>
  6: #include <petscdevice_cupm.h>

  8: /*
  9:    User-defined application context - contains data needed by the
 10:    application-provided call-back routines that evaluate the function,
 11:    gradient, and hessian.
 12: */

 14: typedef struct _Rosenbrock {
 15:   PetscInt  bs; // each block of bs variables is one chained multidimensional rosenbrock problem
 16:   PetscInt  i_start, i_end;
 17:   PetscInt  c_start, c_end;
 18:   PetscReal alpha; // condition parameter
 19: } Rosenbrock;

 21: typedef struct _AppCtx *AppCtx;
 22: struct _AppCtx {
 23:   MPI_Comm      comm;
 24:   PetscInt      n; /* dimension */
 25:   PetscInt      n_local;
 26:   PetscInt      n_local_comp;
 27:   Rosenbrock    problem;
 28:   Vec           Hvalues; /* vector for writing COO values of this MPI process */
 29:   Vec           gvalues; /* vector for writing gradient values of this mpi process */
 30:   Vec           fvector;
 31:   PetscSF       off_process_scatter;
 32:   PetscSF       gscatter;
 33:   Vec           off_process_values; /* buffer for off-process values if chained */
 34:   PetscBool     test_lmvm;
 35:   PetscLogEvent event_f, event_g, event_fg;
 36: };

 38: /* -------------- User-defined routines ---------- */

 40: static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2)
 41: {
 42:   PetscScalar d = x_2 - x_1 * x_1;
 43:   PetscScalar e = 1.0 - x_1;
 44:   return alpha * d * d + e * e;
 45: }

 47: static const PetscLogDouble RosenbrockObjectiveFlops = 7.0;

 49: static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
 50: {
 51:   PetscScalar d  = x_2 - x_1 * x_1;
 52:   PetscScalar e  = 1.0 - x_1;
 53:   PetscScalar g2 = alpha * d * 2.0;

 55:   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
 56:   g[1] = g2;
 57: }

 59: static const PetscInt RosenbrockGradientFlops = 9.0;

 61: static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
 62: {
 63:   PetscScalar d  = x_2 - x_1 * x_1;
 64:   PetscScalar e  = 1.0 - x_1;
 65:   PetscScalar ad = alpha * d;
 66:   PetscScalar g2 = ad * 2.0;

 68:   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
 69:   g[1] = g2;
 70:   return ad * d + e * e;
 71: }

 73: static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0;

 75: static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4])
 76: {
 77:   PetscScalar d  = x_2 - x_1 * x_1;
 78:   PetscScalar g2 = alpha * d * 2.0;
 79:   PetscScalar h2 = -4.0 * alpha * x_1;

 81:   h[0] = -2.0 * (g2 + x_1 * h2) + 2.0;
 82:   h[1] = h[2] = h2;
 83:   h[3]        = 2.0 * alpha;
 84: }

 86: static const PetscLogDouble RosenbrockHessianFlops = 11.0;

 88: static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx)
 89: {
 90:   AppCtx             user;
 91:   PetscDeviceContext dctx;

 93:   PetscFunctionBegin;
 94:   PetscCall(PetscNew(ctx));
 95:   user       = *ctx;
 96:   user->comm = PETSC_COMM_WORLD;

 98:   /* Initialize problem parameters */
 99:   user->n             = 2;
100:   user->problem.alpha = 99.0;
101:   user->problem.bs    = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock
102:   user->test_lmvm     = PETSC_FALSE;
103:   /* Check for command line arguments to override defaults */
104:   PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL);
105:   PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL));
106:   PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL));
107:   PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL));
108:   PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve against LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL));
109:   PetscOptionsEnd();
110:   PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs);
111:   PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " does not divide problem size % " PetscInt_FMT, user->problem.bs, user->n);
112:   PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f));
113:   PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g));
114:   PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg));
115:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
116:   PetscCall(PetscDeviceContextSetUp(dctx));
117:   PetscFunctionReturn(PETSC_SUCCESS);
118: }

120: static PetscErrorCode AppCtxDestroy(AppCtx *ctx)
121: {
122:   AppCtx user;

124:   PetscFunctionBegin;
125:   user = *ctx;
126:   *ctx = NULL;
127:   PetscCall(VecDestroy(&user->Hvalues));
128:   PetscCall(VecDestroy(&user->gvalues));
129:   PetscCall(VecDestroy(&user->fvector));
130:   PetscCall(VecDestroy(&user->off_process_values));
131:   PetscCall(PetscSFDestroy(&user->off_process_scatter));
132:   PetscCall(PetscSFDestroy(&user->gscatter));
133:   PetscCall(PetscFree(user));
134:   PetscFunctionReturn(PETSC_SUCCESS);
135: }

137: static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian)
138: {
139:   Mat         H;
140:   PetscLayout layout;
141:   PetscInt    i_start, i_end, n_local_comp, nnz_local;
142:   PetscInt    c_start, c_end;
143:   PetscInt   *coo_i;
144:   PetscInt   *coo_j;
145:   PetscInt    bs = user->problem.bs;
146:   VecType     vec_type;

148:   PetscFunctionBegin;
149:   /* Partition the optimization variables and the computations.
150:      There are (bs - 1) contributions to the objective function for every (bs)
151:      degrees of freedom. */
152:   PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout));
153:   PetscCall(PetscLayoutSetUp(layout));
154:   PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end));
155:   user->problem.i_start = i_start;
156:   user->problem.i_end   = i_end;
157:   user->n_local         = i_end - i_start;
158:   user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs);
159:   user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs);
160:   user->n_local_comp = n_local_comp = c_end - c_start;

162:   PetscCall(MatCreate(user->comm, Hessian));
163:   H = *Hessian;
164:   PetscCall(MatSetLayouts(H, layout, layout));
165:   PetscCall(PetscLayoutDestroy(&layout));
166:   PetscCall(MatSetType(H, MATAIJ));
167:   PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE));
168:   PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE));
169:   PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE));
170:   PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE));
171:   PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE));
172:   PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */

174:   nnz_local = n_local_comp * 4;
175:   PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j));
176:   /* Instead of having one computation thread per row of the matrix,
177:      this example uses one thread per contribution to the objective
178:      function.  Each contribution to the objective function relates
179:      two adjacent degrees of freedom, so each contribution to
180:      the objective function adds a 2x2 block into the matrix.
181:      We describe these 2x2 blocks in COO format. */
182:   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) {
183:     PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1);

185:     coo_i[k + 0] = i;
186:     coo_i[k + 1] = i;
187:     coo_i[k + 2] = i + 1;
188:     coo_i[k + 3] = i + 1;

190:     coo_j[k + 0] = i;
191:     coo_j[k + 1] = i + 1;
192:     coo_j[k + 2] = i;
193:     coo_j[k + 3] = i + 1;
194:   }
195:   PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j));
196:   PetscCall(PetscFree2(coo_i, coo_j));

198:   PetscCall(MatGetVecType(H, &vec_type));
199:   PetscCall(VecCreate(user->comm, &user->Hvalues));
200:   PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE));
201:   PetscCall(VecSetType(user->Hvalues, vec_type));

203:   // vector to collect contributions to the objective
204:   PetscCall(VecCreate(user->comm, &user->fvector));
205:   PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE));
206:   PetscCall(VecSetType(user->fvector, vec_type));

208:   { /* If we are using a device (such as a GPU), run some computations that will
209:        warm up its linear algebra runtime before the problem we actually want
210:        to profile */

212:     PetscMemType       memtype;
213:     const PetscScalar *a;

215:     PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype));
216:     PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a));

218:     if (memtype == PETSC_MEMTYPE_DEVICE) {
219:       PetscLogStage      warmup;
220:       Mat                A, AtA;
221:       Vec                x, b;
222:       PetscInt           warmup_size = 1000;
223:       PetscDeviceContext dctx;

225:       PetscCall(PetscLogStageRegister("Device Warmup", &warmup));
226:       PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE));

228:       PetscCall(PetscLogStagePush(warmup));
229:       PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A));
230:       PetscCall(MatSetRandom(A, NULL));
231:       PetscCall(MatCreateVecs(A, &x, &b));
232:       PetscCall(VecSetRandom(x, NULL));

234:       PetscCall(MatMult(A, x, b));
235:       PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &AtA));
236:       PetscCall(MatShift(AtA, (PetscScalar)warmup_size));
237:       PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE));
238:       PetscCall(MatCholeskyFactor(AtA, NULL, NULL));
239:       PetscCall(MatDestroy(&AtA));
240:       PetscCall(VecDestroy(&b));
241:       PetscCall(VecDestroy(&x));
242:       PetscCall(MatDestroy(&A));
243:       PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
244:       PetscCall(PetscDeviceContextSynchronize(dctx));
245:       PetscCall(PetscLogStagePop());
246:     }
247:   }
248:   PetscFunctionReturn(PETSC_SUCCESS);
249: }

251: static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient)
252: {
253:   VecType     vec_type;
254:   PetscInt    n_coo, *coo_i, i_start, i_end;
255:   Vec         x;
256:   PetscInt    n_recv;
257:   PetscSFNode recv;
258:   PetscLayout layout;
259:   PetscInt    c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs;

261:   PetscFunctionBegin;
262:   PetscCall(MatCreateVecs(H, solution, gradient));
263:   x = *solution;
264:   PetscCall(VecGetOwnershipRange(x, &i_start, &i_end));
265:   PetscCall(VecGetType(x, &vec_type));
266:   // create scatter for communicating values
267:   PetscCall(VecGetLayout(x, &layout));
268:   n_recv = 0;
269:   if (user->n_local_comp && i_end < user->n) {
270:     PetscMPIInt rank;
271:     PetscInt    index;

273:     n_recv = 1;
274:     PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index));
275:     recv.rank  = rank;
276:     recv.index = index;
277:   }
278:   PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter));
279:   PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES));
280:   PetscCall(VecCreate(user->comm, &user->off_process_values));
281:   PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE));
282:   PetscCall(VecSetType(user->off_process_values, vec_type));
283:   PetscCall(VecZeroEntries(user->off_process_values));

285:   // create COO data for writing the gradient
286:   n_coo = user->n_local_comp * 2;
287:   PetscCall(PetscMalloc1(n_coo, &coo_i));
288:   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) {
289:     PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1));

291:     coo_i[k + 0] = i;
292:     coo_i[k + 1] = i + 1;
293:   }
294:   PetscCall(PetscSFCreate(user->comm, &user->gscatter));
295:   PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i));
296:   PetscCall(PetscSFSetUp(user->gscatter));
297:   PetscCall(PetscFree(coo_i));
298:   PetscCall(VecCreate(user->comm, &user->gvalues));
299:   PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE));
300:   PetscCall(VecSetType(user->gvalues, vec_type));
301:   PetscFunctionReturn(PETSC_SUCCESS);
302: }

304: #if PetscDefined(USING_CUPMCC)

306:   #if PetscDefined(USING_NVCC)
307: typedef cudaStream_t cupmStream_t;
308:     #define PetscCUPMLaunch(...) \
309:       do { \
310:         __VA_ARGS__; \
311:         PetscCallCUDA(cudaGetLastError()); \
312:       } while (0)
313:   #elif PetscDefined(USING_HCC)
314:     #define PetscCUPMLaunch(...) \
315:       do { \
316:         __VA_ARGS__; \
317:         PetscCallHIP(hipGetLastError()); \
318:       } while (0)
319: typedef hipStream_t cupmStream_t;
320:   #endif

322: // x: on-process optimization variables
323: // o: buffer that contains the next optimization variable after the variables on this process
324: template <typename T>
325: PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept
326: {
327:   PetscInt idx         = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid
328:   PetscInt num_threads = gridDim.x * blockDim.x;

330:   for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) {
331:     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
332:     PetscScalar x_a = x[i - r.i_start];
333:     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];

335:     func(k, x_a, x_b);
336:   }
337:   return;
338: }

340: PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
341: {
342:   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); });
343: }

345: PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
346: {
347:   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); });
348: }

350: PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
351: {
352:   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); });
353: }

355: PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
356: {
357:   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); });
358: }

360: static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
361: {
362:   PetscInt n_comp = r.c_end - r.c_start;

364:   PetscFunctionBegin;
365:   if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec));
366:   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp));
367:   PetscFunctionReturn(PETSC_SUCCESS);
368: }

370: static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
371: {
372:   PetscInt n_comp = r.c_end - r.c_start;

374:   PetscFunctionBegin;
375:   if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g));
376:   PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp));
377:   PetscFunctionReturn(PETSC_SUCCESS);
378: }

380: static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
381: {
382:   PetscInt n_comp = r.c_end - r.c_start;

384:   PetscFunctionBegin;
385:   if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g));
386:   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp));
387:   PetscFunctionReturn(PETSC_SUCCESS);
388: }

390: static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
391: {
392:   PetscInt n_comp = r.c_end - r.c_start;

394:   PetscFunctionBegin;
395:   if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h));
396:   PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp));
397:   PetscFunctionReturn(PETSC_SUCCESS);
398: }
399: #endif

401: static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f)
402: {
403:   PetscReal _f = 0.0;

405:   PetscFunctionBegin;
406:   for (PetscInt c = r.c_start; c < r.c_end; c++) {
407:     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
408:     PetscScalar x_a = x[i - r.i_start];
409:     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];

411:     _f += RosenbrockObjective(r.alpha, x_a, x_b);
412:   }
413:   *f = _f;
414:   PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start)));
415:   PetscFunctionReturn(PETSC_SUCCESS);
416: }

418: static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
419: {
420:   PetscFunctionBegin;
421:   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
422:     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
423:     PetscScalar x_a = x[i - r.i_start];
424:     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];

426:     RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]);
427:   }
428:   PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start)));
429:   PetscFunctionReturn(PETSC_SUCCESS);
430: }

432: static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[])
433: {
434:   PetscReal _f = 0.0;

436:   PetscFunctionBegin;
437:   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
438:     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
439:     PetscScalar x_a = x[i - r.i_start];
440:     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];

442:     _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]);
443:   }
444:   *f = _f;
445:   PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start)));
446:   PetscFunctionReturn(PETSC_SUCCESS);
447: }

449: static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
450: {
451:   PetscFunctionBegin;
452:   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
453:     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
454:     PetscScalar x_a = x[i - r.i_start];
455:     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];

457:     RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]);
458:   }
459:   PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start)));
460:   PetscFunctionReturn(PETSC_SUCCESS);
461: }

463: /* -------------------------------------------------------------------- */

465: static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr)
466: {
467:   AppCtx             user    = (AppCtx)ptr;
468:   PetscReal          f_local = 0.0;
469:   const PetscScalar *x;
470:   const PetscScalar *o = NULL;
471:   PetscMemType       memtype_x;

473:   PetscFunctionBeginUser;
474:   PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL));
475:   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
476:   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
477:   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
478:   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
479:   if (memtype_x == PETSC_MEMTYPE_HOST) {
480:     PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local));
481:     PetscCallMPI(MPI_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm));
482: #if PetscDefined(USING_CUPMCC)
483:   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
484:     PetscScalar       *_fvec;
485:     PetscScalar        f_scalar;
486:     cupmStream_t      *stream;
487:     PetscDeviceContext dctx;

489:     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
490:     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
491:     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
492:     PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec));
493:     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
494:     PetscCall(VecSum(user->fvector, &f_scalar));
495:     *f = PetscRealPart(f_scalar);
496: #endif
497:   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
498:   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
499:   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
500:   PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL));
501:   PetscFunctionReturn(PETSC_SUCCESS);
502: }

504: static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr)
505: {
506:   AppCtx             user = (AppCtx)ptr;
507:   PetscScalar       *g;
508:   const PetscScalar *x;
509:   const PetscScalar *o = NULL;
510:   PetscMemType       memtype_x, memtype_g;

512:   PetscFunctionBeginUser;
513:   PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL));
514:   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
515:   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
516:   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
517:   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
518:   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
519:   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
520:   if (memtype_x == PETSC_MEMTYPE_HOST) {
521:     PetscCall(RosenbrockGradient_Host(user->problem, x, o, g));
522: #if PetscDefined(USING_CUPMCC)
523:   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
524:     cupmStream_t      *stream;
525:     PetscDeviceContext dctx;

527:     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
528:     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
529:     PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g));
530: #endif
531:   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);
532:   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
533:   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
534:   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
535:   PetscCall(VecZeroEntries(G));
536:   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
537:   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
538:   PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL));
539:   PetscFunctionReturn(PETSC_SUCCESS);
540: }

542: /*
543:     FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X).

545:     Input Parameters:
546: .   tao  - the Tao context
547: .   X    - input vector
548: .   ptr  - optional user-defined context, as set by TaoSetObjectiveGradient()

550:     Output Parameters:
551: .   G - vector containing the newly evaluated gradient
552: .   f - function value

554:     Note:
555:     Some optimization methods ask for the function and the gradient evaluation
556:     at the same time.  Evaluating both at once may be more efficient that
557:     evaluating each separately.
558: */
559: static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr)
560: {
561:   AppCtx             user    = (AppCtx)ptr;
562:   PetscReal          f_local = 0.0;
563:   PetscScalar       *g;
564:   const PetscScalar *x;
565:   const PetscScalar *o = NULL;
566:   PetscMemType       memtype_x, memtype_g;

568:   PetscFunctionBeginUser;
569:   PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL));
570:   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
571:   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
572:   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
573:   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
574:   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
575:   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
576:   if (memtype_x == PETSC_MEMTYPE_HOST) {
577:     PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g));
578:     PetscCallMPI(MPI_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD));
579: #if PetscDefined(USING_CUPMCC)
580:   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
581:     PetscScalar       *_fvec;
582:     PetscScalar        f_scalar;
583:     cupmStream_t      *stream;
584:     PetscDeviceContext dctx;

586:     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
587:     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
588:     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
589:     PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g));
590:     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
591:     PetscCall(VecSum(user->fvector, &f_scalar));
592:     *f = PetscRealPart(f_scalar);
593: #endif
594:   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);

596:   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
597:   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
598:   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
599:   PetscCall(VecZeroEntries(G));
600:   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
601:   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
602:   PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL));
603:   PetscFunctionReturn(PETSC_SUCCESS);
604: }

606: /* ------------------------------------------------------------------- */
607: /*
608:    FormHessian - Evaluates Hessian matrix.

610:    Input Parameters:
611: .  tao   - the Tao context
612: .  x     - input vector
613: .  ptr   - optional user-defined context, as set by TaoSetHessian()

615:    Output Parameters:
616: .  H     - Hessian matrix

618:    Note:  Providing the Hessian may not be necessary.  Only some solvers
619:    require this matrix.
620: */
621: static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr)
622: {
623:   AppCtx             user = (AppCtx)ptr;
624:   PetscScalar       *h;
625:   const PetscScalar *x;
626:   const PetscScalar *o = NULL;
627:   PetscMemType       memtype_x, memtype_h;

629:   PetscFunctionBeginUser;
630:   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
631:   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
632:   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
633:   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
634:   PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h));
635:   PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype");
636:   if (memtype_x == PETSC_MEMTYPE_HOST) {
637:     PetscCall(RosenbrockHessian_Host(user->problem, x, o, h));
638: #if PetscDefined(USING_CUPMCC)
639:   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
640:     cupmStream_t      *stream;
641:     PetscDeviceContext dctx;

643:     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
644:     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
645:     PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h));
646: #endif
647:   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsupported memtype %d", (int)memtype_x);

649:   PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES));
650:   PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h));

652:   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
653:   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));

655:   if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
656:   PetscFunctionReturn(PETSC_SUCCESS);
657: }

659: static PetscErrorCode TestLMVM(Tao tao)
660: {
661:   KSP       ksp;
662:   PC        pc;
663:   PetscBool is_lmvm;

665:   PetscFunctionBegin;
666:   PetscCall(TaoGetKSP(tao, &ksp));
667:   if (!ksp) PetscFunctionReturn(PETSC_SUCCESS);
668:   PetscCall(KSPGetPC(ksp, &pc));
669:   PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm));
670:   if (is_lmvm) {
671:     Mat       M;
672:     Vec       in, out, out2;
673:     PetscReal mult_solve_dist;
674:     Vec       x;

676:     PetscCall(PCLMVMGetMatLMVM(pc, &M));
677:     PetscCall(TaoGetSolution(tao, &x));
678:     PetscCall(VecDuplicate(x, &in));
679:     PetscCall(VecDuplicate(x, &out));
680:     PetscCall(VecDuplicate(x, &out2));
681:     PetscCall(VecSetRandom(in, NULL));
682:     PetscCall(MatMult(M, in, out));
683:     PetscCall(MatSolve(M, out, out2));

685:     PetscCall(VecAXPY(out2, -1.0, in));
686:     PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist));
687:     if (mult_solve_dist < 1.e-11) {
688:       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n"));
689:     } else if (mult_solve_dist < 1.e-6) {
690:       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n"));
691:     } else {
692:       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist));
693:     }
694:     PetscCall(VecDestroy(&in));
695:     PetscCall(VecDestroy(&out));
696:     PetscCall(VecDestroy(&out2));
697:   }
698:   PetscFunctionReturn(PETSC_SUCCESS);
699: }

701: static PetscErrorCode RosenbrockMain(void)
702: {
703:   Vec           x;    /* solution vector */
704:   Vec           g;    /* gradient vector */
705:   Mat           H;    /* Hessian matrix */
706:   Tao           tao;  /* Tao solver context */
707:   AppCtx        user; /* user-defined application context */
708:   PetscLogStage solve;

710:   /* Initialize TAO and PETSc */
711:   PetscFunctionBegin;
712:   PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve));

714:   PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user));
715:   PetscCall(CreateHessian(user, &H));
716:   PetscCall(CreateVectors(user, H, &x, &g));

718:   /* The TAO code begins here */

720:   PetscCall(TaoCreate(user->comm, &tao));
721:   PetscCall(VecZeroEntries(x));
722:   PetscCall(TaoSetSolution(tao, x));

724:   /* Set routines for function, gradient, hessian evaluation */
725:   PetscCall(TaoSetObjective(tao, FormObjective, user));
726:   PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user));
727:   PetscCall(TaoSetGradient(tao, g, FormGradient, user));
728:   PetscCall(TaoSetHessian(tao, H, H, FormHessian, user));

730:   PetscCall(TaoSetFromOptions(tao));

732:   /* SOLVE THE APPLICATION */
733:   PetscCall(PetscLogStagePush(solve));
734:   PetscCall(TaoSolve(tao));
735:   PetscCall(PetscLogStagePop());

737:   if (user->test_lmvm) PetscCall(TestLMVM(tao));

739:   PetscCall(TaoDestroy(&tao));
740:   PetscCall(VecDestroy(&g));
741:   PetscCall(VecDestroy(&x));
742:   PetscCall(MatDestroy(&H));
743:   PetscCall(AppCtxDestroy(&user));
744:   PetscFunctionReturn(PETSC_SUCCESS);
745: }