-
Notifications
You must be signed in to change notification settings - Fork 21
feat: serialization #1083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: serialization #1083
Conversation
we need to trim out all the pointers that get embedded into the flatten/unflatten code |
22f8a86
to
e57ceff
Compare
using Reactant
# Reactant.Compiler.DEBUG_PRINT_CODEGEN[] = true
mesh = Sharding.Mesh(reshape(collect(0:7), 2, 4), (:x, :y))
sharding = Sharding.NamedSharding(mesh, (:x, :y))
x_ra = Reactant.to_rarray(rand(2, 4); sharding)
y_ra = Reactant.to_rarray(rand(2, 4))
function f(x, y)
y .= x .+ y
return y
end
# ---- Run 1st time
thunk = @compile serializable = true f(x_ra, y_ra)
Reactant.Serialization.serialize("envs/serialized_test.jld2", thunk)
# ---- Run 2nd time
thunk_loaded = Reactant.Serialization.deserialize(
f,
"envs/serialized_test.jld2";
client=Reactant.XLA.default_backend(),
device=nothing,
global_device_ids=collect(0:7),
)
thunk_loaded(x_ra, y_ra) |
one particularly relevant question, if we serialize code for one sharding, can we use the deserialized version for a different sharding (and device). essentially such that we could compile a version on cpu, and reshard for multi tpu. |
Rn I intentionally made the global_device_ids an input to the deserialize function with that intent. If the user knows that there are 32 total tpu devices, we can compile on 32 fake CPU devices and then pass in the device ids for the 32 tpus and it should work. It should also be possible to just compile the unsharded version and shard it upon deserialize |
This might be very useful for the scaling runs, where we serialize the unsharded version and only pay the cost for propagation and partitioning instead of having to re-trace the whole program |
yeah that's what I'm thinking There's a separate long term question of whether we could even make a size agnostic serialization but for another day |
refactor: remove all runtime info from compiled function body perf: optimize mesh codegen fix: pjrt codegen fix: hlosharding codegen feat: serialize/deserialize pipeline
extremely wip...