ENH: multi dimensional wasserstein/earth mover distance in Scipy
See original GitHub issueIs your feature request related to a problem? Please describe.
Currently, Scipy has its own implementation of the wasserstein distance -> scipy.stats.wasserstein_distance
. However, the scipy.stats.wasserstein_distance function only works with one dimensional data. @Eight1911 created an issue #10382 in 2019 suggesting a more general support for multi-dimensional data. As far as I know, his pull request was not merged and the link to his script is now invalid.
A bit background below
The wasserstein distance, also called the Earth mover distance or the optimal transport distance, is defined as a similarity metric between two probability distribution. In the discrete case, the wasserstein distance can be understood as the cost of an optimal transport plan between two samples. A brief and intuitive introduction can be found here.
After it was firstly introduced in the Monge problem, the wasserstein distance has been studied for many years. It has been widely used in many areas to compare discrete distributions. For example, it was used to compare color histograms in computer vision, or as a similarity metric for anomaly detection. In the WGAN neural network framework, it was used as a loss function.
In one dimensional case, the wassrstein distance is equal to the energy distance. However, this does not hold for multi-dimensional metric space.
Describe the solution you’d like.
I’m suggesting Scipy to pick up that issue and carry on adding a multi-dimensional wasserstein distance in Scipy. Also, I’m willing to introduce it into Scipy if the Scipy developers are happy with that. My current idea is to implement a function that supports multiple ‘backend’ techniques. The function would be like this:
wass_nd(u, v, u_weight = None, v_weight = None, method = ['linear_programming', 'network_simplex', 'sinkhorn', ...])
and return a list of wasserstein values. Each element in the list corresponds to an element in the argument method
.
For now I’m considering that the function should support at least the linear programming method, as it is a well studied approach in the optimal transport theory. I wrote a python script for that based on Vincent Herrmann’s blog using scipy.optimize.linprog
and tested it against scipy.stats.wasserstein_distance
on one dimensional data. Despite the comparatively slow speed, the results show consistency.
The Sinkhorn algorithm is another approach for the optimal transport problem, firstly introduced in this paper. The Sinkhorn approach provides an approximate technique for the optimal transport problem by adding an entropy regularization to the original optimal transport.
More details about the LP and Sinkhorn approaches can be found in Computational Optimal Transport by Gabriel Peyré and Marco Cuturi. I’m not so familiar with network simplex algorithms but there are also many methods among them used to solve the optimal transport problem, like the revised simplex algorithm and the shortsimplex algorithm.
Describe alternatives you’ve considered.
In python, there are two alternatives for wasserstein distance. One is the wasserstein distance in the POT
package, the other is the Cv2.EMD
method. However, both of them have some issues when the sample size goes too large.
Additional context (e.g. screenshots, GIFs)
No response
Issue Analytics
- State:
- Created a year ago
- Reactions:2
- Comments:16 (9 by maintainers)
I develop scientific Python more than I use it, but this is one of the few features I have actually needed for research. Still, I lean toward suggesting, @com3dian, that you propose the enhancement to POT. They would need to update their minimum supported SciPy version to 1.6 because that’s when
linprog
method='HiGHS'
was introduced, but if you can show the superiority of a simple linear programming solution for very large problems, perhaps they would work with you to add that method and go further. There are a few reasons I feel this way, but the best ones are probably that it would would likely be easier to find highly qualified reviewers over there (perhaps I’m wrong - other maintainers are welcome to chime in), and POT seems to be the Python library for this sort of thing (IMO), whereas SciPy is a more general-purpose library. I can maybe see extending the existingwassertstein_distance
to multiple dimensions usinglinprog
, but I wouldn’t want to getscipy.stats
into the business of a many-method
, fully-featured EMD suite.What do you think @tupui @tirthasheshpatel @chrisb83 @ev-br?
Great. Added some comments there to help you debug. Yeah, we should not add
spatial.EMD
in this PR; that would be a separate thing, and someone else would need to review that other than me.