Actual source code: mcomposite.c

  1: #include <../src/mat/impls/shell/shell.h>

  3: const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL};

  5: typedef struct _Mat_CompositeLink *Mat_CompositeLink;
  6: struct _Mat_CompositeLink {
  7:   Mat               mat;
  8:   Vec               work;
  9:   Mat_CompositeLink next, prev;
 10: };

 12: typedef struct {
 13:   MatCompositeType      type;
 14:   Mat_CompositeLink     head, tail;
 15:   Vec                   work;
 16:   PetscInt              nmat;
 17:   PetscBool             merge;
 18:   MatCompositeMergeType mergetype;
 19:   MatStructure          structure;

 21:   PetscScalar *scalings;
 22:   PetscBool    merge_mvctx; /* Whether need to merge mvctx of component matrices */
 23:   Vec         *lvecs;       /* [nmat] Basically, they are Mvctx->lvec of each component matrix */
 24:   PetscScalar *larray;      /* [len] Data arrays of lvecs[] are stored consecutively in larray */
 25:   PetscInt     len;         /* Length of larray[] */
 26:   Vec          gvec;        /* Union of lvecs[] without duplicated entries */
 27:   PetscInt    *location;    /* A map that maps entries in garray[] to larray[] */
 28:   VecScatter   Mvctx;
 29: } Mat_Composite;

 31: static PetscErrorCode MatDestroy_Composite(Mat mat)
 32: {
 33:   Mat_Composite    *shell;
 34:   Mat_CompositeLink next, oldnext;
 35:   PetscInt          i;

 37:   PetscFunctionBegin;
 38:   PetscCall(MatShellGetContext(mat, &shell));
 39:   next = shell->head;
 40:   while (next) {
 41:     PetscCall(MatDestroy(&next->mat));
 42:     if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work));
 43:     oldnext = next;
 44:     next    = next->next;
 45:     PetscCall(PetscFree(oldnext));
 46:   }
 47:   PetscCall(VecDestroy(&shell->work));

 49:   if (shell->Mvctx) {
 50:     for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i]));
 51:     PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs));
 52:     PetscCall(PetscFree(shell->larray));
 53:     PetscCall(VecDestroy(&shell->gvec));
 54:     PetscCall(VecScatterDestroy(&shell->Mvctx));
 55:   }

 57:   PetscCall(PetscFree(shell->scalings));
 58:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL));
 59:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL));
 60:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL));
 61:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL));
 62:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL));
 63:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL));
 64:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL));
 65:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL));
 66:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL));
 67:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL));
 68:   PetscCall(PetscFree(shell));
 69:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable()
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y)
 74: {
 75:   Mat_Composite    *shell;
 76:   Mat_CompositeLink next;
 77:   Vec               out;

 79:   PetscFunctionBegin;
 80:   PetscCall(MatShellGetContext(A, &shell));
 81:   next = shell->head;
 82:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
 83:   while (next->next) {
 84:     if (!next->work) { /* should reuse previous work if the same size */
 85:       PetscCall(MatCreateVecs(next->mat, NULL, &next->work));
 86:     }
 87:     out = next->work;
 88:     PetscCall(MatMult(next->mat, x, out));
 89:     x    = out;
 90:     next = next->next;
 91:   }
 92:   PetscCall(MatMult(next->mat, x, y));
 93:   if (shell->scalings) {
 94:     PetscScalar scale = 1.0;
 95:     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
 96:     PetscCall(VecScale(y, scale));
 97:   }
 98:   PetscFunctionReturn(PETSC_SUCCESS);
 99: }

101: static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y)
102: {
103:   Mat_Composite    *shell;
104:   Mat_CompositeLink tail;
105:   Vec               out;

107:   PetscFunctionBegin;
108:   PetscCall(MatShellGetContext(A, &shell));
109:   tail = shell->tail;
110:   PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
111:   while (tail->prev) {
112:     if (!tail->prev->work) { /* should reuse previous work if the same size */
113:       PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work));
114:     }
115:     out = tail->prev->work;
116:     PetscCall(MatMultTranspose(tail->mat, x, out));
117:     x    = out;
118:     tail = tail->prev;
119:   }
120:   PetscCall(MatMultTranspose(tail->mat, x, y));
121:   if (shell->scalings) {
122:     PetscScalar scale = 1.0;
123:     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
124:     PetscCall(VecScale(y, scale));
125:   }
126:   PetscFunctionReturn(PETSC_SUCCESS);
127: }

