Core Reactant API
Compile API
Reactant.Compiler.@compile Macro
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)Compile the function f with arguments args and return the compiled function.
Note
Note that @compile foo(bar(x)) is equivalent to
y = bar(x) # first compute the output of `bar(x)`, say `y`
@compile foo(y) # then compile `foo` for `y`That is, like @jit, @compile only applies to the outermost function call; it does not compile the composed function foo(bar(x)) jointly. Hence, if you want to compile the composed function foo(bar(x)) jointly, you need to introduce an intermediate function, i.e.,
baz(x) = foo(bar(x))
@compile baz(x)Options
sync: Reactant computations are asynchronous by default. Iftrue, the computation will be executed synchronously, blocking till the computation is complete. This is recommended when benchmarking.compile_options: If provided, then all other compilation options will be ignored. This should be an object of typeCompileOptions.optimize: This option maps to theoptimization_passesfield ofCompileOptions. See the documentation ofCompileOptionsfor more details.client: XLA Client used for compilation. If not specified, the default client is used.
For details about other compilation options see the documentation of CompileOptions.
serializable: Iftrue, the compiled function will be serializable. This is needed for saving the compiled function to disk and loading it later. Defaults tofalse.
See also @jit, @code_hlo, @code_mhlo, @code_xla.
Reactant.Compiler.@jit Macro
@jit [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)Run @compile f(args..) then immediately execute it. Most users should use @compile instead to cache the compiled function and execute it later.
Note
Note that @jit foo(bar(x)) is equivalent to
y = bar(x) # first compute the output of `bar(x)`, say `y`
@jit foo(y) # then compile `foo` for `y` and execute itThat is, like @compile, @jit only applies to the outermost function call; it does not compile the composed function foo(bar(x)) jointly. Hence, if you want to compile the composed function foo(bar(x)) jointly, you need to introduce an intermediate function, i.e.,
baz(x) = foo(bar(x))
@jit baz(x)Options
sync: Reactant computations are asynchronous by default. Iftrue, the computation will be executed synchronously, blocking till the computation is complete. This is recommended when benchmarking.compile_options: If provided, then all other compilation options will be ignored. This should be an object of typeCompileOptions.optimize: This option maps to theoptimization_passesfield ofCompileOptions. See the documentation ofCompileOptionsfor more details.client: XLA Client used for compilation. If not specified, the default client is used.
For details about other compilation options see the documentation of CompileOptions.
See also @compile, @code_hlo, @code_mhlo, @code_xla.
ReactantCore API
ReactantCore.within_compile Function
within_compile()Returns true if this function is executed in a Reactant compilation context, otherwise false.
sourceReactantCore.@trace Macro
@trace [key = val,...] <expr>Converts certain expressions like control flow into a Reactant friendly form. Importantly, if no traced value is found inside the expression, then there is no overhead.
Currently Supported
ifconditions (withelseifand other niceties) (@trace if ...)ifstatements with a preceeding assignment (@trace a = if ...) (note the positioning of the macro needs to be before the assignment and not before theif)forstatements with a single induction variable iterating over integers with knownstepwhilestatements
Special Considerations
- Apply
@traceonly at the outermostif. Nestedifstatements will be automatically expanded into the correct form.
Extended Help
Caveats (Deviations from Core Julia Semantics)
New variables introduced
@trace if x > 0
y = x + 1
p = 1
else
y = x - 1
endIn the outer scope p is not defined if x ≤ 0. However, for the traced version, it is defined and set to a dummy value.
Short Circuiting Operations
@trace if x > 0 && z > 0
y = x + 1
else
y = x - 1
end&& and || are short circuiting operations. In the traced version, we replace them with & and | respectively.
Type-Unstable Branches
@trace if x > 0
y = 1.0f0
else
y = 1.0
endThis will not compile since y is a Float32 in one branch and a Float64 in the other. You need to ensure that all branches have the same type.
Another example is the following for loop which changes the type of x between iterations.
x = ... # ConcreteRArray{Int64, 1}
for i in 1f0:0.5f0:10f0
x = x .+ i # ConcreteRArray{Float32, 1}
endCertain Symbols are Reserved
Symbols like [:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core] are not allowed as variables in @trace expressions. While certain cases might work but these are not guaranteed to work. For example, the following will not work:
function fn(x)
nothing = sum(x)
@trace if nothing > 0
y = 1.0
else
y = 2.0
end
return y, nothing
endConfiguration
The behavior of loops can be configured with the following configuration options:
track_numbers::Union{Bool,Datatype}- whether Julia numbers should be automatically promoted to traced numbers upon entering the loop.checkpointing::Union{Bool,Periodic}- whether or not to enable checkpointing when performing reverse mode differentiation. Can befalse(default),true(automatic checkpointing), orPeriodic(n)to specifyncheckpoints. Whentrueis used, defaults toisqrt(num_iters)checkpoints forforloops with static (non-traced) bounds.Periodic(n)must be used forwhileloops orforloops with dynamic (traced) bounds when checkpointing is enabled.mincut::Bool- whether or not to enable the mincut algorithm when performing reverse mode differentiation (default:false).
ReactantCore.Periodic Type
Periodic(n::Int)Checkpointing strategy for traced loops that specifies periodic checkpointing with n checkpoints.
Examples
# Explicit periodic checkpointing with 4 checkpoints
@trace checkpointing=Periodic(4) for i in 1:100
x = x .+ 1
end
# Default periodic checkpointing (uses isqrt(num_iters) checkpoints for static bounds)
@trace checkpointing=true for i in 1:100
x = x .+ 1
endConverting Data
Reactant.to_rarray Function
to_rarray(x; track_numbers=false, sharding=NoSharding(), device=nothing, client=nothing, runtime=nothing)Convert a Julia value x into its Reactant equivalent by tracing through the structure. Arrays are converted to ConcreteRArray, and (optionally) scalar numbers are converted to ConcreteRNumber.
Keyword Arguments
track_numbers::Union{Bool, Type} = false: Controls whether plain Julia numbers are converted toConcreteRNumber.false(default): scalars are left as-is and will be treated as compile-time constants (frozen at tracing time).true: all scalar numbers are converted toConcreteRNumber.A type (e.g.
Number,Float64,Int): only scalars that are subtypes of the given type are tracked.
sharding: Sharding specification for the resulting array.device: Target device for the resulting array.client: XLA client to use.runtime: Backend runtime to use (Val(:PJRT)orVal(:IFRT)).
Examples
# Convert an array (always tracked)
x = Reactant.to_rarray([1.0, 2.0, 3.0]) # ConcreteRArray{Float64, 1}
# Convert a scalar WITHOUT tracking (default) — frozen at compile time
t = Reactant.to_rarray(0.5) # plain Float64
# Convert a scalar WITH tracking — varies at runtime
t = Reactant.to_rarray(0.5; track_numbers=true) # ConcreteRNumber{Float64}
# Convert a struct, tracking all number fields
struct Params; values::Vector{Float64}; scale::Float64; end
rparams = Reactant.to_rarray(Params([1.0], 2.0); track_numbers=true)See also: Partial Evaluation for how untracked values become compile-time constants.
sourceReactant data types
Reactant.ConcreteRArray Type
ConcreteRArray{T}(
undef, shape::Dims;
client::Union{Nothing,XLA.AbstractClient} = nothing,
device::Union{Nothing,XLA.AbstractDevice} = nothing,
sharding::Sharding.AbstractSharding = Sharding.NoSharding(),
)
ConcretePJRTArray{T}(undef, shape::Integer...; kwargs...)
ConcretePJRTArray(data::Array; kwargs...)Allocate an uninitialized ConcreteRArray of element type T and size shape or convert an Array to a ConcreteRArray.
Implementation
Depending on the Reactant xla_runtime preference setting, ConcreteRArray is an alias for ConcretePJRTArray or ConcreteIFRTArray. User code should use ConcreteRArray.
Reactant.ConcreteRNumber Type
ConcreteRNumber(
x::Number;
client::Union{Nothing,XLA.AbstractClient} = nothing,
device::Union{Nothing,XLA.AbstractDevice} = nothing,
sharding::Sharding.AbstractSharding = Sharding.NoSharding(),
)
ConcreteRNumber{T<:Number}(x; kwargs...)Wrap a Number in a ConcreteRNumber.
Implementation
Depending on the Reactant xla_runtime preference setting, ConcreteRArray is an alias for ConcretePJRTNumber or ConcreteIFRTNumber. User code should use ConcreteRNumber.
Inspect Generated HLO
Reactant.Compiler.@code_hlo Macro
@code_hlo [optimize = ...] [no_nan = <true/false>] f(args...)Prints the compiled MLIR module for the function f with arguments args.
Options
compile_options: If provided, then all other compilation options will be ignored. This should be an object of typeCompileOptions.optimize: This option maps to theoptimization_passesfield ofCompileOptions. See the documentation ofCompileOptionsfor more details.client: XLA Client used for compilation. If not specified, the default client is used.
For details about other compilation options see the documentation of CompileOptions.
See also @code_xla, @code_mhlo.
Reactant.Compiler.@code_mhlo Macro
@code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)Similar to @code_hlo, but runs additional passes to export the stablehlo module to MHLO.
Options
compile_options: If provided, then all other compilation options will be ignored. This should be an object of typeCompileOptions.optimize: This option maps to theoptimization_passesfield ofCompileOptions. See the documentation ofCompileOptionsfor more details.client: XLA Client used for compilation. If not specified, the default client is used.
For details about other compilation options see the documentation of CompileOptions.
See also @code_xla, @code_hlo.
Reactant.Compiler.@code_xla Macro
@code_xla [optimize = ...] [no_nan = <true/false>] f(args...)Similar to @code_hlo, but runs additional XLA passes and exports MLIR to XLA HLO. This is the post optimizations XLA HLO module.
Options
compile_options: If provided, then all other compilation options will be ignored. This should be an object of typeCompileOptions.optimize: This option maps to theoptimization_passesfield ofCompileOptions. See the documentation ofCompileOptionsfor more details.client: XLA Client used for compilation. If not specified, the default client is used.
For details about other compilation options see the documentation of CompileOptions.
before_xla_optimizations: Iftrue, return thebefore_optimizationsHLO module.
See also @code_mhlo, @code_hlo.
Reactant.Compiler.code_hlo Function
code_hlo(ctx, f, args; fn_kwargs = NamedTuple(), kwargs...)Compile the function f with arguments args and return the compiled MLIR module.
See also: @code_hlo.
Reactant.Compiler.code_mhlo Function
code_mhlo(ctx, f, args; fn_kwargs = NamedTuple(), kwargs...)Compile the function f with arguments args and return the compiled MLIR module.
See also: @code_mhlo.
Reactant.Compiler.code_xla Function
code_xla(ctx, f, args; fn_kwargs = NamedTuple(), kwargs...)Compile the function f with arguments args and return the compiled HLO module.
See also: @code_xla.
Compile Options
Reactant.CompileOptions Type
CompileOptionsFine-grained control over the compilation options for the Reactant compiler.
Controlling Optimization Passes
optimization_passes: Optimizations passes to run on the traced MLIR code. Valid types of values are:Bool (true/false): whether to run the optimization passes or not. Defaults to
true.String: a custom string with the passes to run. The string should be a comma-separated list of MLIR passes. For example,
"canonicalize,enzyme-hlo-opt".Symbol: a predefined set of passes to run. Valid options are:
:all: Default set of optimization passes. The exact set of passes are not fixed and may change in future versions of Reactant. It is recommended to use this option for most users.:none: No optimization passes will be run.Other predefined options are:
:before_kernel,:before_jit,:before_raise,:before_enzyme,:after_enzyme,:just_batch,:canonicalize,:only_enzyme.
no_nan: Iftrue, the optimization passes will assume that the function does not produce NaN values. This can lead to more aggressive optimizations (and potentially incorrect results if the function does produce NaN values).all_finite: Iftrue, the optimization passes will assume that the function does not produce Inf or -Inf values. This can lead to more aggressive optimizations (and potentially incorrect results if the function does produce Inf or -Inf values).transpose_propagate: If:up,stablehlo.transposeoperations will be propagated up the computation graph. If:down, they will be propagated down. Defaults to:up.reshape_propagate: If:up,stablehlo.reshapeoperations will be propagated up the computation graph. If:down, they will be propagated down. Defaults to:up.max_constant_threshold: If the number of elements in a constant is greater than this threshold (for a non-splatted constant), we will throw an error.inline: Iftrue, all functions will be inlined. (Default:true).
Raising Options
raise: Iftrue, the function will be compiled with the raising pass, which raises CUDA and KernelAbstractions kernels to HLO. Defaults tofalse, but is automatically activated if the inputs are sharded.raise_first: Iftrue, the raising pass will be run before the optimization passes. Defaults tofalse.
Dialect Specific Options
legalize_chlo_to_stablehlo: Iftrue,chlodialect ops will be converted tostablehloops. (Default:false).
Backend Specific Options
Only for CUDA backend
cudnn_hlo_optimize: Run cuDNN specific HLO optimizations. This is only relevant for GPU backends and isfalseby default. Experimental and not heavily tested.
Sharding Options
shardy_passes: Defaults to:post_sdy_propagation. Other options are::none: No sharding passes will be run. Shardy + MHLO shardings are handled by XLA.:post_sdy_propagation: Runs the Shardy propagation passes. MHLO shardings are handled by XLA.ShardyPropagationOptions: Custom sharding propagation options. MHLO shardings are handled by XLA.:to_mhlo_shardings: Runs the Shardy propagation passes and then exports the shardings to MHLO. All passes are run via MLIR pass pipeline and don't involve XLA.
optimize_then_pad: Iftrue, the function will be optimized before padding (for non-divisible sharding axes) is applied. Defaults totrue. (Only for Sharded Inputs)optimize_communications: Iftrue, additional passes for optimizing communication in sharded computations will be run. Defaults totrue. (Only for Sharded Inputs)
Julia Codegen Options
donated_args: If:auto, the function will automatically donate the arguments that are not preserved in the function body. If:none, no arguments will be donated. Defaults to:auto.assert_nonallocating: Iftrue, we make sure that no new buffers are returned by the function. Any buffer returned must be donated from the inputs. Defaults tofalse.sync: Reactant computations are asynchronous by default. Iftrue, the computation will be executed synchronously, blocking till the computation is complete. This is recommended when benchmarking.
Extended Help
Internal XLA Options
Warning
We have limited control over these options (in terms of semantic versioning) since they are tied to XLA upstream.
xla_executable_build_options: XLA executable build options. Additional options that are passed to ExecutableBuildOptionsProto.xla_compile_options: XLA compile options. Additional options that are passed to CompileOptionsProto.xla_debug_options: XLA debug options. Additional options that are passed to DebugOptionsProto.
Private Options
Warning
These options are not part of the public API and are subject to change without any notice or deprecation cycle.
disable_scatter_gather_optimization_passes: Disables the scatter-gather optimization passes. (Default:false).disable_pad_optimization_passes: Disables the pad optimization passes. This isfalseby default.disable_licm_optimization_passes: Disables the Loop Invariant Code Motion (LICM) optimization passes. (Default:false).disable_reduce_slice_fusion_passes: Disables fusion of slice elementwise and reduce operations. (Defaultfalse).disable_slice_to_batch_passes: Disables the slice to batch fusion optimization passes. (Default:true). (Note that this is generally an expensive pass to run)disable_concat_to_batch_passes: Disables concatenate to batch fusion passes. (Default:false).disable_loop_raising_passes: Disables raising passes forstablehlo.while. (Default:false).disable_structured_tensors_detection_passes: Disables structured tensors detection passes. (Defaulttrue).disable_structured_tensors_passes: Disables structured tensors optimization passes. (Defaultfalse).strip_llvm_debuginfo: Removes LLVM debug info from the generated IR.
Reactant.DefaultXLACompileOptions Function
DefaultXLACompileOptions(;
donated_args=:auto, sync=false, optimize_then_pad=true, assert_nonallocating=false
)Runs specific Enzyme-JAX passes to ensure that the generated code is compatible with XLA compilation. For the documentation of the allowed kwargs see CompileOptions.
Warning
This is mostly a benchmarking option, and the default CompileOptions is almost certainly a better option.
Sharding Specific Options
Reactant.OptimizeCommunicationOptions Type
OptimizeCommunicationOptionsFine-grained control over the optimization passes that rewrite ops to minimize collective communication.
sourceReactant.ShardyPropagationOptions Type
ShardyPropagationOptionsFine-grained control over the sharding propagation pipeline. For more information on sharding propagation, see the Shardy Docs.
Options
keep_sharding_rules::Bool: whether to keep existing and created op sharding rules.conservative_propagation::Bool: whether to disallow split axes and non-divisible sharding axes during propagation.debug_sharding_origins::Bool: whether to save information about the origin of a sharding on the MLIR module. These would be the shardings on the function inputs, outputs, sharding constraints and manual computations before propagation.debug_propagation_edge_sharding::Bool: whether to save information about the edge source of a sharding on the MLIR module. These are what operand/result introduced a sharding on some op result.skip_convert_to_reshard::Boolskip_inline::Boolenable_insert_explicit_collectives::Bool: whether to insert explicit collectives for sharding propagation. This is useful for debugging and checking the location of the communication ops.
Tracing customization
Reactant.@skip_rewrite_func Macro
@skip_rewrite_func fMark function f so that Reactant's IR rewrite mechanism will skip it. This can improve compilation time if it's safe to assume that no call inside f will need a @reactant_overlay method.
Info
Note that this marks the whole function, not a specific method with a type signature.
Warning
The macro call should be inside the __init__ function. If you want to mark it for precompilation, you must add the macro call in the global scope too.
See also: @skip_rewrite_type
Reactant.@skip_rewrite_type Macro
@skip_rewrite_type MyStruct
@skip_rewrite_type Type{<:MyStruct}Mark the construct function of MyStruct so that Reactant's IR rewrite mechanism will skip it. It does the same as @skip_rewrite_func but for type constructors.
If you want to mark the set of constructors over it's type parameters or over its abstract type, you should use then the Type{<:MyStruct} syntax.
Warning
The macro call should be inside the __init__ function. If you want to mark it for precompilation, you must add the macro call in the global scope too.
Profile XLA
Reactant can hook into XLA's profiler to generate compilation and execution traces. See the profiling tutorial for more details.
Reactant.Profiler.with_profiler Function
with_profiler(f, trace_output_dir; trace_device=true, trace_host=true,
create_perfetto_link=false, pm_counters=nothing, advanced_config=Dict())Runs the provided function under a profiler for XLA. The pm_counters keyword enables CUPTI hardware counter collection via the PM sampling API. Pass a comma-separated string of CUPTI metric names, or use DEFAULT_PM_COUNTERS for a standard set.
With PM counters enabled, get_framework_op_stats() returns per-kernel metrics including measured_memory_bw, operational_intensity, and bound_by.
with_profiler("./traces"; pm_counters=Profiler.DEFAULT_PM_COUNTERS) do
compiled_fn(args...)
endReactant.Profiler.annotate Function
annotate(f, name, [level=TRACE_ME_LEVEL_CRITICAL]; [metadata])Generate an annotation in the current trace. Optionally include metadata as key-value pairs.
Example
annotate("my_operation") do
# ... do work ...
end
annotate("my_operation"; metadata=Dict("key1" => "value1", "key2" => 42)) do
# ... do work ...
endReactant.Profiler.@annotate Macro
@annotate [name] function foo(a, b, c)
...
endThe created function will generate an annotation in the captured XLA profiles.
sourceReactant.Profiler.@time Macro
@time [nrepeat=1] [warmup=1] [compile_options=nothing] [profile_dir=nothing] fn(args...; kwargs...)Profiles the given function and prints the runtime and compile time. fn will be compiled with compile_options if it is not already a reactant compiled function.
Reactant.Profiler.@timed Macro
@timed [nrepeat=1] [warmup=1] [compile_options=nothing] [profile_dir=nothing] fn(args...; kwargs...)Profiles the given function and returns the runtime, compile time, and memory data. fn will be compiled with compile_options if it is not already a reactant compiled function.
Reactant.Profiler.@profile Macro
@profile [nrepeat=1] [warmup=1] [compile_options=nothing] [profile_dir=nothing] fn(args...; kwargs...)Profiles the given function and prints detailed kernel and framework op statistics. fn will be compiled with compile_options if it is not already a reactant compiled function.
Returns the result of the function call.
sourceReactant.Profiler.profiler_activity_start Function
profiler_activity_start(name::String, level::Cint[, metadata::Dict{String, <:Any}])
profiler_activity_start(name::String, level::Cint[, metadata::Pair{String, <:Any}...])Start a profiler activity with metadata (key-value pairs). The metadata will be encoded into the trace event and can be viewed in profiling tools like Perfetto.
Returns an activity ID that should be passed to profiler_activity_end when the activity ends.
Example
id = profiler_activity_start("my_operation", TRACE_ME_LEVEL_INFO,
"key1" => "value1", "key2" => 42)
# ... do work ...
profiler_activity_end(id)Reactant.Profiler.profiler_activity_end Function
profiler_activity_end(id::Int64)End a profiler activity. See profiler_activity_start for more information.
Reactant.Profiler.DEFAULT_PM_COUNTERS Constant
DEFAULT_PM_COUNTERSDefault CUPTI Performance Monitor counters for GPU kernel analysis. Pass to with_profiler via pm_counters=Profiler.DEFAULT_PM_COUNTERS.
Available counters depend on GPU architecture. Common useful ones:
DRAM bandwidth: dram__bytes_read.sum, dram__bytes_write.sum, dram__throughput.avg.pct_of_peak_sustained_elapsed
L2 cache: lts__t_sectors_lookup_hit.sum, lts__t_sectors_lookup_miss.sum, lts__t_bytes.sum
L1/local memory (register spills): l1tex__t_bytes.sum, l1tex__data_pipe_lsu_wavefronts_mem_lg_cmd_local.sum
Compute: sm__inst_executed.sum, sm__sass_thread_inst_executed_op_dfma_pred_on.sum (FP64 FMAs)
Occupancy: sm__warps_active.avg.pct_of_peak_sustained_active
To list all available counters for your GPU, run: ncu --query-metrics (Nsight Compute) or cupti_query --device 0 --getmetrics (CUPTI toolkit)
Note
PM counter collection requires profiling permissions on NVIDIA GPUs. Set NVreg_RestrictProfilingToAdminUsers=0 in /etc/modprobe.d/nvidia-profiler.conf and reload the nvidia kernel module.
XProf APIs
Reactant.Profiler.initialize_xprof_stubs Function
initialize_xprof_stubs(worker_service_address::String)Initialize XProf stubs for remote profiling. This sets up the worker service address for connecting to the XProf profiler service.
Arguments
worker_service_address: The address of the worker service (e.g., "localhost:9001")
Reactant.Profiler.start_xprof_grpc_server Function
start_xprof_grpc_server(port::Integer)Start an XProf GRPC server on the specified port. This allows remote profiling connections from tools like TensorBoard.
Arguments
port: The port number to start the GRPC server on
Reactant.Profiler.xspace_to_tools_data Function
xspace_to_tools_data(
xspace_paths::Vector{String}, tool_name::String; options::Dict=Dict()
)Convert XSpace profile data to a specific tool format.
Arguments
xspace_paths: Vector of paths to XSpace profile directoriestool_name: Name of the tool to convert to (e.g., "trace_viewer", "tensorflow_stats", "overview_page")options: Optional dictionary of tool-specific options. Values can be Bool, Int, or String.
Returns
Tuple{Vector{UInt8}, Bool}: A tuple of (data, is_binary) where data is the converted profile data and is_binary indicates whether the data is in binary format.
Example
data, is_binary = xspace_to_tools_data(["/path/to/xspace"], "trace_viewer")Devices
Reactant.devices Function
devices(backend::String)
devices(backend::XLA.AbstractClient = XLA.default_backend())Return a list of devices available for the given client.
sourceReactant.addressable_devices Function
addressable_devices(backend::String)
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend())Return a list of addressable devices available for the given client.
sourceDifferentiation Specific API
EnzymeCore.ignore_derivatives Function
ignore_derivatives(x::T)::TBehaves like the identity function, but disconnects the "shadow" associated with x. This has the effect of preventing any derivatives from being propagated through x.
Enzyme 0.13.74
Support for ignore_derivatives was added in Enzyme 0.13.74.
Persistent Compilation Cache
Reactant.PersistentCompileCache.clear_compilation_cache! Function
clear_compilation_cache!()Deletes the compilation cache directory. This removes all cached compilation artifacts for all past versions of Reactant_jll.
sourceInternal utils
ReactantCore.materialize_traced_array Function
materialize_traced_array(AbstractArray{<:TracedRNumber})::TracedRArrayGiven an AbstractArray{TracedRNumber}, return or create an equivalent TracedRArray.
sourceReactant.apply_type_with_promotion Function
This function tries to apply the param types to the wrapper type. When there's a constraint conflict, it tries to resolve it:
ConcreteRNumber{T} vs T: resolves to T
other cases: resolve by
promote_type
The new param type is then propagated in any param type that depends on it. Apart from the applied type, it also returns a boolean array indicating which of the param types were changed.
For example:
using Reactant
struct Foo{T, A<:AbstractArray{T}}
a::A
end
Reactant.apply_type_with_promotion(Foo, (Int, TracedRArray{Int, 1}))returns
(Foo{Reactant.TracedRNumber{Int64}, Reactant.TracedRArray{Int64, 1}}, Bool[1, 0])The first type parameter has been promoted to satisfy to be in agreement with the second parameter.
source