Partial Evaluation
When compiling functions with Reactant, the function arguments (and possible closure fields) may contain non-Reactant values, i.e. numbers and arrays that are not of type Reactant.AbstractConcreteNumber or Reactant.AbstractConcreteArray.
The Reactant compiler may (but is not guaranteed to) treat these non-Reactant values as constant and partially evaluate the function to be compiled based on this.
For example, the function
using Reactant
function add(a, b)
a + b
end;
# output
add (generic function with 1 method)when compiled with two ConcreteRNumber arguments
using Reactant
x = ConcreteRNumber(3)
y = ConcreteRNumber(4)
addxy = @compile add(x, y)
addxy(x, y)
# output
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(7)returns a result that depends on both arguments x and y:
addxy(ConcreteRNumber(7), ConcreteRNumber(8))
# output
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(15)The StableHLO IR code generated here is:
@code_hlo add(x, y)
# output
module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}, %arg1: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
return %0 : tensor<i64>
}
}So at HLO-level, there a are two variable inputs %arg0 and %arg1.
However, if argument y has a non-Reactant value during compilation, (4 in this example) then the result when executing the compiled function
addx4 = @compile add(x, 4)
addx4(x, 4)
# output
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(7)will only change based on x, not on the non-Reactant argument y, we get 7 + 4 == 11, not 7 + 8 == 15:
addx4(ConcreteRNumber(7), 8)
# output
ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(11)The StableHLO code shows that the second argument has been replaced by a constant %c during partial evaluation. When the compiled function is executed, the value of y is ignored - at HLO-level, there is only one variable input %arg0:
@code_hlo add(x, 4)
# output
module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
%c = stablehlo.constant dense<4> : tensor<i64>
%0 = stablehlo.add %arg0, %c : tensor<i64>
return %0 : tensor<i64>
}
}