Actual source code: ex4.c

  1: static char help[] = "Simple example to test separable objective optimizers.\n";

  3: #include <petsc.h>
  4: #include <petsctao.h>
  5: #include <petscvec.h>
  6: #include <petscmath.h>

  8: #define NWORKLEFT  4
  9: #define NWORKRIGHT 12

 11: typedef struct _UserCtx {
 12:   PetscInt    m;       /* The row dimension of F */
 13:   PetscInt    n;       /* The column dimension of F */
 14:   PetscInt    matops;  /* Matrix format. 0 for stencil, 1 for random */
 15:   PetscInt    iter;    /* Number of iterations for ADMM */
 16:   PetscReal   hStart;  /* Starting point for Taylor test */
 17:   PetscReal   hFactor; /* Taylor test step factor */
 18:   PetscReal   hMin;    /* Taylor test end goal */
 19:   PetscReal   alpha;   /* regularization constant applied to || x ||_p */
 20:   PetscReal   eps;     /* small constant for approximating gradient of || x ||_1 */
 21:   PetscReal   mu;      /* the augmented Lagrangian term in ADMM */
 22:   PetscReal   abstol;
 23:   PetscReal   reltol;
 24:   Mat         F;                     /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
 25:   Mat         W;                     /* Workspace matrix. ATA */
 26:   Mat         Hm;                    /* Hessian Misfit*/
 27:   Mat         Hr;                    /* Hessian Reg*/
 28:   Vec         d;                     /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
 29:   Vec         workLeft[NWORKLEFT];   /* Workspace for temporary vec */
 30:   Vec         workRight[NWORKRIGHT]; /* Workspace for temporary vec */
 31:   NormType    p;
 32:   PetscRandom rctx;
 33:   PetscBool   soft;
 34:   PetscBool   taylor;   /* Flag to determine whether to run Taylor test or not */
 35:   PetscBool   use_admm; /* Flag to determine whether to run Taylor test or not */
 36: } *UserCtx;

 38: static PetscErrorCode CreateRHS(UserCtx ctx)
 39: {
 40:   PetscFunctionBegin;
 41:   /* build the rhs d in ctx */
 42:   PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d));
 43:   PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
 44:   PetscCall(VecSetFromOptions(ctx->d));
 45:   PetscCall(VecSetRandom(ctx->d, ctx->rctx));
 46:   PetscFunctionReturn(PETSC_SUCCESS);
 47: }

 49: static PetscErrorCode CreateMatrix(UserCtx ctx)
 50: {
 51:   PetscInt      Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
 52:   PetscLogStage stage;

 54:   PetscFunctionBegin;
 55:   /* build the matrix F in ctx */
 56:   PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F));
 57:   PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
 58:   PetscCall(MatSetType(ctx->F, MATAIJ));                          /* TODO: Decide specific SetType other than dummy*/
 59:   PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
 60:   PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
 61:   PetscCall(MatSetUp(ctx->F));
 62:   PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
 63:   PetscCall(PetscLogStageRegister("Assembly", &stage));
 64:   PetscCall(PetscLogStagePush(stage));

 66:   /* Set matrix elements in  2-D five point stencil format. */
 67:   if (!ctx->matops) {
 68:     PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
 69:     gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
 70:     PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
 71:     for (Ii = Istart; Ii < Iend; Ii++) {
 72:       i   = Ii / gridN;
 73:       j   = Ii % gridN;
 74:       I_n = i * gridN + j + 1;
 75:       if (j + 1 >= gridN) I_n = -1;
 76:       I_s = i * gridN + j - 1;
 77:       if (j - 1 < 0) I_s = -1;
 78:       I_e = (i + 1) * gridN + j;
 79:       if (i + 1 >= gridN) I_e = -1;
 80:       I_w = (i - 1) * gridN + j;
 81:       if (i - 1 < 0) I_w = -1;
 82:       PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
 83:       PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
 84:       PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
 85:       PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
 86:       PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
 87:     }
 88:   } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
 89:   PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
 90:   PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
 91:   PetscCall(PetscLogStagePop());
 92:   /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
 93:   if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE));
 94:   PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W));
 95:   /* Setup Hessian Workspace in same shape as W */
 96:   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm));
 97:   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr));
 98:   PetscFunctionReturn(PETSC_SUCCESS);
 99: }

