Actual source code: bqnk.c
1: #include <../src/tao/bound/impls/bqnk/bqnk.h>
2: #include <petscksp.h>
4: static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5: {
6: TAO_BNK *bnk = (TAO_BNK *)tao->data;
7: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
8: PetscReal gnorm2, delta;
10: PetscFunctionBegin;
11: /* Alias the LMVM matrix into the TAO hessian */
12: if (tao->hessian) PetscCall(MatDestroy(&tao->hessian));
13: if (tao->hessian_pre) PetscCall(MatDestroy(&tao->hessian_pre));
14: PetscCall(PetscObjectReference((PetscObject)bqnk->B));
15: tao->hessian = bqnk->B;
16: PetscCall(PetscObjectReference((PetscObject)bqnk->B));
17: tao->hessian_pre = bqnk->B;
18: /* Update the Hessian with the latest solution */
19: if (bqnk->is_spd) {
20: gnorm2 = bnk->gnorm * bnk->gnorm;
21: if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
22: if (bnk->f == 0.0) {
23: delta = 2.0 / gnorm2;
24: } else {
25: delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
26: }
27: PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta));
28: }
29: PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient));
30: PetscCall(MatLMVMResetShift(tao->hessian));
31: /* Prepare the reduced sub-matrices for the inactive set */
32: PetscCall(MatDestroy(&bnk->H_inactive));
33: if (bnk->active_idx) {
34: PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive));
35: PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx));
36: } else {
37: PetscCall(PetscObjectReference((PetscObject)tao->hessian));
38: bnk->H_inactive = tao->hessian;
39: PetscCall(PCLMVMClearIS(bqnk->pc));
40: }
41: PetscCall(MatDestroy(&bnk->Hpre_inactive));
42: PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive));
43: bnk->Hpre_inactive = bnk->H_inactive;
44: PetscFunctionReturn(PETSC_SUCCESS);
45: }
47: static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
48: {
49: TAO_BNK *bnk = (TAO_BNK *)tao->data;
50: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
52: PetscFunctionBegin;
53: PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type));
54: if (*ksp_reason < 0) {
55: /* Krylov solver failed to converge so reset the LMVM matrix */
56: PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
57: PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient));
58: }
59: PetscFunctionReturn(PETSC_SUCCESS);
60: }
62: PetscErrorCode TaoSolve_BQNK(Tao tao)
63: {
64: TAO_BNK *bnk = (TAO_BNK *)tao->data;
65: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
66: Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data;
67: Mat_LMVM *J0;
68: Mat_SymBrdn *diag_ctx;
69: PetscBool flg = PETSC_FALSE;
71: PetscFunctionBegin;
72: if (!tao->recycle) {
73: PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
74: lmvm->nresets = 0;
75: if (lmvm->J0) {
76: PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg));
77: if (flg) {
78: J0 = (Mat_LMVM *)lmvm->J0->data;
79: J0->nresets = 0;
80: }
81: }
82: flg = PETSC_FALSE;
83: PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, ""));
84: if (flg) {
85: diag_ctx = (Mat_SymBrdn *)lmvm->ctx;
86: J0 = (Mat_LMVM *)diag_ctx->D->data;
87: J0->nresets = 0;
88: }
89: }
90: PetscCall((*bqnk->solve)(tao));
91: PetscFunctionReturn(PETSC_SUCCESS);
92: }
94: PetscErrorCode TaoSetUp_BQNK(Tao tao)
95: {
96: TAO_BNK *bnk = (TAO_BNK *)tao->data;
97: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
98: PetscInt n, N;
99: PetscBool is_lmvm, is_set, is_sym;
101: PetscFunctionBegin;
102: PetscCall(TaoSetUp_BNK(tao));
103: PetscCall(VecGetLocalSize(tao->solution, &n));
104: PetscCall(VecGetSize(tao->solution, &N));
105: PetscCall(MatSetSizes(bqnk->B, n, n, N, N));
106: PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient));
107: PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm));
108: PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type");
109: PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym));
110: PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric");
111: PetscCall(KSPGetPC(tao->ksp, &bqnk->pc));
112: PetscCall(PCSetType(bqnk->pc, PCLMVM));
113: PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B));
114: PetscFunctionReturn(PETSC_SUCCESS);
115: }
117: static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems *PetscOptionsObject)
118: {
119: TAO_BNK *bnk = (TAO_BNK *)tao->data;
120: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
121: PetscBool is_set;
123: PetscFunctionBegin;
124: PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject));
125: if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
126: PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix));
127: PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_"));
128: PetscCall(MatSetFromOptions(bqnk->B));
129: PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd));
130: if (!is_set) bqnk->is_spd = PETSC_FALSE;
131: PetscFunctionReturn(PETSC_SUCCESS);
132: }
134: static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
135: {
136: TAO_BNK *bnk = (TAO_BNK *)tao->data;
137: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
138: PetscBool isascii;
140: PetscFunctionBegin;
141: PetscCall(TaoView_BNK(tao, viewer));
142: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
143: if (isascii) {
144: PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO));
145: PetscCall(MatView(bqnk->B, viewer));
146: PetscCall(PetscViewerPopFormat(viewer));
147: }
148: PetscFunctionReturn(PETSC_SUCCESS);
149: }
151: static PetscErrorCode TaoDestroy_BQNK(Tao tao)
152: {
153: TAO_BNK *bnk = (TAO_BNK *)tao->data;
154: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
156: PetscFunctionBegin;
157: PetscCall(MatDestroy(&bnk->Hpre_inactive));
158: PetscCall(MatDestroy(&bnk->H_inactive));
159: PetscCall(MatDestroy(&bqnk->B));
160: PetscCall(PetscFree(bnk->ctx));
161: PetscCall(TaoDestroy_BNK(tao));
162: PetscFunctionReturn(PETSC_SUCCESS);
163: }
165: PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
166: {
167: TAO_BNK *bnk;
168: TAO_BQNK *bqnk;
170: PetscFunctionBegin;
171: PetscCall(TaoCreate_BNK(tao));
172: tao->ops->solve = TaoSolve_BQNK;
173: tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
174: tao->ops->destroy = TaoDestroy_BQNK;
175: tao->ops->view = TaoView_BQNK;
176: tao->ops->setup = TaoSetUp_BQNK;
178: bnk = (TAO_BNK *)tao->data;
179: bnk->computehessian = TaoBQNKComputeHessian;
180: bnk->computestep = TaoBQNKComputeStep;
181: bnk->init_type = BNK_INIT_DIRECTION;
183: PetscCall(PetscNew(&bqnk));
184: bnk->ctx = (void *)bqnk;
185: bqnk->is_spd = PETSC_TRUE;
187: PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B));
188: PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1));
189: PetscCall(MatSetType(bqnk->B, MATLMVMSR1));
190: PetscFunctionReturn(PETSC_SUCCESS);
191: }
193: /*@
194: TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
195: only for quasi-Newton family of methods.
197: Input Parameter:
198: . tao - `Tao` solver context
200: Output Parameter:
201: . B - LMVM matrix
203: Level: advanced
205: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
206: @*/
207: PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
208: {
209: TAO_BNK *bnk = (TAO_BNK *)tao->data;
210: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
211: PetscBool flg = PETSC_FALSE;
213: PetscFunctionBegin;
214: PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
215: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
216: *B = bqnk->B;
217: PetscFunctionReturn(PETSC_SUCCESS);
218: }
220: /*@
221: TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
222: only for quasi-Newton family of methods.
224: QN family of methods create their own LMVM matrices and users who wish to
225: manipulate this matrix should use TaoGetLMVMMatrix() instead.
227: Input Parameters:
228: + tao - Tao solver context
229: - B - LMVM matrix
231: Level: advanced
233: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
234: @*/
235: PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
236: {
237: TAO_BNK *bnk = (TAO_BNK *)tao->data;
238: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
239: PetscBool flg = PETSC_FALSE;
241: PetscFunctionBegin;
242: PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
243: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
244: PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg));
245: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix");
246: if (bqnk->B) PetscCall(MatDestroy(&bqnk->B));
247: PetscCall(PetscObjectReference((PetscObject)B));
248: bqnk->B = B;
249: PetscFunctionReturn(PETSC_SUCCESS);
250: }