129: static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y)
130: {
131:   Mat_Composite     *shell;
132:   Mat_CompositeLink  cur;
133:   Vec                y2, xin;
134:   Mat                A, B;
135:   PetscInt           i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot;
136:   const PetscScalar *vals;
137:   const PetscInt    *garray;
138:   IS                 ix, iy;
139:   PetscBool          match;

141:   PetscFunctionBegin;
142:   PetscCall(MatShellGetContext(mat, &shell));
143:   cur = shell->head;
144:   PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");

146:   /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time
147:      we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and
148:      it is legal to merge Mvctx, because all component matrices have the same size.
149:    */
150:   if (shell->merge_mvctx && !shell->Mvctx) {
151:     /* Currently only implemented for MATMPIAIJ */
152:     for (cur = shell->head; cur; cur = cur->next) {
153:       PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match));
154:       if (!match) {
155:         shell->merge_mvctx = PETSC_FALSE;
156:         goto skip_merge_mvctx;
157:       }
158:     }

160:     /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */
161:     tot = 0;
162:     for (cur = shell->head; cur; cur = cur->next) {
163:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL));
164:       PetscCall(MatGetLocalSize(B, NULL, &n));
165:       tot += n;
166:     }
167:     PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs));
168:     shell->len = tot;

170:     /* Go through matrices second time to sort off-diag columns and remove dups */
171:     PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to petsc and free the other */
172:     PetscCall(PetscMalloc1(tot, &buf));
173:     nuniq = 0; /* Number of unique nonzero columns */
174:     for (cur = shell->head; cur; cur = cur->next) {
175:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
176:       PetscCall(MatGetLocalSize(B, NULL, &n));
177:       /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */
178:       i = j = k = 0;
179:       while (i < n && j < nuniq) {
180:         if (garray[i] < gindices[j]) buf[k++] = garray[i++];
181:         else if (garray[i] > gindices[j]) buf[k++] = gindices[j++];
182:         else {
183:           buf[k++] = garray[i++];
184:           j++;
185:         }
186:       }
187:       /* Copy leftover in garray[] or gindices[] */
188:       if (i < n) {
189:         PetscCall(PetscArraycpy(buf + k, garray + i, n - i));
190:         nuniq = k + n - i;
191:       } else if (j < nuniq) {
192:         PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j));
193:         nuniq = k + nuniq - j;
194:       } else nuniq = k;
195:       /* Swap gindices and buf to merge garray of the next matrix */
196:       tmp      = gindices;
197:       gindices = buf;
198:       buf      = tmp;
199:     }
200:     PetscCall(PetscFree(buf));

202:     /* Go through matrices third time to build a map from gindices[] to garray[] */
203:     tot = 0;
204:     for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */
205:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
206:       PetscCall(MatGetLocalSize(B, NULL, &n));
207:       PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j]));
208:       /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */
209:       lo = 0;
210:       for (i = 0; i < n; i++) {
211:         hi = nuniq;
212:         while (hi - lo > 1) {
213:           mid = lo + (hi - lo) / 2;
214:           if (garray[i] < gindices[mid]) hi = mid;
215:           else lo = mid;
216:         }
217:         shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */
218:         lo++;                          /* Since garray[i+1] > garray[i], we can safely advance lo */
219:       }
220:       tot += n;
221:     }

223:     /* Build merged Mvctx */
224:     PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix));
225:     PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy));
226:     PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin));
227:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec));
228:     PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx));
229:     PetscCall(VecDestroy(&xin));
230:     PetscCall(ISDestroy(&ix));
231:     PetscCall(ISDestroy(&iy));
232:   }

234: skip_merge_mvctx:
235:   PetscCall(VecSet(y, 0));
236:   if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work)));
237:   y2 = ((Mat_Shell *)mat->data)->left_work;

239:   if (shell->Mvctx) { /* Have a merged Mvctx */
240:     /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do
241:        in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to
242:        overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach.
243:      */
244:     PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
245:     PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));

247:     PetscCall(VecGetArrayRead(shell->gvec, &vals));
248:     for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]];
249:     PetscCall(VecRestoreArrayRead(shell->gvec, &vals));

