Model Checkpoint Improvements
See original GitHub issueThe current checkpoint design needs some improvements.
Format
- The on-disk file names should be readable and deterministic instead of random numbers.
For example, if we have two calls
The second call should fully overwrite the first call. However, in the current implementation with random numbers, we store two copies and the former will never be deleted or used.save_checkpoint(path, target) save_checkpoint(path, target)
- We should not store any absolute path
Speed
Due to frequent synchronization and compression, the current save/load speed is very slow. We need to benchmark the performance against the peak bandwidth.
The peak bandwidth of load can be measured by running a jax program that reads np arrays from disk to local GPU memory. Ditto for save. When we save/load distributed arrays with multiple workers, we should also try to get a perfect linear scaling. The possible ways to improve bandwidth include batching small arrays, batching ray calls, multithreading, and overlapping driver process with worker processes.
The micro-benchmark script can be put under https://github.com/alpa-projects/alpa/tree/main/playground/alpa_micro_benchmark
Features
As a bonus, we want to support partial writing and partial reading. For example, we can save the checkpoints in two separate calls.
# Write layer 1
param_tree_1 = {
"1": {"kernel": array, "bias": array},
"2": {"kernel": None, "bias": None},
}
save_checkpoint(path, param_tree_1)
# Write layer 2
param_tree_2 = {
"1": {"kernel": None, "bias": None},
"2": {"kernel": array, "bias": array},
}
save_checkpoint(path, param_tree_2)
Ditto for load.
Minors
The aval for replicated arrays is the same https://github.com/alpa-projects/alpa/blob/a0d1d17ddc444924c27ef11816bc0963750570b9/alpa/serialization.py#L177
Issue Analytics
- State:
- Created a year ago
- Comments:7 (7 by maintainers)
I did some benchmark experiments for save/load. I chose three different methods for save/load: (1) flax’s checkpoint functions (2) alpa’s current checkpoint functions (no compression) (3) np.save/load. All the experiments are conducted both on EFS filesystem and local filesystem (results indicated that EFS can be the bottleneck).
First, to benchmark the peak bandwidth for save/load, I used the above three methods to save/load a 1GB ndarray between GPU and disk. The results are as followed:
It’s obvious to find that EFS is the bottleneck for saving, but I’m curious why numpy can be much faster than flax’s msgpack and alpa’s tensorstore. So I also benchmarked the saving functions of the three methods on a real MLP model (3G params):
Now the numpy is slower than alpa’s tensorstore. I read the source code of np.save and find that when the target is a python object, numpy will use pickle to serialize the object, but if the target is an array, it will call the fast path. This explains the above two experiments. In the single array case, flax is slow because it uses msgpack (a package for object serialization similar to pickle), alpa is slow because it uses tensorstore, but numpy is fast because it calls the fast pass to save array data directly, without any serialization overhead. In the MLP case, the saving/loading target is a TrainState object, so numpy will use pickle to serialize it, losing the performance.
Conclusion:
Baseline: I also benchmarked alpa’s current distributed checkpoint functions on two hosts as baseline:
Performance Target:
Resolved by #487