101: static PetscErrorCode SetupWorkspace(UserCtx ctx)
102: {
103:   PetscFunctionBegin;
104:   PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
105:   for (PetscInt i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i]));
106:   for (PetscInt i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i]));
107:   PetscFunctionReturn(PETSC_SUCCESS);
108: }

110: static PetscErrorCode ConfigureContext(UserCtx ctx)
111: {
112:   PetscFunctionBegin;
113:   ctx->m        = 16;
114:   ctx->n        = 16;
115:   ctx->eps      = 1.e-3;
116:   ctx->abstol   = 1.e-4;
117:   ctx->reltol   = 1.e-2;
118:   ctx->hStart   = 1.;
119:   ctx->hMin     = 1.e-3;
120:   ctx->hFactor  = 0.5;
121:   ctx->alpha    = 1.;
122:   ctx->mu       = 1.0;
123:   ctx->matops   = 0;
124:   ctx->iter     = 10;
125:   ctx->p        = NORM_2;
126:   ctx->soft     = PETSC_FALSE;
127:   ctx->taylor   = PETSC_TRUE;
128:   ctx->use_admm = PETSC_FALSE;
129:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objective example", "ex4.c");
130:   PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL));
131:   PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL));
132:   PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL));
133:   PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL));
134:   PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL));
135:   PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL));
136:   PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL));
137:   PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL));
138:   PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL));
139:   PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL));
140:   PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL));
141:   PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL));
142:   PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL));
143:   PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL));
144:   PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL));
145:   PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL));
146:   PetscOptionsEnd();
147:   /* Creating random ctx */
148:   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx));
149:   PetscCall(PetscRandomSetFromOptions(ctx->rctx));
150:   PetscCall(CreateMatrix(ctx));
151:   PetscCall(CreateRHS(ctx));
152:   PetscCall(SetupWorkspace(ctx));
153:   PetscFunctionReturn(PETSC_SUCCESS);
154: }

156: static PetscErrorCode DestroyContext(UserCtx *ctx)
157: {
158:   PetscFunctionBegin;
159:   PetscCall(MatDestroy(&(*ctx)->F));
160:   PetscCall(MatDestroy(&(*ctx)->W));
161:   PetscCall(MatDestroy(&(*ctx)->Hm));
162:   PetscCall(MatDestroy(&(*ctx)->Hr));
163:   PetscCall(VecDestroy(&(*ctx)->d));
164:   for (PetscInt i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i]));
165:   for (PetscInt i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i]));
166:   PetscCall(PetscRandomDestroy(&(*ctx)->rctx));
167:   PetscCall(PetscFree(*ctx));
168:   PetscFunctionReturn(PETSC_SUCCESS);
169: }

171: /* compute (1/2) * ||F x - d||^2 */
172: static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
173: {
174:   UserCtx ctx = (UserCtx)_ctx;
175:   Vec     y;

177:   PetscFunctionBegin;
178:   y = ctx->workLeft[0];
179:   PetscCall(MatMult(ctx->F, x, y));
180:   PetscCall(VecAXPY(y, -1., ctx->d));
181:   PetscCall(VecDot(y, y, J));
182:   *J *= 0.5;
183:   PetscFunctionReturn(PETSC_SUCCESS);
184: }

186: /* compute V = FTFx - FTd */
187: static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
188: {
189:   UserCtx ctx = (UserCtx)_ctx;
190:   Vec     FTFx, FTd;

192:   PetscFunctionBegin;
193:   /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
194:   FTFx = ctx->workRight[0];
195:   FTd  = ctx->workRight[1];
196:   PetscCall(MatMult(ctx->W, x, FTFx));
197:   PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
198:   PetscCall(VecWAXPY(V, -1., FTd, FTFx));
199:   PetscFunctionReturn(PETSC_SUCCESS);
200: }

202: /* returns FTF */
203: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
204: {
205:   UserCtx ctx = (UserCtx)_ctx;

207:   PetscFunctionBegin;
208:   if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
209:   if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
210:   PetscFunctionReturn(PETSC_SUCCESS);
211: }