251:     for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */
252:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL));
253:       PetscUseTypeMethod(A, mult, x, y2);
254:       PetscCall(MatGetLocalSize(B, NULL, &n));
255:       PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot]));
256:       PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2);
257:       PetscCall(VecResetArray(shell->lvecs[i]));
258:       PetscCall(VecAXPY(y, shell->scalings ? shell->scalings[i] : 1.0, y2));
259:       tot += n;
260:     }
261:   } else {
262:     if (shell->scalings) {
263:       for (cur = shell->head, i = 0; cur; cur = cur->next, i++) {
264:         PetscCall(MatMult(cur->mat, x, y2));
265:         PetscCall(VecAXPY(y, shell->scalings[i], y2));
266:       }
267:     } else {
268:       for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y));
269:     }
270:   }
271:   PetscFunctionReturn(PETSC_SUCCESS);
272: }

274: static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y)
275: {
276:   Mat_Composite    *shell;
277:   Mat_CompositeLink next;
278:   Vec               y2 = NULL;
279:   PetscInt          i;

281:   PetscFunctionBegin;
282:   PetscCall(MatShellGetContext(A, &shell));
283:   next = shell->head;
284:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");

286:   PetscCall(MatMultTranspose(next->mat, x, y));
287:   if (shell->scalings) {
288:     PetscCall(VecScale(y, shell->scalings[0]));
289:     if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work)));
290:     y2 = ((Mat_Shell *)A->data)->right_work;
291:   }
292:   i = 1;
293:   while ((next = next->next)) {
294:     if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y));
295:     else {
296:       PetscCall(MatMultTranspose(next->mat, x, y2));
297:       PetscCall(VecAXPY(y, shell->scalings[i++], y2));
298:     }
299:   }
300:   PetscFunctionReturn(PETSC_SUCCESS);
301: }

303: static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v)
304: {
305:   Mat_Composite    *shell;
306:   Mat_CompositeLink next;
307:   PetscInt          i;

309:   PetscFunctionBegin;
310:   PetscCall(MatShellGetContext(A, &shell));
311:   next = shell->head;
312:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
313:   PetscCall(MatGetDiagonal(next->mat, v));
314:   if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0]));

316:   if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work));
317:   i = 1;
318:   while ((next = next->next)) {
319:     PetscCall(MatGetDiagonal(next->mat, shell->work));
320:     PetscCall(VecAXPY(v, shell->scalings ? shell->scalings[i++] : 1.0, shell->work));
321:   }
322:   PetscFunctionReturn(PETSC_SUCCESS);
323: }

325: static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t)
326: {
327:   Mat_Composite *shell;

329:   PetscFunctionBegin;
330:   PetscCall(MatShellGetContext(Y, &shell));
331:   if (shell->merge) PetscCall(MatCompositeMerge(Y));
332:   else PetscCall(MatAssemblyEnd_Shell(Y, t));
333:   PetscFunctionReturn(PETSC_SUCCESS);
334: }

336: static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems *PetscOptionsObject)
337: {
338:   Mat_Composite *a;

340:   PetscFunctionBegin;
341:   PetscCall(MatShellGetContext(A, &a));
342:   PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options");
343:   PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL));
344:   PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL));
345:   PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL));
346:   PetscOptionsHeadEnd();
347:   PetscFunctionReturn(PETSC_SUCCESS);
348: }

350: /*@
351:   MatCreateComposite - Creates a matrix as the sum or product of one or more matrices

353:   Collective

355:   Input Parameters:
356: + comm - MPI communicator
357: . nmat - number of matrices to put in
358: - mats - the matrices

360:   Output Parameter:
361: . mat - the matrix

363:   Options Database Keys:
364: + -mat_composite_merge       - merge in `MatAssemblyEnd()`
365: . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices
366: - -mat_composite_merge_type  - set merge direction

368:   Level: advanced

370:   Note:
371:   Alternative construction
372: .vb
373:        MatCreate(comm,&mat);
374:        MatSetSizes(mat,m,n,M,N);
375:        MatSetType(mat,MATCOMPOSITE);
376:        MatCompositeAddMat(mat,mats[0]);
377:        ....
378:        MatCompositeAddMat(mat,mats[nmat-1]);
379:        MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);
380:        MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);
381: .ve

383:   For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0]

385: .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`,
386:           `MATCOMPOSITE`, `MatCompositeType`
387: @*/
388: PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat)
389: {
390:   PetscFunctionBegin;
391:   PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix");
392:   PetscAssertPointer(mat, 4);
393:   PetscCall(MatCreate(comm, mat));
394:   PetscCall(MatSetType(*mat, MATCOMPOSITE));
395:   for (PetscInt i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i]));
396:   PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY));
397:   PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY));
398:   PetscFunctionReturn(PETSC_SUCCESS);
399: }

