Actual source code: baijfact81.c
1: /*
2: Factorization code for BAIJ format.
3: */
4: #include <../src/mat/impls/baij/seq/baij.h>
5: #include <petsc/private/kernels/blockinvert.h>
6: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
7: #include <immintrin.h>
8: #endif
9: /*
10: Version for when blocks are 9 by 9
11: */
12: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
13: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
14: {
15: Mat C = B;
16: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
17: PetscInt i, j, k, nz, nzL, row;
18: const PetscInt n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
19: const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
20: MatScalar *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
21: PetscInt flg;
22: PetscReal shift = info->shiftamount;
23: PetscBool allowzeropivot, zeropivotdetected;
25: PetscFunctionBegin;
26: allowzeropivot = PetscNot(A->erroriffailure);
28: /* generate work space needed by the factorization */
29: PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
30: PetscCall(PetscArrayzero(rtmp, bs2 * n));
32: for (i = 0; i < n; i++) {
33: /* zero rtmp */
34: /* L part */
35: nz = bi[i + 1] - bi[i];
36: bjtmp = bj + bi[i];
37: for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
39: /* U part */
40: nz = bdiag[i] - bdiag[i + 1];
41: bjtmp = bj + bdiag[i + 1] + 1;
42: for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
44: /* load in initial (unfactored row) */
45: nz = ai[i + 1] - ai[i];
46: ajtmp = aj + ai[i];
47: v = aa + bs2 * ai[i];
48: for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2));
50: /* elimination */
51: bjtmp = bj + bi[i];
52: nzL = bi[i + 1] - bi[i];
53: for (k = 0; k < nzL; k++) {
54: row = bjtmp[k];
55: pc = rtmp + bs2 * row;
56: for (flg = 0, j = 0; j < bs2; j++) {
57: if (pc[j] != 0.0) {
58: flg = 1;
59: break;
60: }
61: }
62: if (flg) {
63: pv = b->a + bs2 * bdiag[row];
64: /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
65: PetscCall(PetscKernel_A_gets_A_times_B_9(pc, pv, mwork));
67: pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
68: pv = b->a + bs2 * (bdiag[row + 1] + 1);
69: nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries inU(row,:), excluding diag */
70: for (j = 0; j < nz; j++) {
71: /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
72: /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
73: v = rtmp + bs2 * pj[j];
74: PetscCall(PetscKernel_A_gets_A_minus_B_times_C_9(v, pc, pv + 81 * j));
75: /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
76: }
77: PetscCall(PetscLogFlops(1458 * nz + 1377)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
78: }
79: }
81: /* finished row so stick it into b->a */
82: /* L part */
83: pv = b->a + bs2 * bi[i];
84: pj = b->j + bi[i];
85: nz = bi[i + 1] - bi[i];
86: for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
88: /* Mark diagonal and invert diagonal for simpler triangular solves */
89: pv = b->a + bs2 * bdiag[i];
90: pj = b->j + bdiag[i];
91: PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
92: PetscCall(PetscKernel_A_gets_inverse_A_9(pv, shift, allowzeropivot, &zeropivotdetected));
93: if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
95: /* U part */
96: pv = b->a + bs2 * (bdiag[i + 1] + 1);
97: pj = b->j + bdiag[i + 1] + 1;
98: nz = bdiag[i] - bdiag[i + 1] - 1;
99: for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
100: }
101: PetscCall(PetscFree2(rtmp, mwork));
103: C->ops->solve = MatSolve_SeqBAIJ_9_NaturalOrdering;
104: C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
105: C->assembled = PETSC_TRUE;
107: PetscCall(PetscLogFlops(1.333333333333 * 9 * 9 * 9 * n)); /* from inverting diagonal blocks */
108: PetscFunctionReturn(PETSC_SUCCESS);
109: }
111: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A, Vec bb, Vec xx)
112: {
113: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
114: const PetscInt *ai = a->i, *aj = a->j, *adiag = a->diag, *vi;
115: PetscInt i, k, n = a->mbs;
116: PetscInt nz, bs = A->rmap->bs, bs2 = a->bs2;
117: const MatScalar *aa = a->a, *v;
118: PetscScalar *x, *s, *t, *ls;
119: const PetscScalar *b;
120: __m256d a0, a1, a2, a3, a4, a5, w0, w1, w2, w3, s0, s1, s2, v0, v1, v2, v3;
122: PetscFunctionBegin;
123: PetscCall(VecGetArrayRead(bb, &b));
124: PetscCall(VecGetArray(xx, &x));
125: t = a->solve_work;
127: /* forward solve the lower triangular */
128: PetscCall(PetscArraycpy(t, b, bs)); /* copy 1st block of b to t */
130: for (i = 1; i < n; i++) {
131: v = aa + bs2 * ai[i];
132: vi = aj + ai[i];
133: nz = ai[i + 1] - ai[i];
134: s = t + bs * i;
135: PetscCall(PetscArraycpy(s, b + bs * i, bs)); /* copy i_th block of b to t */
137: __m256d s0, s1, s2;
138: s0 = _mm256_loadu_pd(s + 0);
139: s1 = _mm256_loadu_pd(s + 4);
140: s2 = _mm256_maskload_pd(s + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
142: for (k = 0; k < nz; k++) {
143: w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
144: a0 = _mm256_loadu_pd(&v[0]);
145: s0 = _mm256_fnmadd_pd(a0, w0, s0);
146: a1 = _mm256_loadu_pd(&v[4]);
147: s1 = _mm256_fnmadd_pd(a1, w0, s1);
148: a2 = _mm256_loadu_pd(&v[8]);
149: s2 = _mm256_fnmadd_pd(a2, w0, s2);
151: w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
152: a3 = _mm256_loadu_pd(&v[9]);
153: s0 = _mm256_fnmadd_pd(a3, w1, s0);
154: a4 = _mm256_loadu_pd(&v[13]);
155: s1 = _mm256_fnmadd_pd(a4, w1, s1);
156: a5 = _mm256_loadu_pd(&v[17]);
157: s2 = _mm256_fnmadd_pd(a5, w1, s2);
159: w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
160: a0 = _mm256_loadu_pd(&v[18]);
161: s0 = _mm256_fnmadd_pd(a0, w2, s0);
162: a1 = _mm256_loadu_pd(&v[22]);
163: s1 = _mm256_fnmadd_pd(a1, w2, s1);
164: a2 = _mm256_loadu_pd(&v[26]);
165: s2 = _mm256_fnmadd_pd(a2, w2, s2);
167: w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
168: a3 = _mm256_loadu_pd(&v[27]);
169: s0 = _mm256_fnmadd_pd(a3, w3, s0);
170: a4 = _mm256_loadu_pd(&v[31]);
171: s1 = _mm256_fnmadd_pd(a4, w3, s1);
172: a5 = _mm256_loadu_pd(&v[35]);
173: s2 = _mm256_fnmadd_pd(a5, w3, s2);
175: w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
176: a0 = _mm256_loadu_pd(&v[36]);
177: s0 = _mm256_fnmadd_pd(a0, w0, s0);
178: a1 = _mm256_loadu_pd(&v[40]);
179: s1 = _mm256_fnmadd_pd(a1, w0, s1);
180: a2 = _mm256_loadu_pd(&v[44]);
181: s2 = _mm256_fnmadd_pd(a2, w0, s2);
183: w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
184: a3 = _mm256_loadu_pd(&v[45]);
185: s0 = _mm256_fnmadd_pd(a3, w1, s0);
186: a4 = _mm256_loadu_pd(&v[49]);
187: s1 = _mm256_fnmadd_pd(a4, w1, s1);
188: a5 = _mm256_loadu_pd(&v[53]);
189: s2 = _mm256_fnmadd_pd(a5, w1, s2);
191: w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
192: a0 = _mm256_loadu_pd(&v[54]);
193: s0 = _mm256_fnmadd_pd(a0, w2, s0);
194: a1 = _mm256_loadu_pd(&v[58]);
195: s1 = _mm256_fnmadd_pd(a1, w2, s1);
196: a2 = _mm256_loadu_pd(&v[62]);
197: s2 = _mm256_fnmadd_pd(a2, w2, s2);
199: w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
200: a3 = _mm256_loadu_pd(&v[63]);
201: s0 = _mm256_fnmadd_pd(a3, w3, s0);
202: a4 = _mm256_loadu_pd(&v[67]);
203: s1 = _mm256_fnmadd_pd(a4, w3, s1);
204: a5 = _mm256_loadu_pd(&v[71]);
205: s2 = _mm256_fnmadd_pd(a5, w3, s2);
207: w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
208: a0 = _mm256_loadu_pd(&v[72]);
209: s0 = _mm256_fnmadd_pd(a0, w0, s0);
210: a1 = _mm256_loadu_pd(&v[76]);
211: s1 = _mm256_fnmadd_pd(a1, w0, s1);
212: a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
213: s2 = _mm256_fnmadd_pd(a2, w0, s2);
214: v += bs2;
215: }
216: _mm256_storeu_pd(&s[0], s0);
217: _mm256_storeu_pd(&s[4], s1);
218: _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
219: }
221: /* backward solve the upper triangular */
222: ls = a->solve_work + A->cmap->n;
223: for (i = n - 1; i >= 0; i--) {
224: v = aa + bs2 * (adiag[i + 1] + 1);
225: vi = aj + adiag[i + 1] + 1;
226: nz = adiag[i] - adiag[i + 1] - 1;
227: PetscCall(PetscArraycpy(ls, t + i * bs, bs));
229: s0 = _mm256_loadu_pd(ls + 0);
230: s1 = _mm256_loadu_pd(ls + 4);
231: s2 = _mm256_maskload_pd(ls + 8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
233: for (k = 0; k < nz; k++) {
234: w0 = _mm256_set1_pd((t + bs * vi[k])[0]);
235: a0 = _mm256_loadu_pd(&v[0]);
236: s0 = _mm256_fnmadd_pd(a0, w0, s0);
237: a1 = _mm256_loadu_pd(&v[4]);
238: s1 = _mm256_fnmadd_pd(a1, w0, s1);
239: a2 = _mm256_loadu_pd(&v[8]);
240: s2 = _mm256_fnmadd_pd(a2, w0, s2);
242: /* v += 9; */
243: w1 = _mm256_set1_pd((t + bs * vi[k])[1]);
244: a3 = _mm256_loadu_pd(&v[9]);
245: s0 = _mm256_fnmadd_pd(a3, w1, s0);
246: a4 = _mm256_loadu_pd(&v[13]);
247: s1 = _mm256_fnmadd_pd(a4, w1, s1);
248: a5 = _mm256_loadu_pd(&v[17]);
249: s2 = _mm256_fnmadd_pd(a5, w1, s2);
251: /* v += 9; */
252: w2 = _mm256_set1_pd((t + bs * vi[k])[2]);
253: a0 = _mm256_loadu_pd(&v[18]);
254: s0 = _mm256_fnmadd_pd(a0, w2, s0);
255: a1 = _mm256_loadu_pd(&v[22]);
256: s1 = _mm256_fnmadd_pd(a1, w2, s1);
257: a2 = _mm256_loadu_pd(&v[26]);
258: s2 = _mm256_fnmadd_pd(a2, w2, s2);
260: /* v += 9; */
261: w3 = _mm256_set1_pd((t + bs * vi[k])[3]);
262: a3 = _mm256_loadu_pd(&v[27]);
263: s0 = _mm256_fnmadd_pd(a3, w3, s0);
264: a4 = _mm256_loadu_pd(&v[31]);
265: s1 = _mm256_fnmadd_pd(a4, w3, s1);
266: a5 = _mm256_loadu_pd(&v[35]);
267: s2 = _mm256_fnmadd_pd(a5, w3, s2);
269: /* v += 9; */
270: w0 = _mm256_set1_pd((t + bs * vi[k])[4]);
271: a0 = _mm256_loadu_pd(&v[36]);
272: s0 = _mm256_fnmadd_pd(a0, w0, s0);
273: a1 = _mm256_loadu_pd(&v[40]);
274: s1 = _mm256_fnmadd_pd(a1, w0, s1);
275: a2 = _mm256_loadu_pd(&v[44]);
276: s2 = _mm256_fnmadd_pd(a2, w0, s2);
278: /* v += 9; */
279: w1 = _mm256_set1_pd((t + bs * vi[k])[5]);
280: a3 = _mm256_loadu_pd(&v[45]);
281: s0 = _mm256_fnmadd_pd(a3, w1, s0);
282: a4 = _mm256_loadu_pd(&v[49]);
283: s1 = _mm256_fnmadd_pd(a4, w1, s1);
284: a5 = _mm256_loadu_pd(&v[53]);
285: s2 = _mm256_fnmadd_pd(a5, w1, s2);
287: /* v += 9; */
288: w2 = _mm256_set1_pd((t + bs * vi[k])[6]);
289: a0 = _mm256_loadu_pd(&v[54]);
290: s0 = _mm256_fnmadd_pd(a0, w2, s0);
291: a1 = _mm256_loadu_pd(&v[58]);
292: s1 = _mm256_fnmadd_pd(a1, w2, s1);
293: a2 = _mm256_loadu_pd(&v[62]);
294: s2 = _mm256_fnmadd_pd(a2, w2, s2);
296: /* v += 9; */
297: w3 = _mm256_set1_pd((t + bs * vi[k])[7]);
298: a3 = _mm256_loadu_pd(&v[63]);
299: s0 = _mm256_fnmadd_pd(a3, w3, s0);
300: a4 = _mm256_loadu_pd(&v[67]);
301: s1 = _mm256_fnmadd_pd(a4, w3, s1);
302: a5 = _mm256_loadu_pd(&v[71]);
303: s2 = _mm256_fnmadd_pd(a5, w3, s2);
305: /* v += 9; */
306: w0 = _mm256_set1_pd((t + bs * vi[k])[8]);
307: a0 = _mm256_loadu_pd(&v[72]);
308: s0 = _mm256_fnmadd_pd(a0, w0, s0);
309: a1 = _mm256_loadu_pd(&v[76]);
310: s1 = _mm256_fnmadd_pd(a1, w0, s1);
311: a2 = _mm256_maskload_pd(v + 80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
312: s2 = _mm256_fnmadd_pd(a2, w0, s2);
313: v += bs2;
314: }
316: _mm256_storeu_pd(&ls[0], s0);
317: _mm256_storeu_pd(&ls[4], s1);
318: _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), s2);
320: w0 = _mm256_setzero_pd();
321: w1 = _mm256_setzero_pd();
322: w2 = _mm256_setzero_pd();
324: /* first row */
325: v0 = _mm256_set1_pd(ls[0]);
326: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[0]);
327: w0 = _mm256_fmadd_pd(a0, v0, w0);
328: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[4]);
329: w1 = _mm256_fmadd_pd(a1, v0, w1);
330: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[8]);
331: w2 = _mm256_fmadd_pd(a2, v0, w2);
333: /* second row */
334: v1 = _mm256_set1_pd(ls[1]);
335: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[9]);
336: w0 = _mm256_fmadd_pd(a3, v1, w0);
337: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[13]);
338: w1 = _mm256_fmadd_pd(a4, v1, w1);
339: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[17]);
340: w2 = _mm256_fmadd_pd(a5, v1, w2);
342: /* third row */
343: v2 = _mm256_set1_pd(ls[2]);
344: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[18]);
345: w0 = _mm256_fmadd_pd(a0, v2, w0);
346: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[22]);
347: w1 = _mm256_fmadd_pd(a1, v2, w1);
348: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[26]);
349: w2 = _mm256_fmadd_pd(a2, v2, w2);
351: /* fourth row */
352: v3 = _mm256_set1_pd(ls[3]);
353: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[27]);
354: w0 = _mm256_fmadd_pd(a3, v3, w0);
355: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[31]);
356: w1 = _mm256_fmadd_pd(a4, v3, w1);
357: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[35]);
358: w2 = _mm256_fmadd_pd(a5, v3, w2);
360: /* fifth row */
361: v0 = _mm256_set1_pd(ls[4]);
362: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[36]);
363: w0 = _mm256_fmadd_pd(a0, v0, w0);
364: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[40]);
365: w1 = _mm256_fmadd_pd(a1, v0, w1);
366: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[44]);
367: w2 = _mm256_fmadd_pd(a2, v0, w2);
369: /* sixth row */
370: v1 = _mm256_set1_pd(ls[5]);
371: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[45]);
372: w0 = _mm256_fmadd_pd(a3, v1, w0);
373: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[49]);
374: w1 = _mm256_fmadd_pd(a4, v1, w1);
375: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[53]);
376: w2 = _mm256_fmadd_pd(a5, v1, w2);
378: /* seventh row */
379: v2 = _mm256_set1_pd(ls[6]);
380: a0 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[54]);
381: w0 = _mm256_fmadd_pd(a0, v2, w0);
382: a1 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[58]);
383: w1 = _mm256_fmadd_pd(a1, v2, w1);
384: a2 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[62]);
385: w2 = _mm256_fmadd_pd(a2, v2, w2);
387: /* eighth row */
388: v3 = _mm256_set1_pd(ls[7]);
389: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[63]);
390: w0 = _mm256_fmadd_pd(a3, v3, w0);
391: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[67]);
392: w1 = _mm256_fmadd_pd(a4, v3, w1);
393: a5 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[71]);
394: w2 = _mm256_fmadd_pd(a5, v3, w2);
396: /* ninth row */
397: v0 = _mm256_set1_pd(ls[8]);
398: a3 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[72]);
399: w0 = _mm256_fmadd_pd(a3, v0, w0);
400: a4 = _mm256_loadu_pd(&(aa + bs2 * adiag[i])[76]);
401: w1 = _mm256_fmadd_pd(a4, v0, w1);
402: a2 = _mm256_maskload_pd((&(aa + bs2 * adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63));
403: w2 = _mm256_fmadd_pd(a2, v0, w2);
405: _mm256_storeu_pd(&(t + i * bs)[0], w0);
406: _mm256_storeu_pd(&(t + i * bs)[4], w1);
407: _mm256_maskstore_pd(&(t + i * bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL << 63), w2);
409: PetscCall(PetscArraycpy(x + i * bs, t + i * bs, bs));
410: }
412: PetscCall(VecRestoreArrayRead(bb, &b));
413: PetscCall(VecRestoreArray(xx, &x));
414: PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
415: PetscFunctionReturn(PETSC_SUCCESS);
416: }
417: #endif