Skip to content

Raising

Raising GPU Kernels

Kernel raising refer to Reactant's ability to transform a program written in a GPU kernel style. That is, kernel functions which are evaluated in a grid of blocks and threads where operations are done at the scalar level. The transformation raises the program to a tensor style function (in the StableHLO dialect) where operations are broadcasted.

This transformation enables several features:

  • Running the raised compute kernel on hardware where the original kernel was not designed to run on (i.e. running a CUDA kernel on a TPU).

  • Enabling further optimizations, since the raised kernel is now indiscernible from the rest of the program, it can be optimized with it. For example, two sequential kernel launches operating on the result of each others can be fused if they are both raised. Resulting in a single kernel launch, in the final optimized StableHLO program.

  • Lastly, automatic-differentiation in Reactant is currently not supported for GPU kernels. Raising kernels enables Enzyme to differentiate the raised kernel. For this to function, one must use the raise_first compilation option to make sure the kernel are raised before Enzyme performs automatic-differentiation on the program.

Note

Not all classes of kernels are currently raisable to StableHLO. If your kernel encounters an error while being raised, please open an issue on the Reactant.jl repository.

Example

julia
using Reactant
using KernelAbstractions
using CUDA # needs to be loaded for raising even if CUDA is not functional on your system

Tip

We could have also directly implemented the kernel using CUDA.jl instead of KernelAbstractions.jl.

We will implement a simple kernel to compute the square of a vector.

julia
@kernel function square_kernel!(y, @Const(x))
    i = @index(Global)
    @inbounds y[i] = x[i] * x[i]
end

function square(x)
    y = similar(x)
    backend = KernelAbstractions.get_backend(x)
    kernel! = square_kernel!(backend)
    kernel!(y, x; ndrange=length(x))
    return y
end
square (generic function with 1 method)
julia
x = Reactant.to_rarray(collect(1:1:64) ./ 64)

Let's see what the HLO IR looks like for this function. Note that raising is automatically enabled for backends like TPU, where the original kernel was not designed to run on. To enable raising on other backends, pass the raise=true option.

julia
@code_hlo raise=true square(x)
module @reactant_square attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  llvm.module_flags [#llvm.mlir.module_flag<warning, "Dwarf Version", 2 : i32>, #llvm.mlir.module_flag<warning, "Debug Info Version", 3 : i32>]
  func.func @main(%arg0: tensor<64xf64> {enzymexla.memory_effects = []}) -> tensor<64xf64> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<64xf64>
    return %0 : tensor<64xf64>
  }
}

Raising Scalar Loops to Tensor IR

We will implement a simple N body simulation code in Reactant. Instead of using broadcasting or high-level abstractions, we will use loops and scalar operations to implement this.

julia
using Reactant, PrettyChairmarks

Reactant.allowscalar(true) # generally not recommended to turn on globally
Warning: It's not recommended to use allowscalar([true]) to allow scalar indexing.
Instead, use `allowscalar() do end` or `@allowscalar` to denote exactly which operations can use scalar operations.
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:184

We will implement a naive function to compute the attractive force between each pair of particles in a system.

julia
function compute_attractive_force(
    positions::AbstractMatrix, masses::AbstractVector, G::Number
)
    N = size(positions, 2)
    F = similar(positions, N, N)

    @trace for i in 1:N
        @trace for j in 1:N
            dx = positions[1, i] - positions[1, j]
            dy = positions[2, i] - positions[2, j]
            dz = positions[3, i] - positions[3, j]

            invr² = ifelse(i == j, dx, inv(dx^2 + dy^2 + dz^2))

            Fx = G * masses[i] * masses[j] * invr² * dx
            Fy = G * masses[i] * masses[j] * invr² * dy
            Fz = G * masses[i] * masses[j] * invr² * dz
            F[i, j] = Fx + Fy + Fz
        end
    end

    return F
end
compute_attractive_force (generic function with 1 method)
julia
positions = randn(Float32, 3, 1024)
masses = rand(Float32, 1024) .* 10

positions_ra = Reactant.to_rarray(positions)
masses_ra = Reactant.to_rarray(masses)

Let's see what the HLO IR looks like for this function (without enabling the loop raising).

julia
@code_hlo compile_options = CompileOptions(;
    disable_auto_batching_passes=true
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
module @reactant_compute... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<1024x3xf32> {enzymexla.memory_effects = []}, %arg1: tensor<1024xf32> {enzymexla.memory_effects = []}) -> tensor<1024x1024xf32> attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<1x1xf32>
    %c = stablehlo.constant dense<2> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %c_1 = stablehlo.constant dense<1024> : tensor<i64>
    %cst_2 = stablehlo.constant dense<2.000000e+00> : tensor<1x1xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<1024x1024xf32>
    %c_4 = stablehlo.constant dense<1> : tensor<i32>
    %c_5 = stablehlo.constant dense<0> : tensor<i64>
    %c_6 = stablehlo.constant dense<1> : tensor<i64>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1024x3xf32>) -> tensor<3x1024xf32>
    %1 = stablehlo.reshape %arg1 : (tensor<1024xf32>) -> tensor<1024x1xf32>
    %2:2 = stablehlo.while(%iterArg = %c_5, %iterArg_7 = %cst_3) : tensor<i64>, tensor<1024x1024xf32> attributes {enzyme.disable_mincut, enzymexla.symmetric_matrix = [#enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed NOTGUARANTEED>, #enzymexla<guaranteed UNKNOWN>, #enzymexla<guaranteed UNKNOWN>]}
    cond {
      %3 = stablehlo.compare  LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
      stablehlo.return %3 : tensor<i1>
    } do {
      %3 = stablehlo.add %c_6, %iterArg : tensor<i64>
      %4 = stablehlo.convert %3 : (tensor<i64>) -> tensor<i32>
      %5 = stablehlo.subtract %4, %c_4 : tensor<i32>
      %6:2 = stablehlo.while(%iterArg_8 = %c_5, %iterArg_9 = %iterArg_7) : tensor<i64>, tensor<1024x1024xf32> attributes {enzyme.disable_mincut}
      cond {
        %7 = stablehlo.compare  LT, %iterArg_8, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
        stablehlo.return %7 : tensor<i1>
      } do {
        %7 = stablehlo.add %c_6, %iterArg_8 : tensor<i64>
        %8 = stablehlo.dynamic_slice %0, %c_0, %5, sizes = [1, 1] : (tensor<3x1024xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %9 = stablehlo.convert %7 : (tensor<i64>) -> tensor<i32>
        %10 = stablehlo.subtract %9, %c_4 : tensor<i32>
        %11 = stablehlo.dynamic_slice %0, %c_0, %10, sizes = [1, 1] : (tensor<3x1024xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %12 = stablehlo.dynamic_slice %0, %c_4, %5, sizes = [1, 1] : (tensor<3x1024xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %13 = stablehlo.dynamic_slice %0, %c_4, %10, sizes = [1, 1] : (tensor<3x1024xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %14 = stablehlo.dynamic_slice %0, %c, %5, sizes = [1, 1] : (tensor<3x1024xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %15 = stablehlo.dynamic_slice %0, %c, %10, sizes = [1, 1] : (tensor<3x1024xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %16 = stablehlo.dynamic_slice %1, %5, %c_0, sizes = [1, 1] : (tensor<1024x1xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %17 = stablehlo.multiply %cst_2, %16 : tensor<1x1xf32>
        %18 = stablehlo.dynamic_slice %1, %10, %c_0, sizes = [1, 1] : (tensor<1024x1xf32>, tensor<i32>, tensor<i32>) -> tensor<1x1xf32>
        %19 = stablehlo.multiply %17, %18 : tensor<1x1xf32>
        %20 = stablehlo.compare  EQ, %3, %7 : (tensor<i64>, tensor<i64>) -> tensor<i1>
        %21 = stablehlo.reshape %20 : (tensor<i1>) -> tensor<1x1xi1>
        %22 = stablehlo.subtract %8, %11 : tensor<1x1xf32>
        %23 = stablehlo.multiply %22, %22 : tensor<1x1xf32>
        %24 = stablehlo.subtract %12, %13 : tensor<1x1xf32>
        %25 = stablehlo.multiply %24, %24 : tensor<1x1xf32>
        %26 = stablehlo.add %23, %25 : tensor<1x1xf32>
        %27 = stablehlo.subtract %14, %15 : tensor<1x1xf32>
        %28 = stablehlo.multiply %27, %27 : tensor<1x1xf32>
        %29 = stablehlo.add %26, %28 : tensor<1x1xf32>
        %30 = stablehlo.divide %cst, %29 : tensor<1x1xf32>
        %31 = stablehlo.select %21, %22, %30 : tensor<1x1xi1>, tensor<1x1xf32>
        %32 = stablehlo.multiply %19, %31 : tensor<1x1xf32>
        %33 = stablehlo.add %22, %24 : tensor<1x1xf32>
        %34 = stablehlo.add %33, %27 : tensor<1x1xf32>
        %35 = stablehlo.multiply %32, %34 : tensor<1x1xf32>
        %36 = stablehlo.dynamic_update_slice %iterArg_9, %35, %10, %5 : (tensor<1024x1024xf32>, tensor<1x1xf32>, tensor<i32>, tensor<i32>) -> tensor<1024x1024xf32>
        stablehlo.return %7, %36 : tensor<i64>, tensor<1024x1024xf32>
      }
      stablehlo.return %3, %6#1 : tensor<i64>, tensor<1024x1024xf32>
    }
    return %2#1 : tensor<1024x1024xf32>
  }
}

This IR has a nested loop, but that won't work nicely for GPUs/TPUs. Even for CPUs, XLA often doens't do a great job with loops. By default, we will attempt to raise loops to a tensor IR.

julia
hlo = @code_hlo compute_attractive_force(positions_ra, masses_ra, 2.0f0)
hlo
module @reactant_compute... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<1024x3xf32> {enzymexla.memory_effects = []}, %arg1: tensor<1024xf32> {enzymexla.memory_effects = []}) -> tensor<1024x1024xf32> attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024x1xf32>
    %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<1024x1xf32>
    %c = stablehlo.constant dense<1> : tensor<1024xi64>
    %0 = stablehlo.slice %arg0 [0:1024, 2:3] : (tensor<1024x3xf32>) -> tensor<1024x1xf32>
    %1 = stablehlo.slice %arg0 [0:1024, 1:2] : (tensor<1024x3xf32>) -> tensor<1024x1xf32>
    %2 = stablehlo.slice %arg0 [0:1024, 0:1] : (tensor<1024x3xf32>) -> tensor<1024x1xf32>
    %3 = stablehlo.reshape %arg1 : (tensor<1024xf32>) -> tensor<1024x1xf32>
    %4 = stablehlo.iota dim = 0 : tensor<1024xi64>
    %5 = stablehlo.add %4, %c : tensor<1024xi64>
    %6 = stablehlo.broadcast_in_dim %0, dims = [0, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1x1xf32>
    %7 = stablehlo.multiply %cst_0, %3 : tensor<1024x1xf32>
    %8 = stablehlo.broadcast_in_dim %0, dims = [1, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1x1xf32>
    %9 = stablehlo.subtract %6, %8 : tensor<1024x1024x1x1xf32>
    %10 = stablehlo.broadcast_in_dim %1, dims = [0, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1xf32>
    %11 = stablehlo.broadcast_in_dim %1, dims = [1, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1xf32>
    %12 = stablehlo.subtract %10, %11 : tensor<1024x1024x1xf32>
    %13 = stablehlo.multiply %12, %12 : tensor<1024x1024x1xf32>
    %14 = stablehlo.broadcast_in_dim %2, dims = [0, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1xf32>
    %15 = stablehlo.broadcast_in_dim %2, dims = [1, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1xf32>
    %16 = stablehlo.subtract %14, %15 : tensor<1024x1024x1xf32>
    %17 = stablehlo.add %16, %12 : tensor<1024x1024x1xf32>
    %18 = stablehlo.broadcast_in_dim %7, dims = [0, 2] : (tensor<1024x1xf32>) -> tensor<1024x1024x1xf32>
    %19 = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<1024xf32>) -> tensor<1024x1024x1xf32>
    %20 = stablehlo.multiply %18, %19 : tensor<1024x1024x1xf32>
    %21 = stablehlo.multiply %16, %16 : tensor<1024x1024x1xf32>
    %22 = stablehlo.multiply %9, %9 : tensor<1024x1024x1x1xf32>
    %23 = stablehlo.reshape %22 : (tensor<1024x1024x1x1xf32>) -> tensor<1024x1024x1xf32>
    %24 = stablehlo.add %21, %13 : tensor<1024x1024x1xf32>
    %25 = stablehlo.add %24, %23 : tensor<1024x1024x1xf32>
    %26 = stablehlo.divide %cst, %25 : tensor<1024x1024x1xf32>
    %27 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1024xi64>) -> tensor<1024x1024x1xi64>
    %28 = stablehlo.broadcast_in_dim %5, dims = [1] : (tensor<1024xi64>) -> tensor<1024x1024x1xi64>
    %29 = stablehlo.compare  EQ, %27, %28 : (tensor<1024x1024x1xi64>, tensor<1024x1024x1xi64>) -> tensor<1024x1024x1xi1>
    %30 = stablehlo.select %29, %16, %26 : tensor<1024x1024x1xi1>, tensor<1024x1024x1xf32>
    %31 = stablehlo.broadcast_in_dim %17, dims = [1, 0, 2] : (tensor<1024x1024x1xf32>) -> tensor<1024x1024x1x1xf32>
    %32 = stablehlo.transpose %9, dims = [1, 0, 2, 3] : (tensor<1024x1024x1x1xf32>) -> tensor<1024x1024x1x1xf32>
    %33 = stablehlo.multiply %20, %30 : tensor<1024x1024x1xf32>
    %34 = stablehlo.broadcast_in_dim %33, dims = [1, 0, 2] : (tensor<1024x1024x1xf32>) -> tensor<1024x1024x1x1xf32>
    %35 = stablehlo.add %31, %32 : tensor<1024x1024x1x1xf32>
    %36 = stablehlo.multiply %34, %35 : tensor<1024x1024x1x1xf32>
    %37 = stablehlo.reshape %36 : (tensor<1024x1024x1x1xf32>) -> tensor<1024x1024xf32>
    return %37 : tensor<1024x1024xf32>
  }
}

This IR won't have any loops, instead it will be written in a tensor IR! Let ensure that the values are identical.

julia
y_jl = compute_attractive_force(positions, masses, 2.0f0)
y_ra = @jit compute_attractive_force(positions_ra, masses_ra, 2.0f0)
maximum(abs, Array(y_ra) .- y_jl)
0.00024414062f0

Let's time the execution of the two versions.

julia
fn1 = @compile sync=true compile_options=CompileOptions(;
    disable_auto_batching_passes=true
) compute_attractive_force(positions_ra, masses_ra, 2.0f0)
fn2 = @compile sync=true compute_attractive_force(positions_ra, masses_ra, 2.0f0)
Reactant compiled function compute_attractive_force (with tag ##compute_attractive_force_reactant#125028)

Runtime for non-raised function:

julia
@bs fn1(positions_ra, masses_ra, 2.0f0)
Chairmarks.Benchmark: 14 samples with 1 evaluation.
 Range (minmax):  6.851 ms  7.471 ms GC (min … max): 0.00% … 0.00%
 Time  (median):     7.143 ms                GC (median):    0.00%
 Time  (mean ± σ):   7.154 ms ± 161.839 μs GC (mean ± σ):  0.00% ± 0.00%

  

  6.85 ms         Histogram: frequency by time        7.47 ms <

 Memory estimate: 416.0 bytes, allocs estimate: 14.

Runtime for raised function:

julia
@bs fn2(positions_ra, masses_ra, 2.0f0)
Chairmarks.Benchmark: 108 samples with 1 evaluation.
 Range (minmax):  613.427 μs 1.162 ms GC (min … max): 0.00% … 0.00%
 Time  (median):     852.133 μs               GC (median):    0.00%
 Time  (mean ± σ):   886.722 μs ± 98.353 μs GC (mean ± σ):  0.00% ± 0.00%

                                   

  613 μs          Histogram: frequency by time         1.09 ms <

 Memory estimate: 416.0 bytes, allocs estimate: 14.