401: static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat)
402: {
403:   Mat_Composite    *shell;
404:   Mat_CompositeLink ilink, next;
405:   VecType           vtype_mat, vtype_smat;
406:   PetscBool         match;

408:   PetscFunctionBegin;
409:   PetscCall(MatShellGetContext(mat, &shell));
410:   next = shell->head;
411:   PetscCall(PetscNew(&ilink));
412:   ilink->next = NULL;
413:   PetscCall(PetscObjectReference((PetscObject)smat));
414:   ilink->mat = smat;

416:   if (!next) shell->head = ilink;
417:   else {
418:     while (next->next) next = next->next;
419:     next->next  = ilink;
420:     ilink->prev = next;
421:   }
422:   shell->tail = ilink;
423:   shell->nmat += 1;

425:   /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type.
426:      Otherwise, the default type should be "standard". */
427:   PetscCall(MatGetVecType(smat, &vtype_smat));
428:   if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat));
429:   else {
430:     PetscCall(MatGetVecType(mat, &vtype_mat));
431:     PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match));
432:     if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD));
433:   }

435:   /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */
436:   if (shell->scalings) {
437:     PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings));
438:     shell->scalings[shell->nmat - 1] = 1.0;
439:   }

441:   /* The composite matrix requires PetscLayouts for its rows and columns; we copy these from the constituent partial matrices. */
442:   if (shell->nmat == 1) PetscCall(PetscLayoutReference(smat->cmap, &mat->cmap));
443:   PetscCall(PetscLayoutReference(smat->rmap, &mat->rmap));
444:   PetscFunctionReturn(PETSC_SUCCESS);
445: }

447: /*@
448:   MatCompositeAddMat - Add another matrix to a composite matrix.

450:   Collective

452:   Input Parameters:
453: + mat  - the composite matrix
454: - smat - the partial matrix

456:   Level: advanced

458: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
459: @*/
460: PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat)
461: {
462:   PetscFunctionBegin;
465:   PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat));
466:   PetscFunctionReturn(PETSC_SUCCESS);
467: }

469: static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type)
470: {
471:   Mat_Composite *b;

473:   PetscFunctionBegin;
474:   PetscCall(MatShellGetContext(mat, &b));
475:   b->type = type;
476:   if (type == MAT_COMPOSITE_MULTIPLICATIVE) {
477:     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL));
478:     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite_Multiplicative));
479:     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite_Multiplicative));
480:     b->merge_mvctx = PETSC_FALSE;
481:   } else {
482:     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite));
483:     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite));
484:     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite));
485:   }
486:   PetscFunctionReturn(PETSC_SUCCESS);
487: }

489: /*@
490:   MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product.

492:   Logically Collective

494:   Input Parameters:
495: + mat  - the composite matrix
496: - type - the `MatCompositeType` to use for the matrix

498:   Level: advanced

500: .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`,
501:           `MatCompositeType`
502: @*/
503: PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type)
504: {
505:   PetscFunctionBegin;
508:   PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type));
509:   PetscFunctionReturn(PETSC_SUCCESS);
510: }

512: static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type)
513: {
514:   Mat_Composite *shell;

516:   PetscFunctionBegin;
517:   PetscCall(MatShellGetContext(mat, &shell));
518:   *type = shell->type;
519:   PetscFunctionReturn(PETSC_SUCCESS);
520: }

522: /*@
523:   MatCompositeGetType - Returns type of composite.

525:   Not Collective

527:   Input Parameter:
528: . mat - the composite matrix

530:   Output Parameter:
531: . type - type of composite

533:   Level: advanced

535: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType`
536: @*/
537: PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type)
538: {
539:   PetscFunctionBegin;
541:   PetscAssertPointer(type, 2);
542:   PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type));
543:   PetscFunctionReturn(PETSC_SUCCESS);
544: }

546: static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str)
547: {
548:   Mat_Composite *shell;

550:   PetscFunctionBegin;
551:   PetscCall(MatShellGetContext(mat, &shell));
552:   shell->structure = str;
553:   PetscFunctionReturn(PETSC_SUCCESS);
554: }

556: /*@
557:   MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix.

559:   Not Collective

561:   Input Parameters:
562: + mat - the composite matrix
563: - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN`

565:   Level: advanced

567:   Note:
568:   Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix.

570: .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE`
571: @*/
572: PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str)
573: {
574:   PetscFunctionBegin;
576:   PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str));
577:   PetscFunctionReturn(PETSC_SUCCESS);
578: }

