Serialization
Reactant.Serialization Module
Implements serialization of Reactant compiled functions. Currently supported formats are:
sourceExporting to TensorFlow SavedModel
Load PythonCall
Serialization to TensorFlow SavedModel requires PythonCall to be loaded. Loading PythonCall will automatically install tensorflow. If tensorflow installation fails, we won't be able to export to SavedModel.
A SavedModel contains a complete TensorFlow program, including trained parameters (i.e, tf.Variables) and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub. Refer to the official documentation for more details.
Reactant.Serialization.export_as_tf_saved_model Function
export_as_tf_saved_model(
thunk::Compiler.Thunk,
saved_model_path::String,
target_version::VersionNumber,
input_locations::Vector=[],
state_dict::Dict=Dict(),
)
Serializes a compiled reactant function (aka Reactant.Compiler.Thunk
) to a Tensorflow SavedModel which can be used for deployemnt.
Arguments
thunk
: The compiled function to serialize (output of@compile
). For this to work, the thunk must be compiled withserializable=true
.saved_model_path
: The path where the SavedModel will be saved.target_version
: The version for serialization.input_locations
: A vector of input locations. This must be an empty vector or a vector of size equal to the number of inputs of the function. Each element can be one of:TFSavedModel.VariableType
: This indicates whether the variable is an input variable or a parameter. A parameter is serialized as a constant in the SavedModel, while an input variable is required at runtime.String
: The name of a parameter. This requires a corresponding entry in thestate_dict
to be present.Integer
: The position of an input argument. This is used to indicate that the input is an input argument.
state_dict
: A dictionary mapping parameter names to their values. This is used to serialize the parameters of the function.
Example
using Reactant, PythonCall
function fn(x, y)
return sin.(x) .+ cos.(y.x[1:2, :])
end
x = Reactant.to_rarray(rand(Float32, 2, 10))
y = (; x=Reactant.to_rarray(rand(Float32, 4, 10)))
compiled_fn = @compile serializable = true fn(x, y)
Reactant.Serialization.export_as_tf_saved_model(
compiled_fn,
"/tmp/test_saved_model",
v"1.8.5",
[
Reactant.Serialization.TFSavedModel.InputArgument(1),
Reactant.Serialization.TFSavedModel.Parameter("y.x"),
],
Dict("y.x" => y.x),
)
import tensorflow as tf
import numpy as np
restored_model = tf.saved_model.load("/tmp/test_saved_model")
# Note that size of the input in python is reversed compared to Julia.
x = tf.constant(np.random.rand(10, 2))
restored_model.f(x)