Actual source code: submat.c

  1: #include <petsc/private/matimpl.h>

  3: typedef struct {
  4:   IS         isrow, iscol;   /* rows and columns in submatrix, only used to check consistency */
  5:   Vec        lwork, rwork;   /* work vectors inside the scatters */
  6:   Vec        lwork2, rwork2; /* work vectors inside the scatters */
  7:   VecScatter lrestrict, rprolong;
  8:   Mat        A;
  9: } Mat_SubVirtual;

 11: static PetscErrorCode MatScale_SubMatrix(Mat N, PetscScalar a)
 12: {
 13:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 15:   PetscFunctionBegin;
 16:   PetscCall(MatScale(Na->A, a));
 17:   PetscFunctionReturn(PETSC_SUCCESS);
 18: }

 20: static PetscErrorCode MatShift_SubMatrix(Mat N, PetscScalar a)
 21: {
 22:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 24:   PetscFunctionBegin;
 25:   PetscCall(MatShift(Na->A, a));
 26:   PetscFunctionReturn(PETSC_SUCCESS);
 27: }

 29: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N, Vec left, Vec right)
 30: {
 31:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 33:   PetscFunctionBegin;
 34:   if (right) {
 35:     PetscCall(VecZeroEntries(Na->rwork));
 36:     PetscCall(VecScatterBegin(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 37:     PetscCall(VecScatterEnd(Na->rprolong, right, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 38:   }
 39:   if (left) {
 40:     PetscCall(VecZeroEntries(Na->lwork));
 41:     PetscCall(VecScatterBegin(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 42:     PetscCall(VecScatterEnd(Na->lrestrict, left, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 43:   }
 44:   PetscCall(MatDiagonalScale(Na->A, left ? Na->lwork : NULL, right ? Na->rwork : NULL));
 45:   PetscFunctionReturn(PETSC_SUCCESS);
 46: }

 48: static PetscErrorCode MatGetDiagonal_SubMatrix(Mat N, Vec d)
 49: {
 50:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 52:   PetscFunctionBegin;
 53:   PetscCall(MatGetDiagonal(Na->A, Na->rwork));
 54:   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
 55:   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, d, INSERT_VALUES, SCATTER_REVERSE));
 56:   PetscFunctionReturn(PETSC_SUCCESS);
 57: }

 59: static PetscErrorCode MatMult_SubMatrix(Mat N, Vec x, Vec y)
 60: {
 61:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 63:   PetscFunctionBegin;
 64:   PetscCall(VecZeroEntries(Na->rwork));
 65:   PetscCall(VecScatterBegin(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 66:   PetscCall(VecScatterEnd(Na->rprolong, x, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 67:   PetscCall(MatMult(Na->A, Na->rwork, Na->lwork));
 68:   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
 69:   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, y, INSERT_VALUES, SCATTER_FORWARD));
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: static PetscErrorCode MatMultAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
 74: {
 75:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

 77:   PetscFunctionBegin;
 78:   PetscCall(VecZeroEntries(Na->rwork));
 79:   PetscCall(VecScatterBegin(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 80:   PetscCall(VecScatterEnd(Na->rprolong, v1, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
 81:   if (v1 == v2) {
 82:     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->rwork, Na->lwork));
 83:   } else if (v2 == v3) {
 84:     PetscCall(VecZeroEntries(Na->lwork));
 85:     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 86:     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
 87:     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork, Na->lwork));
 88:   } else {
 89:     if (!Na->lwork2) {
 90:       PetscCall(VecDuplicate(Na->lwork, &Na->lwork2));
 91:     } else {
 92:       PetscCall(VecZeroEntries(Na->lwork2));
 93:     }
 94:     PetscCall(VecScatterBegin(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
 95:     PetscCall(VecScatterEnd(Na->lrestrict, v2, Na->lwork2, INSERT_VALUES, SCATTER_REVERSE));
 96:     PetscCall(MatMultAdd(Na->A, Na->rwork, Na->lwork2, Na->lwork));
 97:   }
 98:   PetscCall(VecScatterBegin(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
 99:   PetscCall(VecScatterEnd(Na->lrestrict, Na->lwork, v3, INSERT_VALUES, SCATTER_FORWARD));
100:   PetscFunctionReturn(PETSC_SUCCESS);
101: }

103: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N, Vec x, Vec y)
104: {
105:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

107:   PetscFunctionBegin;
108:   PetscCall(VecZeroEntries(Na->lwork));
109:   PetscCall(VecScatterBegin(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
110:   PetscCall(VecScatterEnd(Na->lrestrict, x, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
111:   PetscCall(MatMultTranspose(Na->A, Na->lwork, Na->rwork));
112:   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
113:   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, y, INSERT_VALUES, SCATTER_REVERSE));
114:   PetscFunctionReturn(PETSC_SUCCESS);
115: }

117: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N, Vec v1, Vec v2, Vec v3)
118: {
119:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

121:   PetscFunctionBegin;
122:   PetscCall(VecZeroEntries(Na->lwork));
123:   PetscCall(VecScatterBegin(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
124:   PetscCall(VecScatterEnd(Na->lrestrict, v1, Na->lwork, INSERT_VALUES, SCATTER_REVERSE));
125:   if (v1 == v2) {
126:     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->lwork, Na->rwork));
127:   } else if (v2 == v3) {
128:     PetscCall(VecZeroEntries(Na->rwork));
129:     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
130:     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork, INSERT_VALUES, SCATTER_FORWARD));
131:     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork, Na->rwork));
132:   } else {
133:     if (!Na->rwork2) {
134:       PetscCall(VecDuplicate(Na->rwork, &Na->rwork2));
135:     } else {
136:       PetscCall(VecZeroEntries(Na->rwork2));
137:     }
138:     PetscCall(VecScatterBegin(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
139:     PetscCall(VecScatterEnd(Na->rprolong, v2, Na->rwork2, INSERT_VALUES, SCATTER_FORWARD));
140:     PetscCall(MatMultTransposeAdd(Na->A, Na->lwork, Na->rwork2, Na->rwork));
141:   }
142:   PetscCall(VecScatterBegin(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
143:   PetscCall(VecScatterEnd(Na->rprolong, Na->rwork, v3, INSERT_VALUES, SCATTER_REVERSE));
144:   PetscFunctionReturn(PETSC_SUCCESS);
145: }

147: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
148: {
149:   Mat_SubVirtual *Na = (Mat_SubVirtual *)N->data;

151:   PetscFunctionBegin;
152:   PetscCall(ISDestroy(&Na->isrow));
153:   PetscCall(ISDestroy(&Na->iscol));
154:   PetscCall(VecDestroy(&Na->lwork));
155:   PetscCall(VecDestroy(&Na->rwork));
156:   PetscCall(VecDestroy(&Na->lwork2));
157:   PetscCall(VecDestroy(&Na->rwork2));
158:   PetscCall(VecScatterDestroy(&Na->lrestrict));
159:   PetscCall(VecScatterDestroy(&Na->rprolong));
160:   PetscCall(MatDestroy(&Na->A));
161:   PetscCall(PetscFree(N->data));
162:   PetscFunctionReturn(PETSC_SUCCESS);
163: }

165: /*@
166:   MatCreateSubMatrixVirtual - Creates a virtual matrix `MATSUBMATRIX` that acts as a submatrix

168:   Collective

170:   Input Parameters:
171: + A     - matrix that we will extract a submatrix of
172: . isrow - rows to be present in the submatrix
173: - iscol - columns to be present in the submatrix

175:   Output Parameter:
176: . newmat - new matrix

178:   Level: developer

180:   Note:
181:   Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.

183: .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MATLOCALREF`, `MatCreateLocalRef()`, `MatCreateSubMatrix()`, `MatSubMatrixVirtualUpdate()`
184: @*/
185: PetscErrorCode MatCreateSubMatrixVirtual(Mat A, IS isrow, IS iscol, Mat *newmat)
186: {
187:   Vec             left, right;
188:   PetscInt        m, n;
189:   Mat             N;
190:   Mat_SubVirtual *Na;

192:   PetscFunctionBegin;
196:   PetscAssertPointer(newmat, 4);
197:   *newmat = NULL;

199:   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), &N));
200:   PetscCall(ISGetLocalSize(isrow, &m));
201:   PetscCall(ISGetLocalSize(iscol, &n));
202:   PetscCall(MatSetSizes(N, m, n, PETSC_DETERMINE, PETSC_DETERMINE));
203:   PetscCall(PetscObjectChangeTypeName((PetscObject)N, MATSUBMATRIX));

205:   PetscCall(PetscNew(&Na));
206:   N->data = (void *)Na;

208:   PetscCall(PetscObjectReference((PetscObject)isrow));
209:   PetscCall(PetscObjectReference((PetscObject)iscol));
210:   Na->isrow = isrow;
211:   Na->iscol = iscol;

213:   PetscCall(PetscFree(N->defaultvectype));
214:   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
215:   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
216:      the reference count of the context. This is a problem if A is already of type MATSHELL */
217:   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));

219:   N->ops->destroy          = MatDestroy_SubMatrix;
220:   N->ops->mult             = MatMult_SubMatrix;
221:   N->ops->multadd          = MatMultAdd_SubMatrix;
222:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
223:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
224:   N->ops->scale            = MatScale_SubMatrix;
225:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
226:   N->ops->shift            = MatShift_SubMatrix;
227:   N->ops->convert          = MatConvert_Shell;
228:   N->ops->getdiagonal      = MatGetDiagonal_SubMatrix;

230:   PetscCall(MatSetBlockSizesFromMats(N, A, A));
231:   PetscCall(PetscLayoutSetUp(N->rmap));
232:   PetscCall(PetscLayoutSetUp(N->cmap));

234:   PetscCall(MatCreateVecs(A, &Na->rwork, &Na->lwork));
235:   PetscCall(MatCreateVecs(N, &right, &left));
236:   PetscCall(VecScatterCreate(Na->lwork, isrow, left, NULL, &Na->lrestrict));
237:   PetscCall(VecScatterCreate(right, NULL, Na->rwork, iscol, &Na->rprolong));
238:   PetscCall(VecDestroy(&left));
239:   PetscCall(VecDestroy(&right));
240:   PetscCall(MatSetUp(N));

242:   N->assembled = PETSC_TRUE;
243:   *newmat      = N;
244:   PetscFunctionReturn(PETSC_SUCCESS);
245: }

247: /*MC
248:    MATSUBMATRIX - "submatrix" - A matrix type that represents a virtual submatrix of a matrix

250:   Level: advanced

252:    Developer Note:
253:    The `MatType` is `MATSUBMATRIX` but the routines associated have `SubMatrixVirtual` in them, the `MatType` name should likely be changed to
254:    `MATSUBMATRIXVIRTUAL`

256: .seealso: [](ch_matrices), `Mat`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrixVirtual()`, `MatCreateSubMatrix()`
257: M*/

259: /*@
260:   MatSubMatrixVirtualUpdate - Updates a `MATSUBMATRIX` virtual submatrix

262:   Collective

264:   Input Parameters:
265: + N     - submatrix to update
266: . A     - full matrix in the submatrix
267: . isrow - rows in the update (same as the first time the submatrix was created)
268: - iscol - columns in the update (same as the first time the submatrix was created)

270:   Level: developer

272:   Note:
273:   Most will use `MatCreateSubMatrix()` which provides a more efficient representation if it is available.

275: .seealso: [](ch_matrices), `Mat`, `MATSUBMATRIX`, `MatCreateSubMatrixVirtual()`
276: @*/
277: PetscErrorCode MatSubMatrixVirtualUpdate(Mat N, Mat A, IS isrow, IS iscol)
278: {
279:   PetscBool       flg;
280:   Mat_SubVirtual *Na;

282:   PetscFunctionBegin;
287:   PetscCall(PetscObjectTypeCompare((PetscObject)N, MATSUBMATRIX, &flg));
288:   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Matrix has wrong type");

290:   Na = (Mat_SubVirtual *)N->data;
291:   PetscCall(ISEqual(isrow, Na->isrow, &flg));
292:   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different row indices");
293:   PetscCall(ISEqual(iscol, Na->iscol, &flg));
294:   PetscCheck(flg, PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Cannot update submatrix with different column indices");

296:   PetscCall(PetscFree(N->defaultvectype));
297:   PetscCall(PetscStrallocpy(A->defaultvectype, &N->defaultvectype));
298:   PetscCall(MatDestroy(&Na->A));
299:   /* Do not use MatConvert directly since MatShell has a duplicate operation which does not increase
300:      the reference count of the context. This is a problem if A is already of type MATSHELL */
301:   PetscCall(MatConvertFrom_Shell(A, MATSHELL, MAT_INITIAL_MATRIX, &Na->A));
302:   PetscFunctionReturn(PETSC_SUCCESS);
303: }