580: static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str)
581: {
582:   Mat_Composite *shell;

584:   PetscFunctionBegin;
585:   PetscCall(MatShellGetContext(mat, &shell));
586:   *str = shell->structure;
587:   PetscFunctionReturn(PETSC_SUCCESS);
588: }

590: /*@
591:   MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix.

593:   Not Collective

595:   Input Parameter:
596: . mat - the composite matrix

598:   Output Parameter:
599: . str - structure of the matrices

601:   Level: advanced

603: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE`
604: @*/
605: PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str)
606: {
607:   PetscFunctionBegin;
609:   PetscAssertPointer(str, 2);
610:   PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str));
611:   PetscFunctionReturn(PETSC_SUCCESS);
612: }

614: static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type)
615: {
616:   Mat_Composite *shell;

618:   PetscFunctionBegin;
619:   PetscCall(MatShellGetContext(mat, &shell));
620:   shell->mergetype = type;
621:   PetscFunctionReturn(PETSC_SUCCESS);
622: }

624: /*@
625:   MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`.

627:   Logically Collective

629:   Input Parameters:
630: + mat  - the composite matrix
631: - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]),
632:           `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1])

634:   Level: advanced

636:   Note:
637:   The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed.
638:   If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0])))
639:   otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0].

641: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE`
642: @*/
643: PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type)
644: {
645:   PetscFunctionBegin;
648:   PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type));
649:   PetscFunctionReturn(PETSC_SUCCESS);
650: }

652: static PetscErrorCode MatCompositeMerge_Composite(Mat mat)
653: {
654:   Mat_Composite    *shell;
655:   Mat_CompositeLink next, prev;
656:   Mat               tmat, newmat;
657:   Vec               left, right, dshift;
658:   PetscScalar       scale, shift;
659:   PetscInt          i;

661:   PetscFunctionBegin;
662:   PetscCall(MatShellGetContext(mat, &shell));
663:   next = shell->head;
664:   prev = shell->tail;
665:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
666:   PetscCall(MatShellGetScalingShifts(mat, &shift, &scale, &dshift, &left, &right, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
667:   if (shell->type == MAT_COMPOSITE_ADDITIVE) {
668:     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
669:       i = 0;
670:       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
671:       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++]));
672:       while ((next = next->next)) PetscCall(MatAXPY(tmat, shell->scalings ? shell->scalings[i++] : 1.0, next->mat, shell->structure));
673:     } else {
674:       i = shell->nmat - 1;
675:       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
676:       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--]));
677:       while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, shell->scalings ? shell->scalings[i--] : 1.0, prev->mat, shell->structure));
678:     }
679:   } else {
680:     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
681:       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
682:       while ((next = next->next)) {
683:         PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &newmat));
684:         PetscCall(MatDestroy(&tmat));
685:         tmat = newmat;
686:       }
687:     } else {
688:       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
689:       while ((prev = prev->prev)) {
690:         PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &newmat));
691:         PetscCall(MatDestroy(&tmat));
692:         tmat = newmat;
693:       }
694:     }
695:     if (shell->scalings) {
696:       for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
697:     }
698:   }

700:   if (left) PetscCall(PetscObjectReference((PetscObject)left));
701:   if (right) PetscCall(PetscObjectReference((PetscObject)right));
702:   if (dshift) PetscCall(PetscObjectReference((PetscObject)dshift));

704:   PetscCall(MatHeaderReplace(mat, &tmat));

706:   PetscCall(MatDiagonalScale(mat, left, right));
707:   PetscCall(MatScale(mat, scale));
708:   PetscCall(MatShift(mat, shift));
709:   PetscCall(VecDestroy(&left));
710:   PetscCall(VecDestroy(&right));
711:   if (dshift) {
712:     PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES));
713:     PetscCall(VecDestroy(&dshift));
714:   }
715:   PetscFunctionReturn(PETSC_SUCCESS);
716: }

718: /*@
719:   MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix
720:   by summing or computing the product of all the matrices inside the composite matrix.

722:   Collective

724:   Input Parameter:
725: . mat - the composite matrix

727:   Options Database Keys:
728: + -mat_composite_merge      - merge in `MatAssemblyEnd()`
729: - -mat_composite_merge_type - set merge direction

731:   Level: advanced

733:   Note:
734:   The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix.

736: .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE`
737: @*/
738: PetscErrorCode MatCompositeMerge(Mat mat)
739: {
740:   PetscFunctionBegin;
742:   PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat));
743:   PetscFunctionReturn(PETSC_SUCCESS);
744: }