213: /* computes augment Lagrangian objective (with scaled dual):
214:  * 0.5 * ||F x - d||^2  + 0.5 * mu ||x - z + u||^2 */
215: static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
216: {
217:   UserCtx   ctx = (UserCtx)_ctx;
218:   PetscReal mu, workNorm, misfit;
219:   Vec       z, u, temp;

221:   PetscFunctionBegin;
222:   mu   = ctx->mu;
223:   z    = ctx->workRight[5];
224:   u    = ctx->workRight[6];
225:   temp = ctx->workRight[10];
226:   /* misfit = f(x) */
227:   PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
228:   PetscCall(VecCopy(x, temp));
229:   /* temp = x - z + u */
230:   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
231:   /* workNorm = ||x - z + u||^2 */
232:   PetscCall(VecDot(temp, temp, &workNorm));
233:   /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
234:   *J = misfit + 0.5 * mu * workNorm;
235:   PetscFunctionReturn(PETSC_SUCCESS);
236: }

238: /* computes FTFx - FTd  mu*(x - z + u) */
239: static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
240: {
241:   UserCtx   ctx = (UserCtx)_ctx;
242:   PetscReal mu;
243:   Vec       z, u, temp;

245:   PetscFunctionBegin;
246:   mu   = ctx->mu;
247:   z    = ctx->workRight[5];
248:   u    = ctx->workRight[6];
249:   temp = ctx->workRight[10];
250:   PetscCall(GradientMisfit(tao, x, V, _ctx));
251:   PetscCall(VecCopy(x, temp));
252:   /* temp = x - z + u */
253:   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
254:   /* V =  FTFx - FTd  mu*(x - z + u) */
255:   PetscCall(VecAXPY(V, mu, temp));
256:   PetscFunctionReturn(PETSC_SUCCESS);
257: }

259: /* returns FTF + diag(mu) */
260: static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
261: {
262:   UserCtx ctx = (UserCtx)_ctx;

264:   PetscFunctionBegin;
265:   PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
266:   PetscCall(MatShift(H, ctx->mu));
267:   if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
268:   PetscFunctionReturn(PETSC_SUCCESS);
269: }

271: /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
272: static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
273: {
274:   UserCtx   ctx = (UserCtx)_ctx;
275:   PetscReal norm;

277:   PetscFunctionBegin;
278:   *J = 0;
279:   PetscCall(VecNorm(x, ctx->p, &norm));
280:   if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
281:   *J = ctx->alpha * norm;
282:   PetscFunctionReturn(PETSC_SUCCESS);
283: }

