Actual source code: bqnk.c
1: #include <../src/tao/bound/impls/bqnk/bqnk.h>
2: #include <petscksp.h>
4: static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5: {
6: TAO_BNK *bnk = (TAO_BNK *)tao->data;
7: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
8: PetscReal gnorm2, delta;
10: PetscFunctionBegin;
11: /* Alias the LMVM matrix into the TAO hessian */
12: if (tao->hessian) PetscCall(MatDestroy(&tao->hessian));
13: if (tao->hessian_pre) PetscCall(MatDestroy(&tao->hessian_pre));
14: PetscCall(PetscObjectReference((PetscObject)bqnk->B));
15: tao->hessian = bqnk->B;
16: PetscCall(PetscObjectReference((PetscObject)bqnk->B));
17: tao->hessian_pre = bqnk->B;
18: /* Update the Hessian with the latest solution */
19: if (bqnk->is_spd) {
20: gnorm2 = bnk->gnorm * bnk->gnorm;
21: if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
22: if (bnk->f == 0.0) {
23: delta = 2.0 / gnorm2;
24: } else {
25: delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
26: }
27: PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta));
28: }
29: PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient));
30: PetscCall(MatLMVMResetShift(tao->hessian));
31: /* Prepare the reduced sub-matrices for the inactive set */
32: PetscCall(MatDestroy(&bnk->H_inactive));
33: if (bnk->active_idx) {
34: PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive));
35: PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx));
36: } else {
37: PetscCall(PetscObjectReference((PetscObject)tao->hessian));
38: bnk->H_inactive = tao->hessian;
39: PetscCall(PCLMVMClearIS(bqnk->pc));
40: }
41: PetscCall(MatDestroy(&bnk->Hpre_inactive));
42: PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive));
43: bnk->Hpre_inactive = bnk->H_inactive;
44: PetscFunctionReturn(PETSC_SUCCESS);
45: }
47: static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
48: {
49: TAO_BNK *bnk = (TAO_BNK *)tao->data;
50: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
52: PetscFunctionBegin;
53: PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type));
54: if (*ksp_reason < 0) {
55: /* Krylov solver failed to converge so reset the LMVM matrix */
56: PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
57: PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient));
58: }
59: PetscFunctionReturn(PETSC_SUCCESS);
60: }
62: PetscErrorCode TaoSolve_BQNK(Tao tao)
63: {
64: TAO_BNK *bnk = (TAO_BNK *)tao->data;
65: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
66: Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data;
67: Mat_LMVM *J0;
68: PetscBool flg = PETSC_FALSE;
70: PetscFunctionBegin;
71: if (!tao->recycle) {
72: PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
73: lmvm->nresets = 0;
74: if (lmvm->J0) {
75: PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg));
76: if (flg) {
77: J0 = (Mat_LMVM *)lmvm->J0->data;
78: J0->nresets = 0;
79: }
80: }
81: }
82: PetscCall((*bqnk->solve)(tao));
83: PetscFunctionReturn(PETSC_SUCCESS);
84: }
86: PetscErrorCode TaoSetUp_BQNK(Tao tao)
87: {
88: TAO_BNK *bnk = (TAO_BNK *)tao->data;
89: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
90: PetscInt n, N;
91: PetscBool is_lmvm, is_set, is_sym;
93: PetscFunctionBegin;
94: PetscCall(TaoSetUp_BNK(tao));
95: PetscCall(VecGetLocalSize(tao->solution, &n));
96: PetscCall(VecGetSize(tao->solution, &N));
97: PetscCall(MatSetSizes(bqnk->B, n, n, N, N));
98: PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient));
99: PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm));
100: PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type");
101: PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym));
102: PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric");
103: PetscCall(KSPGetPC(tao->ksp, &bqnk->pc));
104: PetscCall(PCSetType(bqnk->pc, PCLMVM));
105: PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B));
106: PetscFunctionReturn(PETSC_SUCCESS);
107: }
109: static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems PetscOptionsObject)
110: {
111: TAO_BNK *bnk = (TAO_BNK *)tao->data;
112: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
113: PetscBool is_set;
115: PetscFunctionBegin;
116: PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject));
117: if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
118: PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix));
119: PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_"));
120: PetscCall(MatSetFromOptions(bqnk->B));
121: PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd));
122: if (!is_set) bqnk->is_spd = PETSC_FALSE;
123: PetscFunctionReturn(PETSC_SUCCESS);
124: }
126: static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
127: {
128: TAO_BNK *bnk = (TAO_BNK *)tao->data;
129: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
130: PetscBool isascii;
132: PetscFunctionBegin;
133: PetscCall(TaoView_BNK(tao, viewer));
134: PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
135: if (isascii) {
136: PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO));
137: PetscCall(MatView(bqnk->B, viewer));
138: PetscCall(PetscViewerPopFormat(viewer));
139: }
140: PetscFunctionReturn(PETSC_SUCCESS);
141: }
143: static PetscErrorCode TaoDestroy_BQNK(Tao tao)
144: {
145: TAO_BNK *bnk = (TAO_BNK *)tao->data;
146: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
148: PetscFunctionBegin;
149: PetscCall(MatDestroy(&bnk->Hpre_inactive));
150: PetscCall(MatDestroy(&bnk->H_inactive));
151: PetscCall(MatDestroy(&bqnk->B));
152: PetscCall(PetscFree(bnk->ctx));
153: PetscCall(TaoDestroy_BNK(tao));
154: PetscFunctionReturn(PETSC_SUCCESS);
155: }
157: PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
158: {
159: TAO_BNK *bnk;
160: TAO_BQNK *bqnk;
162: PetscFunctionBegin;
163: PetscCall(TaoCreate_BNK(tao));
164: tao->ops->solve = TaoSolve_BQNK;
165: tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
166: tao->ops->destroy = TaoDestroy_BQNK;
167: tao->ops->view = TaoView_BQNK;
168: tao->ops->setup = TaoSetUp_BQNK;
170: bnk = (TAO_BNK *)tao->data;
171: bnk->computehessian = TaoBQNKComputeHessian;
172: bnk->computestep = TaoBQNKComputeStep;
173: bnk->init_type = BNK_INIT_DIRECTION;
175: PetscCall(PetscNew(&bqnk));
176: bnk->ctx = (void *)bqnk;
177: bqnk->is_spd = PETSC_TRUE;
179: PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B));
180: PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1));
181: PetscCall(MatSetType(bqnk->B, MATLMVMSR1));
182: PetscFunctionReturn(PETSC_SUCCESS);
183: }
185: /*@
186: TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
187: only for quasi-Newton family of methods.
189: Input Parameter:
190: . tao - `Tao` solver context
192: Output Parameter:
193: . B - LMVM matrix
195: Level: advanced
197: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
198: @*/
199: PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
200: {
201: TAO_BNK *bnk = (TAO_BNK *)tao->data;
202: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
203: PetscBool flg = PETSC_FALSE;
205: PetscFunctionBegin;
206: PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
207: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
208: *B = bqnk->B;
209: PetscFunctionReturn(PETSC_SUCCESS);
210: }
212: /*@
213: TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
214: only for quasi-Newton family of methods.
216: QN family of methods create their own LMVM matrices and users who wish to
217: manipulate this matrix should use TaoGetLMVMMatrix() instead.
219: Input Parameters:
220: + tao - Tao solver context
221: - B - LMVM matrix
223: Level: advanced
225: .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
226: @*/
227: PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
228: {
229: TAO_BNK *bnk = (TAO_BNK *)tao->data;
230: TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
231: PetscBool flg = PETSC_FALSE;
233: PetscFunctionBegin;
234: PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
235: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
236: PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg));
237: PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix");
238: if (bqnk->B) PetscCall(MatDestroy(&bqnk->B));
239: PetscCall(PetscObjectReference((PetscObject)B));
240: bqnk->B = B;
241: PetscFunctionReturn(PETSC_SUCCESS);
242: }