Skip to content

Serialization

Reactant.Serialization Module

Implements serialization of Reactant compiled functions. Currently supported formats are:

source

Exporting to TensorFlow SavedModel

Load PythonCall

Serialization to TensorFlow SavedModel requires PythonCall to be loaded. Loading PythonCall will automatically install tensorflow. If tensorflow installation fails, we won't be able to export to SavedModel.

A SavedModel contains a complete TensorFlow program, including trained parameters (i.e, tf.Variables) and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub. Refer to the official documentation for more details.

Reactant.Serialization.export_as_tf_saved_model Function
julia
export_as_tf_saved_model(
    thunk::Compiler.Thunk,
    saved_model_path::String,
    target_version::VersionNumber,
    input_locations::Vector=[],
    state_dict::Dict=Dict(),
)

Serializes a compiled reactant function (aka Reactant.Compiler.Thunk) to a Tensorflow SavedModel which can be used for deployemnt.

Arguments

  • thunk: The compiled function to serialize (output of @compile). For this to work, the thunk must be compiled with serializable=true.

  • saved_model_path: The path where the SavedModel will be saved.

  • target_version: The version for serialization.

  • input_locations: A vector of input locations. This must be an empty vector or a vector of size equal to the number of inputs of the function. Each element can be one of:

    • TFSavedModel.VariableType: This indicates whether the variable is an input variable or a parameter. A parameter is serialized as a constant in the SavedModel, while an input variable is required at runtime.

    • String: The name of a parameter. This requires a corresponding entry in the state_dict to be present.

    • Integer: The position of an input argument. This is used to indicate that the input is an input argument.

  • state_dict: A dictionary mapping parameter names to their values. This is used to serialize the parameters of the function.

Example

julia
using Reactant, PythonCall

function fn(x, y)
    return sin.(x) .+ cos.(y.x[1:2, :])
end

x = Reactant.to_rarray(rand(Float32, 2, 10))
y = (; x=Reactant.to_rarray(rand(Float32, 4, 10)))

compiled_fn = @compile serializable = true fn(x, y)

Reactant.Serialization.export_as_tf_saved_model(
    compiled_fn,
    "/tmp/test_saved_model",
    v"1.8.5",
    [
        Reactant.Serialization.TFSavedModel.InputArgument(1),
        Reactant.Serialization.TFSavedModel.Parameter("y.x"),
    ],
    Dict("y.x" => y.x),
)
python
import tensorflow as tf
import numpy as np

restored_model = tf.saved_model.load("/tmp/test_saved_model")

# Note that size of the input in python is reversed compared to Julia.
x = tf.constant(np.random.rand(10, 2))
restored_model.f(x)
source

Exporting to JAX via EnzymeAD

Load NPZ

This export functionality requires the NPZ package to be loaded.

This export functionality generates:

  1. A .mlir file containing the StableHLO representation of your Julia function

  2. Input .npz files containing the input arrays for the function

  3. A Python script that wraps the function for use with enzyme_ad.jax.hlo_call

The generated Python script can be immediately used with JAX and EnzymeAD without any additional Julia dependencies.

Reactant.Serialization.EnzymeJAX.export_to_enzymejax Function
julia
export_to_enzymejax(
    f,
    args...;
    output_dir::Union{String,Nothing}=nothing,
    function_name::String=string(f),
    preserve_sharding::Bool=true,
    compile_options=Reactant.Compiler.CompileOptions(),
)

Export a Julia function to EnzymeJAX format for use in Python/JAX.

This function:

  1. Compiles the function to StableHLO via Reactant.@code_hlo

  2. Saves the MLIR/StableHLO code to a .mlir file

  3. Saves all input arrays to a single compressed .npz file (transposed to account for row-major vs column-major)

  4. Generates a Python script with the function wrapped for EnzymeJAX's hlo_call

Requirements

  • NPZ.jl: Must be loaded with using NPZ for compression support

Arguments

  • f: The Julia function to export

  • args...: The arguments to the function (used to infer types and shapes)

Keyword Arguments

  • output_dir::Union{String,Nothing}: Directory where output files will be saved. If nothing, uses a temporary directory and prints the path.

  • function_name::String: Base name for generated files

  • preserve_sharding::Bool: Whether to preserve sharding information in the exported function. Defaults to true.

  • compile_options: Compilation options passed to Reactant.Compiler.compile_mlir. See Reactant.Compiler.CompileOptions for more details.

Returns

The path to the generated Python script as a String.

Files Generated

  • {function_name}.mlir: The StableHLO/MLIR module

  • {function_name}_{id}_inputs.npz: Compressed NPZ file containing all input arrays

  • {function_name}.py: Python script with the function wrapped for EnzymeJAX

Example

julia
using Reactant, NPZ

# Define a simple function
function my_function(x::AbstractArray, y::NamedTuple, z::Number)
    return x .+ y.x .- y.y .+ z
end

# Create some example inputs
x = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3))
y = (;
    x=Reactant.to_rarray(reshape(collect(Float32, 7:12), 2, 3)),
    y=Reactant.to_rarray(reshape(collect(Float32, 13:18), 2, 3))
)
z = Reactant.to_rarray(10.0f0; track_numbers=true)

# Export to EnzymeJAX
python_file_path = Reactant.Serialization.export_to_enzymejax(my_function, x, y, z)

Then in Python:

python
# Run the generated Python script
from exported.my_function import run_my_function
import jax

result = run_my_function(*inputs)
source