285: /* NORM_2 Case: return x
286:  * NORM_1 Case: x/(|x| + eps)
287:  * Else: TODO */
288: static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
289: {
290:   UserCtx   ctx = (UserCtx)_ctx;
291:   PetscReal eps = ctx->eps;

293:   PetscFunctionBegin;
294:   if (ctx->p == NORM_2) PetscCall(VecCopy(x, V));
295:   else if (ctx->p == NORM_1) {
296:     PetscCall(VecCopy(x, ctx->workRight[1]));
297:     PetscCall(VecAbs(ctx->workRight[1]));
298:     PetscCall(VecShift(ctx->workRight[1], eps));
299:     PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
300:   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
301:   PetscFunctionReturn(PETSC_SUCCESS);
302: }

304: /* NORM_2 Case: returns diag(mu)
305:  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)))  */
306: static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
307: {
308:   UserCtx   ctx = (UserCtx)_ctx;
309:   PetscReal eps = ctx->eps;
310:   Vec       copy1, copy2, copy3;

312:   PetscFunctionBegin;
313:   if (ctx->p == NORM_2) {
314:     /* Identity matrix scaled by mu */
315:     PetscCall(MatZeroEntries(H));
316:     PetscCall(MatShift(H, ctx->mu));
317:     if (Hpre != H) {
318:       PetscCall(MatZeroEntries(Hpre));
319:       PetscCall(MatShift(Hpre, ctx->mu));
320:     }
321:   } else if (ctx->p == NORM_1) {
322:     /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
323:     copy1 = ctx->workRight[1];
324:     copy2 = ctx->workRight[2];
325:     copy3 = ctx->workRight[3];
326:     /* copy1 : 1/sqrt(x_i^2 + eps) */
327:     PetscCall(VecCopy(x, copy1));
328:     PetscCall(VecPow(copy1, 2));
329:     PetscCall(VecShift(copy1, eps));
330:     PetscCall(VecSqrtAbs(copy1));
331:     PetscCall(VecReciprocal(copy1));
332:     /* copy2:  x_i^2.*/
333:     PetscCall(VecCopy(x, copy2));
334:     PetscCall(VecPow(copy2, 2));
335:     /* copy3: abs(x_i^2 + eps) */
336:     PetscCall(VecCopy(x, copy3));
337:     PetscCall(VecPow(copy3, 2));
338:     PetscCall(VecShift(copy3, eps));
339:     PetscCall(VecAbs(copy3));
340:     /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
341:     PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
342:     PetscCall(VecScale(copy2, -1.));
343:     PetscCall(VecShift(copy2, 1.));
344:     PetscCall(VecAXPY(copy1, 1., copy2));
345:     PetscCall(VecScale(copy1, ctx->mu));
346:     PetscCall(MatZeroEntries(H));
347:     PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
348:     if (Hpre != H) {
349:       PetscCall(MatZeroEntries(Hpre));
350:       PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
351:     }
352:   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
353:   PetscFunctionReturn(PETSC_SUCCESS);
354: }

356: /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
357:  * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
358: static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
359: {
360:   UserCtx   ctx = (UserCtx)_ctx;
361:   PetscReal mu, workNorm, reg;
362:   Vec       x, u, temp;

364:   PetscFunctionBegin;
365:   mu   = ctx->mu;
366:   x    = ctx->workRight[4];
367:   u    = ctx->workRight[6];
368:   temp = ctx->workRight[10];
369:   PetscCall(ObjectiveRegularization(tao, z, &reg, _ctx));
370:   PetscCall(VecCopy(z, temp));
371:   /* temp = x + u -z */
372:   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
373:   /* workNorm = ||x + u - z ||^2 */
374:   PetscCall(VecDot(temp, temp, &workNorm));
375:   *J = reg + 0.5 * mu * workNorm;
376:   PetscFunctionReturn(PETSC_SUCCESS);
377: }

379: /* NORM_2 Case: x - mu*(x + u - z)
380:  * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
381:  * Else: TODO */
382: static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
383: {
384:   UserCtx   ctx = (UserCtx)_ctx;
385:   PetscReal mu;
386:   Vec       x, u, temp;

388:   PetscFunctionBegin;
389:   mu   = ctx->mu;
390:   x    = ctx->workRight[4];
391:   u    = ctx->workRight[6];
392:   temp = ctx->workRight[10];
393:   PetscCall(GradientRegularization(tao, z, V, _ctx));
394:   PetscCall(VecCopy(z, temp));
395:   /* temp = x + u -z */
396:   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
397:   PetscCall(VecAXPY(V, -mu, temp));
398:   PetscFunctionReturn(PETSC_SUCCESS);
399: }

401: /* NORM_2 Case: returns diag(mu)
402:  * NORM_1 Case: FTF + diag(mu) */
403: static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
404: {
405:   UserCtx ctx = (UserCtx)_ctx;

407:   PetscFunctionBegin;
408:   if (ctx->p == NORM_2) {
409:     /* Identity matrix scaled by mu */
410:     PetscCall(MatZeroEntries(H));
411:     PetscCall(MatShift(H, ctx->mu));
412:     if (Hpre != H) {
413:       PetscCall(MatZeroEntries(Hpre));
414:       PetscCall(MatShift(Hpre, ctx->mu));
415:     }
416:   } else if (ctx->p == NORM_1) {
417:     PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
418:     PetscCall(MatShift(H, ctx->mu));
419:     if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
420:   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
421:   PetscFunctionReturn(PETSC_SUCCESS);
422: }

424: /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
425: *  NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
426: static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, PetscCtx ctx)
427: {
428:   PetscReal Jm, Jr;

430:   PetscFunctionBegin;
431:   PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
432:   PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
433:   *J = Jm + Jr;
434:   PetscFunctionReturn(PETSC_SUCCESS);
435: }

437: /* NORM_2 Case: FTFx - FTd + x
438:  * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
439: static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, PetscCtx ctx)
440: {
441:   UserCtx cntx = (UserCtx)ctx;

443:   PetscFunctionBegin;
444:   PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
445:   PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
446:   PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
447:   PetscFunctionReturn(PETSC_SUCCESS);
448: }

450: /* NORM_2 Case: diag(mu) + FTF
451:  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF  */
452: static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, PetscCtx ctx)
453: {
454:   Mat tempH;

456:   PetscFunctionBegin;
457:   PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
458:   PetscCall(HessianMisfit(tao, x, H, H, ctx));
459:   PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
460:   PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
461:   if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
462:   PetscCall(MatDestroy(&tempH));
463:   PetscFunctionReturn(PETSC_SUCCESS);
464: }

466: static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x)
467: {
468:   PetscInt  i;
469:   PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
470:   Tao       tao1, tao2;
471:   Vec       xk, z, u, diff, zold, zdiff, temp;
472:   PetscReal mu;

474:   PetscFunctionBegin;
475:   xk    = ctx->workRight[4];
476:   z     = ctx->workRight[5];
477:   u     = ctx->workRight[6];
478:   diff  = ctx->workRight[7];
479:   zold  = ctx->workRight[8];
480:   zdiff = ctx->workRight[9];
481:   temp  = ctx->workRight[11];
482:   mu    = ctx->mu;
483:   PetscCall(VecSet(u, 0.));
484:   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
485:   PetscCall(TaoSetType(tao1, TAONLS));
486:   PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
487:   PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
488:   PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
489:   PetscCall(VecSet(xk, 0.));
490:   PetscCall(TaoSetSolution(tao1, xk));
491:   PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
492:   PetscCall(TaoSetFromOptions(tao1));
493:   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
494:   if (ctx->p == NORM_2) {
495:     PetscCall(TaoSetType(tao2, TAONLS));
496:     PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
497:     PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
498:     PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
499:   }
500:   PetscCall(VecSet(z, 0.));
501:   PetscCall(TaoSetSolution(tao2, z));
502:   PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
503:   PetscCall(TaoSetFromOptions(tao2));

505:   for (i = 0; i < ctx->iter; i++) {
506:     PetscCall(VecCopy(z, zold));
507:     PetscCall(TaoSolve(tao1)); /* Updates xk */
508:     if (ctx->p == NORM_1) {
509:       PetscCall(VecWAXPY(temp, 1., xk, u));
510:       PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
511:     } else {
512:       PetscCall(TaoSolve(tao2)); /* Update zk */
513:     }
514:     /* u = u + xk -z */
515:     PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
516:     /* r_norm : norm(x-z) */
517:     PetscCall(VecWAXPY(diff, -1., z, xk));
518:     PetscCall(VecNorm(diff, NORM_2, &r_norm));
519:     /* s_norm : norm(-mu(z-zold)) */
520:     PetscCall(VecWAXPY(zdiff, -1., zold, z));
521:     PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
522:     s_norm = s_norm * mu;
523:     /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
524:     PetscCall(VecNorm(xk, NORM_2, &x_norm));
525:     PetscCall(VecNorm(z, NORM_2, &z_norm));
526:     primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
527:     /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
528:     PetscCall(VecNorm(u, NORM_2, &u_norm));
529:     dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
530:     PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
531:     if (r_norm < primal && s_norm < dual) break;
532:   }
533:   PetscCall(VecCopy(xk, x));
534:   PetscCall(TaoDestroy(&tao1));
535:   PetscCall(TaoDestroy(&tao2));
536:   PetscFunctionReturn(PETSC_SUCCESS);
537: }

