Actual source code: hypre2.hip.cxx
1: #include <petsc/private/petschypre.h>
2: #include <petscdevice_hip.h>
3: #include <../src/mat/impls/hypre/mhypre_kernels.hpp>
4: #include <../src/mat/impls/hypre/mhypre.h>
6: PetscErrorCode MatZeroRows_HIP(PetscInt n, const PetscInt rows[], const HYPRE_Int i[], const HYPRE_Int j[], HYPRE_Complex a[], HYPRE_Complex diag)
7: {
8: const PetscInt blkDimX = 16, blkDimY = 32;
9: PetscInt gridDimX = (n + blkDimX - 1) / blkDimX;
10: hipStream_t stream;
12: PetscFunctionBegin;
13: if (!n) PetscFunctionReturn(PETSC_SUCCESS);
14: PetscCall(PetscGetCurrentHIPStream(&stream));
15: hipLaunchKernelGGL(ZeroRows, dim3(gridDimX, 1), dim3(blkDimX, blkDimY), 0, stream, n, rows, i, j, a, diag);
16: PetscCallHIP(hipGetLastError());
17: PetscFunctionReturn(PETSC_SUCCESS);
18: }
20: PetscErrorCode PetscHypreIntCastArray_HIP(PetscInt n, const PetscInt *a, HYPRE_Int *b)
21: {
22: hipStream_t stream;
24: PetscFunctionBegin;
25: if (n) {
26: PetscCall(PetscGetCurrentHIPStream(&stream));
27: hipLaunchKernelGGL(CastArray, dim3((n + 255) / 256), dim3(256), 0, stream, n, a, b);
28: PetscCallHIP(hipGetLastError());
29: }
30: PetscFunctionReturn(PETSC_SUCCESS);
31: }
33: PetscErrorCode MatHypreDeviceMalloc_HIP(size_t size, void **ptr)
34: {
35: PetscFunctionBegin;
36: if (size) PetscCallHIP(hipMalloc(ptr, size));
37: else *ptr = NULL;
38: PetscFunctionReturn(PETSC_SUCCESS);
39: }
41: PetscErrorCode MatHypreDeviceFree_HIP(void *a)
42: {
43: PetscFunctionBegin;
44: PetscCallHIP(hipFree(a));
45: PetscFunctionReturn(PETSC_SUCCESS);
46: }