Serialization
Reactant.Serialization Module
Implements serialization of Reactant compiled functions. Currently supported formats are:
EnzymeJAX export for JAX integration
Exporting to TensorFlow SavedModel
Load PythonCall
Serialization to TensorFlow SavedModel requires PythonCall to be loaded. Loading PythonCall will automatically install tensorflow. If tensorflow installation fails, we won't be able to export to SavedModel.
A SavedModel contains a complete TensorFlow program, including trained parameters (i.e, tf.Variables) and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub. Refer to the official documentation for more details.
Reactant.Serialization.export_as_tf_saved_model Function
export_as_tf_saved_model(
thunk::Compiler.Thunk,
saved_model_path::String,
target_version::VersionNumber,
input_locations::Vector=[],
state_dict::Dict=Dict(),
)Serializes a compiled reactant function (aka Reactant.Compiler.Thunk) to a Tensorflow SavedModel which can be used for deployemnt.
Arguments
thunk: The compiled function to serialize (output of@compile). For this to work, the thunk must be compiled withserializable=true.saved_model_path: The path where the SavedModel will be saved.target_version: The version for serialization.input_locations: A vector of input locations. This must be an empty vector or a vector of size equal to the number of inputs of the function. Each element can be one of:TFSavedModel.VariableType: This indicates whether the variable is an input variable or a parameter. A parameter is serialized as a constant in the SavedModel, while an input variable is required at runtime.String: The name of a parameter. This requires a corresponding entry in thestate_dictto be present.Integer: The position of an input argument. This is used to indicate that the input is an input argument.
state_dict: A dictionary mapping parameter names to their values. This is used to serialize the parameters of the function.
Example
using Reactant, PythonCall
function fn(x, y)
return sin.(x) .+ cos.(y.x[1:2, :])
end
x = Reactant.to_rarray(rand(Float32, 2, 10))
y = (; x=Reactant.to_rarray(rand(Float32, 4, 10)))
compiled_fn = @compile serializable = true fn(x, y)
Reactant.Serialization.export_as_tf_saved_model(
compiled_fn,
"/tmp/test_saved_model",
v"1.8.5",
[
Reactant.Serialization.TFSavedModel.InputArgument(1),
Reactant.Serialization.TFSavedModel.Parameter("y.x"),
],
Dict("y.x" => y.x),
)import tensorflow as tf
import numpy as np
restored_model = tf.saved_model.load("/tmp/test_saved_model")
# Note that size of the input in python is reversed compared to Julia.
x = tf.constant(np.random.rand(10, 2))
restored_model.f(x)Exporting to JAX via EnzymeAD
Load NPZ
This export functionality requires the NPZ package to be loaded.
This export functionality generates:
A
.mlirfile containing the StableHLO representation of your Julia functionInput
.npzfiles containing the input arrays for the functionA Python script that wraps the function for use with
enzyme_ad.jax.hlo_call
The generated Python script can be immediately used with JAX and EnzymeAD without any additional Julia dependencies.
Reactant.Serialization.EnzymeJAX.export_to_enzymejax Function
export_to_enzymejax(
f,
args...;
output_dir::Union{String,Nothing}=nothing,
function_name::String=string(f),
preserve_sharding::Bool=true,
compile_options=Reactant.Compiler.CompileOptions(),
)Export a Julia function to EnzymeJAX format for use in Python/JAX.
This function:
Compiles the function to StableHLO via
Reactant.@code_hloSaves the MLIR/StableHLO code to a
.mlirfileSaves all input arrays to a single compressed
.npzfile (transposed to account for row-major vs column-major)Generates a Python script with the function wrapped for EnzymeJAX's
hlo_call
Requirements
- NPZ.jl: Must be loaded with
using NPZfor compression support
Arguments
f: The Julia function to exportargs...: The arguments to the function (used to infer types and shapes)
Keyword Arguments
output_dir::Union{String,Nothing}: Directory where output files will be saved. Ifnothing, uses a temporary directory and prints the path.function_name::String: Base name for generated filespreserve_sharding::Bool: Whether to preserve sharding information in the exported function. Defaults totrue.compile_options: Compilation options passed toReactant.Compiler.compile_mlir. SeeReactant.Compiler.CompileOptionsfor more details.
Returns
The path to the generated Python script as a String.
Files Generated
{function_name}.mlir: The StableHLO/MLIR module{function_name}_{id}_inputs.npz: Compressed NPZ file containing all input arrays{function_name}.py: Python script with the function wrapped for EnzymeJAX
Example
using Reactant, NPZ
# Define a simple function
function my_function(x::AbstractArray, y::NamedTuple, z::Number)
return x .+ y.x .- y.y .+ z
end
# Create some example inputs
x = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3))
y = (;
x=Reactant.to_rarray(reshape(collect(Float32, 7:12), 2, 3)),
y=Reactant.to_rarray(reshape(collect(Float32, 13:18), 2, 3))
)
z = Reactant.to_rarray(10.0f0; track_numbers=true)
# Export to EnzymeJAX
python_file_path = Reactant.Serialization.export_to_enzymejax(my_function, x, y, z)Then in Python:
# Run the generated Python script
from exported.my_function import run_my_function
import jax
result = run_my_function(*inputs)