539: /* Second order Taylor remainder convergence test */
540: static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
541: {
542:   PetscReal  h, J, temp;
543:   PetscInt   i;
544:   PetscInt   numValues;
545:   PetscReal  Jx, Jxhat_comp, Jxhat_pred;
546:   PetscReal *Js, *hs;
547:   PetscReal  gdotdx;
548:   PetscReal  minrate = PETSC_MAX_REAL;
549:   MPI_Comm   comm    = PetscObjectComm((PetscObject)x);
550:   Vec        g, dx, xhat;

552:   PetscFunctionBegin;
553:   PetscCall(VecDuplicate(x, &g));
554:   PetscCall(VecDuplicate(x, &xhat));
555:   /* choose a perturbation direction */
556:   PetscCall(VecDuplicate(x, &dx));
557:   PetscCall(VecSetRandom(dx, ctx->rctx));
558:   /* evaluate objective at x: J(x) */
559:   PetscCall(TaoComputeObjective(tao, x, &Jx));
560:   /* evaluate gradient at x, save in vector g */
561:   PetscCall(TaoComputeGradient(tao, x, g));
562:   PetscCall(VecDot(g, dx, &gdotdx));

564:   for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
565:   PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
566:   for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
567:     PetscCall(VecWAXPY(xhat, h, dx, x));
568:     PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
569:     /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
570:     Jxhat_pred = Jx + h * gdotdx;
571:     /* Vector to dJdm scalar? Dot?*/
572:     J = PetscAbsReal(Jxhat_comp - Jxhat_pred);
573:     PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
574:     Js[i] = J;
575:     hs[i] = h;
576:   }
577:   for (PetscInt j = 1; j < numValues; j++) {
578:     temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
579:     PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
580:     minrate = PetscMin(minrate, temp);
581:   }
582:   /* If O is not ~2, then the test is wrong */
583:   PetscCall(PetscFree2(Js, hs));
584:   *C = minrate;
585:   PetscCall(VecDestroy(&dx));
586:   PetscCall(VecDestroy(&xhat));
587:   PetscCall(VecDestroy(&g));
588:   PetscFunctionReturn(PETSC_SUCCESS);
589: }

