Skip to content

feat: serialization #1083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

feat: serialization #1083

wants to merge 1 commit into from

Conversation

avik-pal
Copy link
Collaborator

extremely wip...

@avik-pal
Copy link
Collaborator Author

we need to trim out all the pointers that get embedded into the flatten/unflatten code

@avik-pal avik-pal force-pushed the ap/serialize branch 3 times, most recently from 22f8a86 to e57ceff Compare March 29, 2025 02:01
@avik-pal
Copy link
Collaborator Author

using Reactant

# Reactant.Compiler.DEBUG_PRINT_CODEGEN[] = true

mesh = Sharding.Mesh(reshape(collect(0:7), 2, 4), (:x, :y))

sharding = Sharding.NamedSharding(mesh, (:x, :y))

x_ra = Reactant.to_rarray(rand(2, 4); sharding)
y_ra = Reactant.to_rarray(rand(2, 4))


function f(x, y)
    y .= x .+ y
    return y
end

# ---- Run 1st time

thunk = @compile serializable = true f(x_ra, y_ra)

Reactant.Serialization.serialize("envs/serialized_test.jld2", thunk)

# ---- Run 2nd time

thunk_loaded = Reactant.Serialization.deserialize(
    f,
    "envs/serialized_test.jld2";
    client=Reactant.XLA.default_backend(),
    device=nothing,
    global_device_ids=collect(0:7),
)

thunk_loaded(x_ra, y_ra)

@wsmoses
Copy link
Member

wsmoses commented Mar 29, 2025

one particularly relevant question, if we serialize code for one sharding, can we use the deserialized version for a different sharding (and device).

essentially such that we could compile a version on cpu, and reshard for multi tpu.

@avik-pal
Copy link
Collaborator Author

Rn I intentionally made the global_device_ids an input to the deserialize function with that intent. If the user knows that there are 32 total tpu devices, we can compile on 32 fake CPU devices and then pass in the device ids for the 32 tpus and it should work.

It should also be possible to just compile the unsharded version and shard it upon deserialize

@avik-pal
Copy link
Collaborator Author

It should also be possible to just compile the unsharded version and shard it upon deserialize

This might be very useful for the scaling runs, where we serialize the unsharded version and only pay the cost for propagation and partitioning instead of having to re-trace the whole program

@wsmoses
Copy link
Member

wsmoses commented Mar 29, 2025

It should also be possible to just compile the unsharded version and shard it upon deserialize

This might be very useful for the scaling runs, where we serialize the unsharded version and only pay the cost for propagation and partitioning instead of having to re-trace the whole program

yeah that's what I'm thinking

There's a separate long term question of whether we could even make a size agnostic serialization but for another day

refactor: remove all runtime info from compiled function body

perf: optimize mesh codegen

fix: pjrt codegen

fix: hlosharding codegen

feat: serialize/deserialize pipeline
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants