Partial Evaluation
When compiling functions with Reactant, the function arguments (and possible closure fields) may contain non-Reactant values, i.e. numbers and arrays that are not of type Reactant.AbstractConcreteNumber or Reactant.AbstractConcreteArray.
The Reactant compiler may (but is not guaranteed to) treat these non-Reactant values as constant and partially evaluate the function to be compiled based on this.
For example, the function
using Reactant
function add(a, b)
a + b
endadd (generic function with 1 method)when compiled with two ConcreteRNumber arguments
using Reactant
x = ConcreteRNumber(3)
y = ConcreteRNumber(4)
addxy = @compile add(x, y)
res = addxy(x, y)ConcretePJRTNumber{Int64, 1}(7)returns a result that depends on both arguments x and y:
res = addxy(ConcreteRNumber(7), ConcreteRNumber(8))ConcretePJRTNumber{Int64, 1}(15)The StableHLO IR code generated here is:
@code_hlo add(x, y)module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}, %arg1: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
return %0 : tensor<i64>
}
}So at HLO-level, there a are two variable inputs %arg0 and %arg1.
However, if argument y has a non-Reactant value during compilation, (4 in this example) then the result when executing the compiled function
addx4 = @compile add(x, 4)
res = addx4(x, 4)ConcretePJRTNumber{Int64, 1}(7)will only change based on x, not on the non-Reactant argument y, we get 7 + 4 == 11, not 7 + 8 == 15:
res = addx4(ConcreteRNumber(7), 8)ConcretePJRTNumber{Int64, 1}(11)The StableHLO code shows that the second argument has been replaced by a constant %c during partial evaluation. When the compiled function is executed, the value of y is ignored - at HLO-level, there is only one variable input %arg0:
@code_hlo add(x, 4)module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
%c = stablehlo.constant dense<4> : tensor<i64>
%0 = stablehlo.add %arg0, %c : tensor<i64>
return %0 : tensor<i64>
}
}