Actual source code: bqnk.c
1: #include <../src/tao/bound/impls/bqnk/bqnk.h>
2: #include <petscksp.h>
4: static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5: {
6: TAO_BNK *bnk = (TAO_BNK *)tao->data;
7: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
8: PetscReal gnorm2, delta;
10: PetscFunctionBegin;
11: /* Alias the LMVM matrix into the TAO hessian */
12: PetscCall(MatDestroy(&tao->hessian));
13: PetscCall(MatDestroy(&tao->hessian_pre));
14: PetscCall(PetscObjectReference((PetscObject)bqnk->B));
15: tao->hessian = bqnk->B;
16: PetscCall(PetscObjectReference((PetscObject)bqnk->B));
17: tao->hessian_pre = bqnk->B;
18: /* Update the Hessian with the latest solution */
19: if (bqnk->is_spd) {
20: gnorm2 = bnk->gnorm * bnk->gnorm;
21: if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
22: if (bnk->f == 0.0) delta = 2.0 / gnorm2;
23: else delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
24: PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta));
25: }
26: PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient));
27: PetscCall(MatLMVMResetShift(tao->hessian));
28: /* Prepare the reduced sub-matrices for the inactive set */
29: PetscCall(MatDestroy(&bnk->H_inactive));
30: if (bnk->active_idx) {
31: PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive));
32: PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx));
33: } else {
34: PetscCall(PetscObjectReference((PetscObject)tao->hessian));
35: bnk->H_inactive = tao->hessian;
36: PetscCall(PCLMVMClearIS(bqnk->pc));
37: }
38: PetscCall(MatDestroy(&bnk->Hpre_inactive));
39: PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive));
40: bnk->Hpre_inactive = bnk->H_inactive;
41: PetscFunctionReturn(PETSC_SUCCESS);
42: }
44: static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
45: {
46: TAO_BNK *bnk = (TAO_BNK *)tao->data;
47: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
49: PetscFunctionBegin;
50: PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type));
51: if (*ksp_reason < 0) {
52: /* Krylov solver failed to converge so reset the LMVM matrix */
53: PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
54: PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient));
55: }
56: PetscFunctionReturn(PETSC_SUCCESS);
57: }
59: PetscErrorCode TaoSolve_BQNK(Tao tao)
60: {
61: TAO_BNK *bnk = (TAO_BNK *)tao->data;
62: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
63: Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data;
64: Mat_LMVM *J0;
65: PetscBool flg = PETSC_FALSE;
67: PetscFunctionBegin;
68: if (!tao->recycle) {
69: PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
70: lmvm->nresets = 0;
71: if (lmvm->J0) {
72: PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg));
73: if (flg) {
74: J0 = (Mat_LMVM *)lmvm->J0->data;
75: J0->nresets = 0;
76: }
77: }
78: }
79: PetscCall((*bqnk->solve)(tao));
80: PetscFunctionReturn(PETSC_SUCCESS);
81: }
83: PetscErrorCode TaoSetUp_BQNK(Tao tao)
84: {
85: TAO_BNK *bnk = (TAO_BNK *)tao->data;
86: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
87: PetscInt n, N;
88: PetscBool is_lmvm, is_set, is_sym;
90: PetscFunctionBegin;
91: PetscCall(TaoSetUp_BNK(tao));
92: PetscCall(VecGetLocalSize(tao->solution, &n));
93: PetscCall(VecGetSize(tao->solution, &N));
94: PetscCall(MatSetSizes(bqnk->B, n, n, N, N));
95: PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient));
96: PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm));
97: PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type");
98: PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym));
99: PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric");
100: PetscCall(KSPGetPC(tao->ksp, &bqnk->pc));
101: PetscCall(PCSetType(bqnk->pc, PCLMVM));
102: PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B));
103: PetscFunctionReturn(PETSC_SUCCESS);
104: }
106: static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems PetscOptionsObject)
107: {
108: TAO_BNK *bnk = (TAO_BNK *)tao->data;
109: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
110: PetscBool is_set;
112: PetscFunctionBegin;
113: PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject));
114: if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
115: PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix));
116: PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_"));
117: PetscCall(MatSetFromOptions(bqnk->B));
118: PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd));
119: if (!is_set) bqnk->is_spd = PETSC_FALSE;
120: PetscFunctionReturn(PETSC_SUCCESS);
121: }
123: static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
124: {
125: TAO_BNK *bnk = (TAO_BNK *)tao->data;
126: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
127: PetscBool isascii;
129: PetscFunctionBegin;
130: PetscCall(TaoView_BNK(tao, viewer));
131: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
132: if (isascii) {
133: PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO));
134: PetscCall(MatView(bqnk->B, viewer));
135: PetscCall(PetscViewerPopFormat(viewer));
136: }
137: PetscFunctionReturn(PETSC_SUCCESS);
138: }
140: static PetscErrorCode TaoDestroy_BQNK(Tao tao)
141: {
142: TAO_BNK *bnk = (TAO_BNK *)tao->data;
143: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
145: PetscFunctionBegin;
146: PetscCall(MatDestroy(&bnk->Hpre_inactive));
147: PetscCall(MatDestroy(&bnk->H_inactive));
148: PetscCall(MatDestroy(&bqnk->B));
149: PetscCall(PetscFree(bnk->ctx));
150: PetscCall(TaoDestroy_BNK(tao));
151: PetscFunctionReturn(PETSC_SUCCESS);
152: }
154: PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
155: {
156: TAO_BNK *bnk;
157: TAO_BQNK *bqnk;
159: PetscFunctionBegin;
160: PetscCall(TaoCreate_BNK(tao));
161: tao->ops->solve = TaoSolve_BQNK;
162: tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
163: tao->ops->destroy = TaoDestroy_BQNK;
164: tao->ops->view = TaoView_BQNK;
165: tao->ops->setup = TaoSetUp_BQNK;
166: tao->uses_hessian_matrices = PETSC_FALSE;
168: bnk = (TAO_BNK *)tao->data;
169: bnk->computehessian = TaoBQNKComputeHessian;
170: bnk->computestep = TaoBQNKComputeStep;
171: bnk->init_type = BNK_INIT_DIRECTION;
173: PetscCall(PetscNew(&bqnk));
174: bnk->ctx = (void *)bqnk;
175: bqnk->is_spd = PETSC_TRUE;
177: PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B));
178: PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1));
179: PetscCall(MatSetType(bqnk->B, MATLMVMSR1));
180: PetscFunctionReturn(PETSC_SUCCESS);
181: }
183: /*@
184: TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
185: only for quasi-Newton family of methods.
187: Input Parameter:
188: . tao - `Tao` solver context
190: Output Parameter:
191: . B - LMVM matrix
193: Level: advanced
195: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
196: @*/
197: PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
198: {
199: TAO_BNK *bnk = (TAO_BNK *)tao->data;
200: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
201: PetscBool flg = PETSC_FALSE;
203: PetscFunctionBegin;
204: PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
205: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
206: *B = bqnk->B;
207: PetscFunctionReturn(PETSC_SUCCESS);
208: }
210: /*@
211: TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
212: only for quasi-Newton family of methods.
214: QN family of methods create their own LMVM matrices and users who wish to
215: manipulate this matrix should use TaoGetLMVMMatrix() instead.
217: Input Parameters:
218: + tao - Tao solver context
219: - B - LMVM matrix
221: Level: advanced
223: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
224: @*/
225: PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
226: {
227: TAO_BNK *bnk = (TAO_BNK *)tao->data;
228: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
229: PetscBool flg = PETSC_FALSE;
231: PetscFunctionBegin;
232: PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
233: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
234: PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg));
235: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix");
236: PetscCall(MatDestroy(&bqnk->B));
237: PetscCall(PetscObjectReference((PetscObject)B));
238: bqnk->B = B;
239: PetscFunctionReturn(PETSC_SUCCESS);
240: }