Actual source code: ex4.c
1: static char help[] = "Simple example to test separable objective optimizers.\n";
3: #include <petsc.h>
4: #include <petsctao.h>
5: #include <petscvec.h>
6: #include <petscmath.h>
8: #define NWORKLEFT 4
9: #define NWORKRIGHT 12
11: typedef struct _UserCtx {
12: PetscInt m; /* The row dimension of F */
13: PetscInt n; /* The column dimension of F */
14: PetscInt matops; /* Matrix format. 0 for stencil, 1 for random */
15: PetscInt iter; /* Number of iterations for ADMM */
16: PetscReal hStart; /* Starting point for Taylor test */
17: PetscReal hFactor; /* Taylor test step factor */
18: PetscReal hMin; /* Taylor test end goal */
19: PetscReal alpha; /* regularization constant applied to || x ||_p */
20: PetscReal eps; /* small constant for approximating gradient of || x ||_1 */
21: PetscReal mu; /* the augmented Lagrangian term in ADMM */
22: PetscReal abstol;
23: PetscReal reltol;
24: Mat F; /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
25: Mat W; /* Workspace matrix. ATA */
26: Mat Hm; /* Hessian Misfit*/
27: Mat Hr; /* Hessian Reg*/
28: Vec d; /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
29: Vec workLeft[NWORKLEFT]; /* Workspace for temporary vec */
30: Vec workRight[NWORKRIGHT]; /* Workspace for temporary vec */
31: NormType p;
32: PetscRandom rctx;
33: PetscBool soft;
34: PetscBool taylor; /* Flag to determine whether to run Taylor test or not */
35: PetscBool use_admm; /* Flag to determine whether to run Taylor test or not */
36: } *UserCtx;
38: static PetscErrorCode CreateRHS(UserCtx ctx)
39: {
40: PetscFunctionBegin;
41: /* build the rhs d in ctx */
42: PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->d));
43: PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
44: PetscCall(VecSetFromOptions(ctx->d));
45: PetscCall(VecSetRandom(ctx->d, ctx->rctx));
46: PetscFunctionReturn(PETSC_SUCCESS);
47: }
49: static PetscErrorCode CreateMatrix(UserCtx ctx)
50: {
51: PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
52: PetscLogStage stage;
54: PetscFunctionBegin;
55: /* build the matrix F in ctx */
56: PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->F));
57: PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
58: PetscCall(MatSetType(ctx->F, MATAIJ)); /* TODO: Decide specific SetType other than dummy*/
59: PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
60: PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
61: PetscCall(MatSetUp(ctx->F));
62: PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
63: PetscCall(PetscLogStageRegister("Assembly", &stage));
64: PetscCall(PetscLogStagePush(stage));
66: /* Set matrix elements in 2-D five point stencil format. */
67: if (!ctx->matops) {
68: PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
69: gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
70: PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
71: for (Ii = Istart; Ii < Iend; Ii++) {
72: i = Ii / gridN;
73: j = Ii % gridN;
74: I_n = i * gridN + j + 1;
75: if (j + 1 >= gridN) I_n = -1;
76: I_s = i * gridN + j - 1;
77: if (j - 1 < 0) I_s = -1;
78: I_e = (i + 1) * gridN + j;
79: if (i + 1 >= gridN) I_e = -1;
80: I_w = (i - 1) * gridN + j;
81: if (i - 1 < 0) I_w = -1;
82: PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
83: PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
84: PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
85: PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
86: PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
87: }
88: } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
89: PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
90: PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
91: PetscCall(PetscLogStagePop());
92: /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
93: if (!ctx->matops) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE));
94: PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &ctx->W));
95: /* Setup Hessian Workspace in same shape as W */
96: PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hm));
97: PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &ctx->Hr));
98: PetscFunctionReturn(PETSC_SUCCESS);
99: }
101: static PetscErrorCode SetupWorkspace(UserCtx ctx)
102: {
103: PetscFunctionBegin;
104: PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
105: for (PetscInt i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &ctx->workLeft[i]));
106: for (PetscInt i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &ctx->workRight[i]));
107: PetscFunctionReturn(PETSC_SUCCESS);
108: }
110: static PetscErrorCode ConfigureContext(UserCtx ctx)
111: {
112: PetscFunctionBegin;
113: ctx->m = 16;
114: ctx->n = 16;
115: ctx->eps = 1.e-3;
116: ctx->abstol = 1.e-4;
117: ctx->reltol = 1.e-2;
118: ctx->hStart = 1.;
119: ctx->hMin = 1.e-3;
120: ctx->hFactor = 0.5;
121: ctx->alpha = 1.;
122: ctx->mu = 1.0;
123: ctx->matops = 0;
124: ctx->iter = 10;
125: ctx->p = NORM_2;
126: ctx->soft = PETSC_FALSE;
127: ctx->taylor = PETSC_TRUE;
128: ctx->use_admm = PETSC_FALSE;
129: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objective example", "ex4.c");
130: PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &ctx->m, NULL));
131: PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &ctx->n, NULL));
132: PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &ctx->matops, NULL));
133: PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &ctx->iter, NULL));
134: PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &ctx->alpha, NULL));
135: PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &ctx->eps, NULL));
136: PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &ctx->mu, NULL));
137: PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &ctx->hStart, NULL));
138: PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &ctx->hFactor, NULL));
139: PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &ctx->hMin, NULL));
140: PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &ctx->abstol, NULL));
141: PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &ctx->reltol, NULL));
142: PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &ctx->taylor, NULL));
143: PetscCall(PetscOptionsBool("-soft", "Flag for testing soft threshold no-op case. Default is false.", "ex4.c", ctx->soft, &ctx->soft, NULL));
144: PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &ctx->use_admm, NULL));
145: PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&ctx->p, NULL));
146: PetscOptionsEnd();
147: /* Creating random ctx */
148: PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &ctx->rctx));
149: PetscCall(PetscRandomSetFromOptions(ctx->rctx));
150: PetscCall(CreateMatrix(ctx));
151: PetscCall(CreateRHS(ctx));
152: PetscCall(SetupWorkspace(ctx));
153: PetscFunctionReturn(PETSC_SUCCESS);
154: }
156: static PetscErrorCode DestroyContext(UserCtx *ctx)
157: {
158: PetscFunctionBegin;
159: PetscCall(MatDestroy(&(*ctx)->F));
160: PetscCall(MatDestroy(&(*ctx)->W));
161: PetscCall(MatDestroy(&(*ctx)->Hm));
162: PetscCall(MatDestroy(&(*ctx)->Hr));
163: PetscCall(VecDestroy(&(*ctx)->d));
164: for (PetscInt i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&(*ctx)->workLeft[i]));
165: for (PetscInt i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&(*ctx)->workRight[i]));
166: PetscCall(PetscRandomDestroy(&(*ctx)->rctx));
167: PetscCall(PetscFree(*ctx));
168: PetscFunctionReturn(PETSC_SUCCESS);
169: }
171: /* compute (1/2) * ||F x - d||^2 */
172: static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx)
173: {
174: UserCtx ctx = (UserCtx)_ctx;
175: Vec y;
177: PetscFunctionBegin;
178: y = ctx->workLeft[0];
179: PetscCall(MatMult(ctx->F, x, y));
180: PetscCall(VecAXPY(y, -1., ctx->d));
181: PetscCall(VecDot(y, y, J));
182: *J *= 0.5;
183: PetscFunctionReturn(PETSC_SUCCESS);
184: }
186: /* compute V = FTFx - FTd */
187: static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx)
188: {
189: UserCtx ctx = (UserCtx)_ctx;
190: Vec FTFx, FTd;
192: PetscFunctionBegin;
193: /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
194: FTFx = ctx->workRight[0];
195: FTd = ctx->workRight[1];
196: PetscCall(MatMult(ctx->W, x, FTFx));
197: PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
198: PetscCall(VecWAXPY(V, -1., FTd, FTFx));
199: PetscFunctionReturn(PETSC_SUCCESS);
200: }
202: /* returns FTF */
203: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
204: {
205: UserCtx ctx = (UserCtx)_ctx;
207: PetscFunctionBegin;
208: if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
209: if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
210: PetscFunctionReturn(PETSC_SUCCESS);
211: }
213: /* computes augment Lagrangian objective (with scaled dual):
214: * 0.5 * ||F x - d||^2 + 0.5 * mu ||x - z + u||^2 */
215: static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx)
216: {
217: UserCtx ctx = (UserCtx)_ctx;
218: PetscReal mu, workNorm, misfit;
219: Vec z, u, temp;
221: PetscFunctionBegin;
222: mu = ctx->mu;
223: z = ctx->workRight[5];
224: u = ctx->workRight[6];
225: temp = ctx->workRight[10];
226: /* misfit = f(x) */
227: PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
228: PetscCall(VecCopy(x, temp));
229: /* temp = x - z + u */
230: PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
231: /* workNorm = ||x - z + u||^2 */
232: PetscCall(VecDot(temp, temp, &workNorm));
233: /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
234: *J = misfit + 0.5 * mu * workNorm;
235: PetscFunctionReturn(PETSC_SUCCESS);
236: }
238: /* computes FTFx - FTd mu*(x - z + u) */
239: static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx)
240: {
241: UserCtx ctx = (UserCtx)_ctx;
242: PetscReal mu;
243: Vec z, u, temp;
245: PetscFunctionBegin;
246: mu = ctx->mu;
247: z = ctx->workRight[5];
248: u = ctx->workRight[6];
249: temp = ctx->workRight[10];
250: PetscCall(GradientMisfit(tao, x, V, _ctx));
251: PetscCall(VecCopy(x, temp));
252: /* temp = x - z + u */
253: PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
254: /* V = FTFx - FTd mu*(x - z + u) */
255: PetscCall(VecAXPY(V, mu, temp));
256: PetscFunctionReturn(PETSC_SUCCESS);
257: }
259: /* returns FTF + diag(mu) */
260: static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
261: {
262: UserCtx ctx = (UserCtx)_ctx;
264: PetscFunctionBegin;
265: PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
266: PetscCall(MatShift(H, ctx->mu));
267: if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
268: PetscFunctionReturn(PETSC_SUCCESS);
269: }
271: /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
272: static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx)
273: {
274: UserCtx ctx = (UserCtx)_ctx;
275: PetscReal norm;
277: PetscFunctionBegin;
278: *J = 0;
279: PetscCall(VecNorm(x, ctx->p, &norm));
280: if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
281: *J = ctx->alpha * norm;
282: PetscFunctionReturn(PETSC_SUCCESS);
283: }
285: /* NORM_2 Case: return x
286: * NORM_1 Case: x/(|x| + eps)
287: * Else: TODO */
288: static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx)
289: {
290: UserCtx ctx = (UserCtx)_ctx;
291: PetscReal eps = ctx->eps;
293: PetscFunctionBegin;
294: if (ctx->p == NORM_2) PetscCall(VecCopy(x, V));
295: else if (ctx->p == NORM_1) {
296: PetscCall(VecCopy(x, ctx->workRight[1]));
297: PetscCall(VecAbs(ctx->workRight[1]));
298: PetscCall(VecShift(ctx->workRight[1], eps));
299: PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
300: } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
301: PetscFunctionReturn(PETSC_SUCCESS);
302: }
304: /* NORM_2 Case: returns diag(mu)
305: * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) */
306: static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
307: {
308: UserCtx ctx = (UserCtx)_ctx;
309: PetscReal eps = ctx->eps;
310: Vec copy1, copy2, copy3;
312: PetscFunctionBegin;
313: if (ctx->p == NORM_2) {
314: /* Identity matrix scaled by mu */
315: PetscCall(MatZeroEntries(H));
316: PetscCall(MatShift(H, ctx->mu));
317: if (Hpre != H) {
318: PetscCall(MatZeroEntries(Hpre));
319: PetscCall(MatShift(Hpre, ctx->mu));
320: }
321: } else if (ctx->p == NORM_1) {
322: /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
323: copy1 = ctx->workRight[1];
324: copy2 = ctx->workRight[2];
325: copy3 = ctx->workRight[3];
326: /* copy1 : 1/sqrt(x_i^2 + eps) */
327: PetscCall(VecCopy(x, copy1));
328: PetscCall(VecPow(copy1, 2));
329: PetscCall(VecShift(copy1, eps));
330: PetscCall(VecSqrtAbs(copy1));
331: PetscCall(VecReciprocal(copy1));
332: /* copy2: x_i^2.*/
333: PetscCall(VecCopy(x, copy2));
334: PetscCall(VecPow(copy2, 2));
335: /* copy3: abs(x_i^2 + eps) */
336: PetscCall(VecCopy(x, copy3));
337: PetscCall(VecPow(copy3, 2));
338: PetscCall(VecShift(copy3, eps));
339: PetscCall(VecAbs(copy3));
340: /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
341: PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
342: PetscCall(VecScale(copy2, -1.));
343: PetscCall(VecShift(copy2, 1.));
344: PetscCall(VecAXPY(copy1, 1., copy2));
345: PetscCall(VecScale(copy1, ctx->mu));
346: PetscCall(MatZeroEntries(H));
347: PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
348: if (Hpre != H) {
349: PetscCall(MatZeroEntries(Hpre));
350: PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
351: }
352: } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
353: PetscFunctionReturn(PETSC_SUCCESS);
354: }
356: /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
357: * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
358: static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx)
359: {
360: UserCtx ctx = (UserCtx)_ctx;
361: PetscReal mu, workNorm, reg;
362: Vec x, u, temp;
364: PetscFunctionBegin;
365: mu = ctx->mu;
366: x = ctx->workRight[4];
367: u = ctx->workRight[6];
368: temp = ctx->workRight[10];
369: PetscCall(ObjectiveRegularization(tao, z, ®, _ctx));
370: PetscCall(VecCopy(z, temp));
371: /* temp = x + u -z */
372: PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
373: /* workNorm = ||x + u - z ||^2 */
374: PetscCall(VecDot(temp, temp, &workNorm));
375: *J = reg + 0.5 * mu * workNorm;
376: PetscFunctionReturn(PETSC_SUCCESS);
377: }
379: /* NORM_2 Case: x - mu*(x + u - z)
380: * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
381: * Else: TODO */
382: static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx)
383: {
384: UserCtx ctx = (UserCtx)_ctx;
385: PetscReal mu;
386: Vec x, u, temp;
388: PetscFunctionBegin;
389: mu = ctx->mu;
390: x = ctx->workRight[4];
391: u = ctx->workRight[6];
392: temp = ctx->workRight[10];
393: PetscCall(GradientRegularization(tao, z, V, _ctx));
394: PetscCall(VecCopy(z, temp));
395: /* temp = x + u -z */
396: PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
397: PetscCall(VecAXPY(V, -mu, temp));
398: PetscFunctionReturn(PETSC_SUCCESS);
399: }
401: /* NORM_2 Case: returns diag(mu)
402: * NORM_1 Case: FTF + diag(mu) */
403: static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx)
404: {
405: UserCtx ctx = (UserCtx)_ctx;
407: PetscFunctionBegin;
408: if (ctx->p == NORM_2) {
409: /* Identity matrix scaled by mu */
410: PetscCall(MatZeroEntries(H));
411: PetscCall(MatShift(H, ctx->mu));
412: if (Hpre != H) {
413: PetscCall(MatZeroEntries(Hpre));
414: PetscCall(MatShift(Hpre, ctx->mu));
415: }
416: } else if (ctx->p == NORM_1) {
417: PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
418: PetscCall(MatShift(H, ctx->mu));
419: if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
420: } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
421: PetscFunctionReturn(PETSC_SUCCESS);
422: }
424: /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
425: * NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
426: static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, PetscCtx ctx)
427: {
428: PetscReal Jm, Jr;
430: PetscFunctionBegin;
431: PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
432: PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
433: *J = Jm + Jr;
434: PetscFunctionReturn(PETSC_SUCCESS);
435: }
437: /* NORM_2 Case: FTFx - FTd + x
438: * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
439: static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, PetscCtx ctx)
440: {
441: UserCtx cntx = (UserCtx)ctx;
443: PetscFunctionBegin;
444: PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
445: PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
446: PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
447: PetscFunctionReturn(PETSC_SUCCESS);
448: }
450: /* NORM_2 Case: diag(mu) + FTF
451: * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF */
452: static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, PetscCtx ctx)
453: {
454: Mat tempH;
456: PetscFunctionBegin;
457: PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
458: PetscCall(HessianMisfit(tao, x, H, H, ctx));
459: PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
460: PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
461: if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
462: PetscCall(MatDestroy(&tempH));
463: PetscFunctionReturn(PETSC_SUCCESS);
464: }
466: static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x)
467: {
468: PetscInt i;
469: PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
470: Tao tao1, tao2;
471: Vec xk, z, u, diff, zold, zdiff, temp;
472: PetscReal mu;
474: PetscFunctionBegin;
475: xk = ctx->workRight[4];
476: z = ctx->workRight[5];
477: u = ctx->workRight[6];
478: diff = ctx->workRight[7];
479: zold = ctx->workRight[8];
480: zdiff = ctx->workRight[9];
481: temp = ctx->workRight[11];
482: mu = ctx->mu;
483: PetscCall(VecSet(u, 0.));
484: PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
485: PetscCall(TaoSetType(tao1, TAONLS));
486: PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
487: PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
488: PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
489: PetscCall(VecSet(xk, 0.));
490: PetscCall(TaoSetSolution(tao1, xk));
491: PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
492: PetscCall(TaoSetFromOptions(tao1));
493: PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
494: if (ctx->p == NORM_2) {
495: PetscCall(TaoSetType(tao2, TAONLS));
496: PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
497: PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
498: PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
499: }
500: PetscCall(VecSet(z, 0.));
501: PetscCall(TaoSetSolution(tao2, z));
502: PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
503: PetscCall(TaoSetFromOptions(tao2));
505: for (i = 0; i < ctx->iter; i++) {
506: PetscCall(VecCopy(z, zold));
507: PetscCall(TaoSolve(tao1)); /* Updates xk */
508: if (ctx->p == NORM_1) {
509: PetscCall(VecWAXPY(temp, 1., xk, u));
510: PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
511: } else {
512: PetscCall(TaoSolve(tao2)); /* Update zk */
513: }
514: /* u = u + xk -z */
515: PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
516: /* r_norm : norm(x-z) */
517: PetscCall(VecWAXPY(diff, -1., z, xk));
518: PetscCall(VecNorm(diff, NORM_2, &r_norm));
519: /* s_norm : norm(-mu(z-zold)) */
520: PetscCall(VecWAXPY(zdiff, -1., zold, z));
521: PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
522: s_norm = s_norm * mu;
523: /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
524: PetscCall(VecNorm(xk, NORM_2, &x_norm));
525: PetscCall(VecNorm(z, NORM_2, &z_norm));
526: primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
527: /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
528: PetscCall(VecNorm(u, NORM_2, &u_norm));
529: dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
530: PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
531: if (r_norm < primal && s_norm < dual) break;
532: }
533: PetscCall(VecCopy(xk, x));
534: PetscCall(TaoDestroy(&tao1));
535: PetscCall(TaoDestroy(&tao2));
536: PetscFunctionReturn(PETSC_SUCCESS);
537: }
539: /* Second order Taylor remainder convergence test */
540: static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C)
541: {
542: PetscReal h, J, temp;
543: PetscInt i;
544: PetscInt numValues;
545: PetscReal Jx, Jxhat_comp, Jxhat_pred;
546: PetscReal *Js, *hs;
547: PetscReal gdotdx;
548: PetscReal minrate = PETSC_MAX_REAL;
549: MPI_Comm comm = PetscObjectComm((PetscObject)x);
550: Vec g, dx, xhat;
552: PetscFunctionBegin;
553: PetscCall(VecDuplicate(x, &g));
554: PetscCall(VecDuplicate(x, &xhat));
555: /* choose a perturbation direction */
556: PetscCall(VecDuplicate(x, &dx));
557: PetscCall(VecSetRandom(dx, ctx->rctx));
558: /* evaluate objective at x: J(x) */
559: PetscCall(TaoComputeObjective(tao, x, &Jx));
560: /* evaluate gradient at x, save in vector g */
561: PetscCall(TaoComputeGradient(tao, x, g));
562: PetscCall(VecDot(g, dx, &gdotdx));
564: for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
565: PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
566: for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
567: PetscCall(VecWAXPY(xhat, h, dx, x));
568: PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
569: /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
570: Jxhat_pred = Jx + h * gdotdx;
571: /* Vector to dJdm scalar? Dot?*/
572: J = PetscAbsReal(Jxhat_comp - Jxhat_pred);
573: PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
574: Js[i] = J;
575: hs[i] = h;
576: }
577: for (PetscInt j = 1; j < numValues; j++) {
578: temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
579: PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
580: minrate = PetscMin(minrate, temp);
581: }
582: /* If O is not ~2, then the test is wrong */
583: PetscCall(PetscFree2(Js, hs));
584: *C = minrate;
585: PetscCall(VecDestroy(&dx));
586: PetscCall(VecDestroy(&xhat));
587: PetscCall(VecDestroy(&g));
588: PetscFunctionReturn(PETSC_SUCCESS);
589: }
591: int main(int argc, char **argv)
592: {
593: UserCtx ctx;
594: Tao tao;
595: Vec x;
596: Mat H;
598: PetscFunctionBeginUser;
599: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
600: PetscCall(PetscNew(&ctx));
601: PetscCall(ConfigureContext(ctx));
602: /* Define two functions that could pass as objectives to TaoSetObjective(): one
603: * for the misfit component, and one for the regularization component */
604: /* ObjectiveMisfit() and ObjectiveRegularization() */
606: /* Define a single function that calls both components adds them together: the complete objective,
607: * in the absence of a Tao implementation that handles separability */
608: /* ObjectiveComplete() */
609: PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
610: PetscCall(TaoSetType(tao, TAONM));
611: PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
612: PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
613: PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
614: PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
615: PetscCall(MatCreateVecs(ctx->F, NULL, &x));
616: PetscCall(TaoSetSolution(tao, x));
617: PetscCall(TaoSetFromOptions(tao));
618: if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
619: else PetscCall(TaoSolve(tao));
620: /* examine solution */
621: PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
622: if (ctx->taylor) {
623: PetscReal rate;
624: PetscCall(TaylorTest(ctx, tao, x, &rate));
625: }
626: if (ctx->soft) PetscCall(TaoSoftThreshold(x, 0., 0., x));
627: PetscCall(MatDestroy(&H));
628: PetscCall(TaoDestroy(&tao));
629: PetscCall(VecDestroy(&x));
630: PetscCall(DestroyContext(&ctx));
631: PetscCall(PetscFinalize());
632: return 0;
633: }
635: /*TEST
637: build:
638: requires: !complex
640: test:
641: suffix: 0
642: args:
644: test:
645: suffix: l1_1
646: args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
648: test:
649: suffix: hessian_1
650: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
652: test:
653: suffix: hessian_2
654: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
656: test:
657: suffix: nm_1
658: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
660: test:
661: suffix: nm_2
662: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
664: test:
665: suffix: lmvm_1
666: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
668: test:
669: suffix: lmvm_2
670: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
672: test:
673: suffix: soft_threshold_admm_1
674: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
676: test:
677: suffix: hessian_admm_1
678: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
680: test:
681: suffix: hessian_admm_2
682: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
684: test:
685: suffix: nm_admm_1
686: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
688: test:
689: suffix: nm_admm_2
690: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
692: test:
693: suffix: lmvm_admm_1
694: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
696: test:
697: suffix: lmvm_admm_2
698: args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
700: test:
701: suffix: soft
702: args: -taylor 0 -soft 1
703: output_file: output/empty.out
705: test:
706: suffix: soft_view_10
707: args: -taylor 0 -soft 1 -tao_view -tao_max_funcs 10
709: TEST*/