question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Roadmap] CPU Performance Optimization for PyG

See original GitHub issue

🚀 The feature, motivation and pitch

The goal of this roadmap is to optimize CPU performance for PyG (including torch_scatter, torch_sparse).

For the first step, we will start with single node inference performance optimization on:

  • Homogeneous Models: GCN or GAT, PNA, EdgeConv
  • Heterogeneous Models: to_hetero, R-GCN, R-GAT

Next step will extend to optimization effort to (distributed) training.

Performance Profiling

CPU platform: Icelake Xeon

Generic benchmarking

  • GCN + ogbn-products: (torch_sparse::spmm_sum 96.04%)
  • GCN + reddits: layer=1, hidden=16 (DataLoader 83.49%, aten::scatter_add_ 8.47%)
  • GCN + reddits: layer=3, hidden=32 (DataLoader 59.83%, aten::scatter_add_ 24.76%)
  • SAGE + ogbn-products: (aten::scatter_add_ 27.61%, DataLoader 25.70%, aten::index 20.26%)
  • GAT + CiteSeer: (aten::scatter_add_ 30.91%, torch_scatter::scatter_max 24.54%, aten::mm 10.34%, aten::index_select 6.71%) most of models under pytorch_geometric/benchmark/citation have similar behavior from performance perspective.
  • to_hetero_mag: (aten::addmm 21.69%, aten::scatter_add_ 20.60%, aten::index_select 13.48%, DataLoader 12.31%)
  • PNA: (torch_scatter::scatter_max 39.34%, torch_scatter::scatter_min 39.25%); need follow up: need to get scatter_reduce tensor shape/stride (similar issue as aten::scatter_add_?)
  • dynamicEdgeConv: (torch_scatter::scatter_max 66.91%, torch_cluster::knn 23.56%) source benchmark/points/edge_cnn.py
  • EdgeConv: ((torch_scatter::scatter_max torch_scatter::scatter_max 53.61%, aten::index_select 21.73%, DataLoader 16.11%) source from https://github.com/pyg-team/pytorch_geometric/pull/4915
  • pytorch_geometric/benchmark/kernel
  • pytorch_geometric/benchmark/points

Large dataset benchmarking

  • GraphSAGE + mag240m profiling (table below)
  • analysis of ratio of profiler recorded time against total runtime (evaluate other overhead such numpy calling if any)
  • gather the input range for spmm_{sum|max|mean} for oneDNN RFC proposal (future plan)
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        71.27%      608.842s        71.39%      609.891s      70.223ms          8685  
                                torch_sparse::spmm_mean        14.91%      127.390s        14.93%      127.522s       7.342ms         17370  
                                            aten::addmm         3.77%       32.166s         7.34%       62.727s       1.806ms         34740  
                                            aten::copy_         3.60%       30.766s         3.60%       30.766s     161.007us        191082  
                                               aten::mm         2.29%       19.588s         2.30%       19.683s       1.133ms         17370  
                                aten::native_batch_norm         0.94%        7.989s         1.01%        8.657s     332.256us         26055  

DataLoader (with preprocess of input data) is the major bottleneck here, mostly from_numpy (246s) and to (169s) triggered by data type conversion, source from convert_batch.

Performance Hotspots

  • DataLoader (mini-batch mode): mostly introduced by preprocessing by Samplers (e.g. NeighborSampler).
  • edge_index in CSR: spmm_sum or spmm_max from torch_sparse (memory format CSR).
  • edge_index in COO: scatter_add, scatter_max(torch_scatter), index_select, index, etc.

Python level API upgrade in model scripts

The DataLoader is a major hotspot so the first step is to upgrade DataLoader from NeightborSampler to NeighborLoader which has native C++ impelemtation:

Native level kernel optimization

Phase One Optimizations

  • NeighborLoader parallelization: the current impl is sequential (probably to avoid oversubscription with multiple workers on the data loader). Unlike GPU runs, asynchronously run data loading thread and computation thread does not always make sense. On some occasions, run data loading step and computation step sequentially while making each of the torch operator parallel on OpenMP (which is case of intra-parallelism) makes more sense. Hotspot on torch_sparse::neighbor_sample.
  • aten::sort: GCN + ogbn-products spent roughly 1/3 time on sort in the preprocessing step (which is not covered during profiler result for the model inference), introduced by indexing from sparse tensor at gnn.py#L123. Root cause is aten::sort(dim) could only be paralleled on dimensions != dim, and the grain size is not correctly set. Fixed by #74897.
  • spmm_{max|mean|sum} (torch_sparse). Add vectorization and prefetch (indirect memory access) and apply blocking on M and K (if necessary).
  • scatte_add and scatter_max (torch_scatter). Optimized scatter_add (with extended index) with #82703. Still need more polishing work.

the current impl for scatter_add will try to parallel on the inner dimension to avoid write conflict; while ideally we should try to parallel on the outer dimension and vectorize on the inner dimension, yet need to resolve the write conflict on the output tensor. Experiment different impls for the given input range.

  • index_select, optimized via #76868.
  • index, directly optimize index would be difficult, maybe we can change it to more performance ops like index_select from NeighborLoader or customize its kernel from NeighborLoader.

Phase Two Optimizations

  • enable ‘std’ reduce type in scatter_reduce
  • add MultiAggregation kernel
  • optimize segment_reduce and align the reduce types between scatter_reduce, spmm_reduce, segment_reduce.
  • kernel fusion for GAS, #71300, maybe dispatch on TensorTypeId of CPU and SparseCPU.
  • scatter_add: cache the sorted index.
  • knn (torch_cluster), need follow up shape info to determine proper method to parallel the kernel. Evaluate knn from oneAPI dal.

Design option for vectorization

To vectorize kernels from torch-sparse and torch-scatter, we have multiple options:

  • vectorize inside torch-sparse/torch-scatter: the most simple way is to use #pragma omp simd and add a compiler flag march=skylake-avx512 but this won’t apply bfloat16 (bfloat16 is a overload of uint16 and won’t be vectorized properly by compiler)
  • vectorize inside torch-sparse/torch-scatter: use the wrapper of at::vec::Vectorized<scalar_t>, this will apply to bfloat16 but we need to customize the cmake scripts to make it compatible with PyTorch’s cpu build flags: _DEFAULT(scalar code), _AVX2 and _AVX512.
  • vectorize inside torch core: in this manner at::vec::Vectorized<scalar_t> will work without any change but need to move the operator from torch-sparse/torch-scatter to torch. Makes more sense for the fused kernel of GAS.

(current decision is to go with option 3 as much as we can)

Bfloat16 enabling in torch-sparse/torch-scatter

(highly related to the vectorization method choosn)

  • Need more work to determine operator list for bfloat16 support.

Validation

  • verify float32 accuracy
  • verify bfloat16 accuracy

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:16 (7 by maintainers)

github_iconTop GitHub Comments

5reactions
mingfeimacommented, Aug 19, 2022

Update on spmm optimizations, PR submitted at https://github.com/pytorch/pytorch/pull/83727.

Port spmm reduction from torch-sparse to torch, the current PR is only for demonstrating performance gains, API definition needs more amendment.

Now only sum is added, more will come in future (max, mean, min), the algorithm is pretty much the same.

Select benchmark from ./ogb/examples/nodeproppred/products/gnn.py, since originally this one spent majority of time on torch_sparse::spmm_sum. The spmm roughly got 5x speedup on my 20 core machine.

  1. before
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       torch_sparse::spmm_sum        97.09%       56.086s        97.09%       56.088s        6.232s             9
                 aten::linear         0.00%      85.000us         1.38%     795.485ms      88.387ms             9
                 aten::matmul         0.00%      57.000us         1.38%     795.260ms      88.362ms             9
                     aten::mm         1.38%     795.201ms         1.38%     795.203ms      88.356ms             9
                   aten::relu         0.00%      50.000us         0.76%     440.434ms      73.406ms             6
              aten::clamp_min         0.76%     440.384ms         0.76%     440.384ms      73.397ms             6
                   aten::add_         0.57%     327.801ms         0.57%     327.801ms      36.422ms             9
            aten::log_softmax         0.00%      23.000us         0.10%      55.503ms      18.501ms             3
           aten::_log_softmax         0.10%      55.480ms         0.10%      55.480ms      18.493ms             3
                 aten::argmax         0.09%      53.149ms         0.09%      53.153ms      13.288ms             4
                  aten::index         0.01%       5.771ms         0.01%       5.839ms     324.389us            18
                  aten::empty         0.00%       1.088ms         0.00%       1.088ms      77.714us            14
  1. after
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
               aten::spmm_sum        87.35%       11.826s        87.36%       11.827s        1.314s             9
                 aten::linear         0.00%      92.000us         5.87%     794.451ms      88.272ms             9
                 aten::matmul         0.00%      62.000us         5.87%     794.208ms      88.245ms             9
                     aten::mm         5.87%     794.143ms         5.87%     794.146ms      88.238ms             9
                   aten::relu         0.00%      53.000us         3.35%     452.977ms      75.496ms             6
              aten::clamp_min         3.35%     452.924ms         3.35%     452.924ms      75.487ms             6
                   aten::add_         2.58%     348.663ms         2.58%     348.663ms      38.740ms             9
                 aten::argmax         0.42%      57.473ms         0.42%      57.475ms      14.369ms             4
            aten::log_softmax         0.00%      22.000us         0.39%      52.605ms      17.535ms             3
           aten::_log_softmax         0.39%      52.583ms         0.39%      52.583ms      17.528ms             3
                  aten::index         0.04%       5.100ms         0.04%       5.174ms     287.444us            18
                  aten::empty         0.01%       1.097ms         0.01%       1.097ms      78.357us            14

To break down the optimization scheme a little bit:

  • original (spmm): 56.086s
  • naive vectorization: 29.314s
  • unroll by 4: 25.664s
  • rowwise blocking x16: 21.953s
  • balanced thread partition: 11.826s

The balanced thread partition is targeting at balancing the thread payload. Basically if we directly parallel on row direction, it will be (I collect number of edges for each thread):

### thread: 0; min: 1; max: 17482; avg = 172.599
### thread: 1; min: 1; max: 9918; avg = 137.251
### thread: 2; min: 1; max: 5786; avg = 39.7606
### thread: 3; min: 1; max: 4062; avg = 40.0852
### thread: 4; min: 1; max: 10406; avg = 39.7207
### thread: 5; min: 1; max: 3491; avg = 40.0985
### thread: 6; min: 1; max: 5965; avg = 40.0117
### thread: 7; min: 1; max: 5865; avg = 40.3841
### thread: 8; min: 1; max: 5892; avg = 39.969
### thread: 9; min: 1; max: 6076; avg = 39.9995
### thread: 10; min: 1; max: 5215; avg = 40.0757
### thread: 11; min: 1; max: 3893; avg = 40.1075
### thread: 12; min: 1; max: 8052; avg = 39.8108
### thread: 13; min: 1; max: 4062; avg = 39.7186
### thread: 14; min: 1; max: 3243; avg = 40.3022
### thread: 15; min: 1; max: 5008; avg = 40.4213
### thread: 16; min: 1; max: 7657; avg = 40.0987
### thread: 17; min: 1; max: 6784; avg = 40.0618
### thread: 18; min: 1; max: 4810; avg = 39.8836
### thread: 19; min: 1; max: 6429; avg = 39.9829

We can see that the first 2 threads have more payload than others, need to balance the thread payload here. Normally we can use dynamic scheduling for omp, but this won’t fit into pytorch’s at::parallel_for which is essentially a static scheduling, so I did manual partitioning here (the logic may be further refined, will do later).

2reactions
mingfeimacommented, Jul 6, 2022

current benchmark profiling result uses the default setting. Some scripts, for example to_hetero_mag would explicitly set the num_workers, if not the pytorch default setting will be 4.

DataLoader time in the benchmark profile result actually comprises of two parts:

  • IO: load data from disk to memory
  • pre processing: sampling, data type conversion, etc.

The second part takes more time, so it is still possible to be improved with single worker + parallel openmp. If we use num_workers>0, need to make sure openmp in the worker have correct setting (omp_num_threads and core affinity binding) to avoid over-subscription.

Actually the data loader optimization is a rather complexed issue, perhaps more complexed than optimizing the kernels 😦 since it is more likely a tuning job to achieve the most balanced situation between workload payload (memory footprint, computation complexity etc.) and hardware capacity (IO, memory bandwidth, ALU flops).

Usually we do not do data loading optimizations since the real case in deployment would probably be even more complexed (some venders have mechanisms like prefetching, batching to improve overall user experience and efficiency). But the thing is DGL has done some optimizations here so we need to at least something similar, otherwise out of box performance on PyG would look bad.

Anyway, we will make sure that openmp have correct settings either num_workers=0 or num_workers=N, and also each of the sampler can be properly paralleled. num_workers=0 benefits more for the pre processing and num_workers=N benefits more for the IO. And let the users to decide which way to go (maybe we can give a BKM or some simple guideline).

Read more comments on GitHub >

github_iconTop Results From Across the Web

3 Performance Tuning Roadmap - Oracle Help Center
This chapter provides a tuning roadmap and tuning tips for you can use to improve system performance:
Read more >
Graphite: Optimizing Graph Neural Networks on CPUs ...
We evaluate Graphite with popular GNN models on large graphs. The result is high-performance full-batch GNN training and infer- ence on CPUs. Our...
Read more >
Intel Reiterates Plans to Merge CPU, GPU High-performance ...
Intel reiterated it is well on its way to merging its roadmap of high-performance CPUs and GPUs as it shifts over to newer...
Read more >
Intel CPU Roadmap Update: 14th Gen Meteor Lake (4nm) in ...
Or, at the very least, the mobility lineup should. It will bring a 20% improvement in performance per watt and feature the first...
Read more >
Intel Performance optimizations for Deep Learning
optimized Deep Learning library for scalable, high-velocity integration ... Maximize TensorFlow* Performance on CPU: Considerations and Recommendations for ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found