question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Model Checkpoint Improvements

See original GitHub issue

The current checkpoint design needs some improvements.

Format

  • The on-disk file names should be readable and deterministic instead of random numbers. For example, if we have two calls
    save_checkpoint(path, target)
    save_checkpoint(path, target)
    
    The second call should fully overwrite the first call. However, in the current implementation with random numbers, we store two copies and the former will never be deleted or used.
  • We should not store any absolute path

Speed

Due to frequent synchronization and compression, the current save/load speed is very slow. We need to benchmark the performance against the peak bandwidth.

The peak bandwidth of load can be measured by running a jax program that reads np arrays from disk to local GPU memory. Ditto for save. When we save/load distributed arrays with multiple workers, we should also try to get a perfect linear scaling. The possible ways to improve bandwidth include batching small arrays, batching ray calls, multithreading, and overlapping driver process with worker processes.

The micro-benchmark script can be put under https://github.com/alpa-projects/alpa/tree/main/playground/alpa_micro_benchmark

Features

As a bonus, we want to support partial writing and partial reading. For example, we can save the checkpoints in two separate calls.

# Write layer 1
param_tree_1 = {
  "1": {"kernel": array, "bias": array},
  "2": {"kernel": None, "bias": None},
}
save_checkpoint(path, param_tree_1)

# Write layer 2
param_tree_2 = {
  "1": {"kernel": None, "bias": None},
  "2": {"kernel": array, "bias": array},
}
save_checkpoint(path, param_tree_2)

Ditto for load.

Minors

The aval for replicated arrays is the same https://github.com/alpa-projects/alpa/blob/a0d1d17ddc444924c27ef11816bc0963750570b9/alpa/serialization.py#L177

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
PKUFlyingPigcommented, May 27, 2022

I did some benchmark experiments for save/load. I chose three different methods for save/load: (1) flax’s checkpoint functions (2) alpa’s current checkpoint functions (no compression) (3) np.save/load. All the experiments are conducted both on EFS filesystem and local filesystem (results indicated that EFS can be the bottleneck).

First, to benchmark the peak bandwidth for save/load, I used the above three methods to save/load a 1GB ndarray between GPU and disk. The results are as followed:

   Benchmark results on EFS: 
    - flax.save_checkpoint:     save average run time: 15.0580 seconds, save average throughput: 0.5313 Gbps
    - flax.restore_checkpoint: load average run time:   6.8287 seconds, load average throughput: 1.2225 Gbps

    - alpa.save_checkpoint:     save average run time: 12.8583 seconds, save average throughput: 0.6222 Gbps
    - alpa.restore_checkpoint: N/A because the loading involves multiple hosts and ray, only benchmarked for a distributed baseline at the bottom.

    - np.save:                 save average run time: 10.4157 seconds, save average throughput: 0.7682 Gbps
    - np.load:                  load average run time:  2.9987 seconds, load average throughput: 4.9950 Gbps

    Benchmark results on local filesystem:
    - flax.save_checkpoint:    save average run time: 5.5268 seconds, save average throughput: 1.4475 Gbps
    - flax.restore_checkpoint: load average run time: 5.1856 seconds, load average throughput: 1.5428 Gbps

    - alpa.save_checkpoint:   save average run time: 10.3145 seconds, save average throughput: 0.7756 Gbps
    - alpa.restore_checkpoint: N/A

    - np.save:                 save average run time: 0.8104 seconds, save average throughput:  9.8718 Gbps
    - np.load:                  load average run time: 0.5116 seconds, load average throughput: 15.6365 Gbps

It’s obvious to find that EFS is the bottleneck for saving, but I’m curious why numpy can be much faster than flax’s msgpack and alpa’s tensorstore. So I also benchmarked the saving functions of the three methods on a real MLP model (3G params):

    Benchmark results on EFS: 
    - flax.save_checkpoint: average run time: 45.19087886810303 seconds, average throughput: 0.5313484040513637 Gbps
    - alpa.save_checkpoint: average run time: 16.15189399719238, average throughput:    1.4860819837013484 Gbps
    - np.save:               average run time: 20.618193340301513, average throughput: 1.1642373201358331 Gbps

    Benchmark results on local disk:
    - flax.save_checkpoint: average run time: 16.1341721534729, average throughput: 1.4877078603042466 Gbps
    - alpa.save_checkpoint: average run time: 10.663438653945922, average throughput: 2.2509621962263244 Gbps
    - np.save:              average run time: 18.294342517852783, average throughput: 1.3120415111267847 Gbps

Now the numpy is slower than alpa’s tensorstore. I read the source code of np.save and find that when the target is a python object, numpy will use pickle to serialize the object, but if the target is an array, it will call the fast path. This explains the above two experiments. In the single array case, flax is slow because it uses msgpack (a package for object serialization similar to pickle), alpa is slow because it uses tensorstore, but numpy is fast because it calls the fast pass to save array data directly, without any serialization overhead. In the MLP case, the saving/loading target is a TrainState object, so numpy will use pickle to serialize it, losing the performance.

Conclusion:

  • object serialization and tensorstore is slow => I will try to implement the resharding feature solely in numpy array
  • EFS is slow => maybe each host can save part of the model on local disk first, but then the resharding process during loading can be problematic 😦

Baseline: I also benchmarked alpa’s current distributed checkpoint functions on two hosts as baseline:

    Benchmark results on EFS:
    - alpa.save_checkpoint: save average run time: 31.5137 seconds, save average throughput: 0.7617 Gbps
    - alpa.restore_checkpoint: load average run time: 15.2250 seconds, load average throughput: 1.5772 Gbps

Performance Target:

  • single host peak bandwidth: (np.save/load achieve this on EFS)
    • save : 0.7682 Gbps
    • load : 4.9950 Gbps
  • mutiple workers: get a perfect linear scaling
0reactions
zhuohan123commented, Jun 13, 2022

Resolved by #487

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to Checkpoint Deep Learning Models in Keras
Checkpoint Neural Network Model Improvements. A good use of checkpointing is to output the model weights each time an improvement is ...
Read more >
How to use the ModelCheckpoint callback with Keras and ...
A good application of checkpointing is to serialize your network to disk each time there is an improvement during training. We define an...
Read more >
ModelCheckpoint - Keras
ModelCheckpoint callback is used in conjunction with training using model.fit() to save a model or weights (in a checkpoint file) at some interval, ......
Read more >
Model Checkpoint Deep Learning Tricks tf.keras.callbacks ...
In this video we talk about Tesorflow Callbacks - Model Checkpoint Deep Learning Tricks tf.keras.callbacks.ModelCheckpointWays to create ...
Read more >
Save model checkpoint only when model shows improvement ...
Checkpoints are saved for the event that your training process is interrupted. If you don't have checkpoints you will need to restart from ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found