591: int main(int argc, char **argv)
592: {
593:   UserCtx ctx;
594:   Tao     tao;
595:   Vec     x;
596:   Mat     H;

598:   PetscFunctionBeginUser;
599:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
600:   PetscCall(PetscNew(&ctx));
601:   PetscCall(ConfigureContext(ctx));
602:   /* Define two functions that could pass as objectives to TaoSetObjective(): one
603:    * for the misfit component, and one for the regularization component */
604:   /* ObjectiveMisfit() and ObjectiveRegularization() */

606:   /* Define a single function that calls both components adds them together: the complete objective,
607:    * in the absence of a Tao implementation that handles separability */
608:   /* ObjectiveComplete() */
609:   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
610:   PetscCall(TaoSetType(tao, TAONM));
611:   PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
612:   PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
613:   PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
614:   PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
615:   PetscCall(MatCreateVecs(ctx->F, NULL, &x));
616:   PetscCall(TaoSetSolution(tao, x));
617:   PetscCall(TaoSetFromOptions(tao));
618:   if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
619:   else PetscCall(TaoSolve(tao));
620:   /* examine solution */
621:   PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
622:   if (ctx->taylor) {
623:     PetscReal rate;
624:     PetscCall(TaylorTest(ctx, tao, x, &rate));
625:   }
626:   if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x));
627:   PetscCall(MatDestroy(&H));
628:   PetscCall(TaoDestroy(&tao));
629:   PetscCall(VecDestroy(&x));
630:   PetscCall(DestroyContext(&ctx));
631:   PetscCall(PetscFinalize());
632:   return 0;
633: }

635: /*TEST

637:   build:
638:     requires: !complex

640:   test:
641:     suffix: 0
642:     args:

644:   test:
645:     suffix: l1_1
646:     args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1

648:   test:
649:     suffix: hessian_1
650:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls

652:   test:
653:     suffix: hessian_2
654:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls

656:   test:
657:     suffix: nm_1
658:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50

660:   test:
661:     suffix: nm_2
662:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50

664:   test:
665:     suffix: lmvm_1
666:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40

668:   test:
669:     suffix: lmvm_2
670:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15

672:   test:
673:     suffix: soft_threshold_admm_1
674:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm

676:   test:
677:     suffix: hessian_admm_1
678:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls

680:   test:
681:     suffix: hessian_admm_2
682:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls

684:   test:
685:     suffix: nm_admm_1
686:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm

688:   test:
689:     suffix: nm_admm_2
690:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7

692:   test:
693:     suffix: lmvm_admm_1
694:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm

696:   test:
697:     suffix: lmvm_admm_2
698:     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm

700:   test:
701:     suffix: soft
702:     args: -taylor 0 -soft 1
703:     output_file: output/empty.out

705:   test:
706:     suffix: soft_view_10
707:     args: -taylor 0 -soft 1 -tao_view -tao_max_funcs 10

709: TEST*/