746: static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat)
747: {
748:   Mat_Composite *shell;

750:   PetscFunctionBegin;
751:   PetscCall(MatShellGetContext(mat, &shell));
752:   *nmat = shell->nmat;
753:   PetscFunctionReturn(PETSC_SUCCESS);
754: }

756: /*@
757:   MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix.

759:   Not Collective

761:   Input Parameter:
762: . mat - the composite matrix

764:   Output Parameter:
765: . nmat - number of matrices in the composite matrix

767:   Level: advanced

769: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
770: @*/
771: PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat)
772: {
773:   PetscFunctionBegin;
775:   PetscAssertPointer(nmat, 2);
776:   PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat));
777:   PetscFunctionReturn(PETSC_SUCCESS);
778: }

780: static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai)
781: {
782:   Mat_Composite    *shell;
783:   Mat_CompositeLink ilink;
784:   PetscInt          k;

786:   PetscFunctionBegin;
787:   PetscCall(MatShellGetContext(mat, &shell));
788:   PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat);
789:   ilink = shell->head;
790:   for (k = 0; k < i; k++) ilink = ilink->next;
791:   *Ai = ilink->mat;
792:   PetscFunctionReturn(PETSC_SUCCESS);
793: }

795: /*@
796:   MatCompositeGetMat - Returns the ith matrix from the composite matrix.

798:   Logically Collective

800:   Input Parameters:
801: + mat - the composite matrix
802: - i   - the number of requested matrix

804:   Output Parameter:
805: . Ai - ith matrix in composite

807:   Level: advanced

809: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE`
810: @*/
811: PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai)
812: {
813:   PetscFunctionBegin;
816:   PetscAssertPointer(Ai, 3);
817:   PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai));
818:   PetscFunctionReturn(PETSC_SUCCESS);
819: }

821: static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings)
822: {
823:   Mat_Composite *shell;
824:   PetscInt       nmat;

826:   PetscFunctionBegin;
827:   PetscCall(MatShellGetContext(mat, &shell));
828:   PetscCall(MatCompositeGetNumberMat(mat, &nmat));
829:   if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings));
830:   PetscCall(PetscArraycpy(shell->scalings, scalings, nmat));
831:   PetscFunctionReturn(PETSC_SUCCESS);
832: }

834: /*@
835:   MatCompositeSetScalings - Sets separate scaling factors for component matrices.

837:   Logically Collective

839:   Input Parameters:
840: + mat      - the composite matrix
841: - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat)

843:   Level: advanced

845: .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE`
846: @*/
847: PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings)
848: {
849:   PetscFunctionBegin;
851:   PetscAssertPointer(scalings, 2);
853:   PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings));
854:   PetscFunctionReturn(PETSC_SUCCESS);
855: }

857: /*MC
858:    MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices.
859:     The matrices need to have a correct size and parallel layout for the sum or product to be valid.

861:   Level: advanced

863:    Note:
864:    To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`);

866:   Developer Notes:
867:   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code

869:   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage

871: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`,
872:           `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()`
873: M*/

875: PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A)
876: {
877:   Mat_Composite *b;

879:   PetscFunctionBegin;
880:   PetscCall(PetscNew(&b));

882:   b->type        = MAT_COMPOSITE_ADDITIVE;
883:   b->nmat        = 0;
884:   b->merge       = PETSC_FALSE;
885:   b->mergetype   = MAT_COMPOSITE_MERGE_RIGHT;
886:   b->structure   = DIFFERENT_NONZERO_PATTERN;
887:   b->merge_mvctx = PETSC_TRUE;

889:   PetscCall(MatSetType(A, MATSHELL));
890:   PetscCall(MatShellSetContext(A, b));
891:   PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (void (*)(void))MatDestroy_Composite));
892:   PetscCall(MatShellSetOperation(A, MATOP_MULT, (void (*)(void))MatMult_Composite));
893:   PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite));
894:   PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite));
895:   PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (void (*)(void))MatAssemblyEnd_Composite));
896:   PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (void (*)(void))MatSetFromOptions_Composite));
897:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite));
898:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite));
899:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite));
900:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite));
901:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite));
902:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite));
903:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite));
904:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite));
905:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite));
906:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite));
907:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable));
908:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
909:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
910:   PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE));
911:   PetscFunctionReturn(PETSC_SUCCESS);
912: }