question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Implement `scipy.optimize.linear_sum_assignment`

See original GitHub issue

Implement scipy.optimize.linear_sum_assignment, which solves the assignment problem. Among other things, this is useful for estimating the Wasserstein distance between two distributions based on their empirical measures.

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:3
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
rdilipcommented, Nov 7, 2022

Any updates on this? This seems particularly important to have for set to set machine learning methods (eg detr).

1reaction
carlosgmartincommented, Jun 27, 2022

@avinashsai @riversdark Here is scipy’s C++ implementation. I ported it to JAX, though it’s not in fully JITable form yet:

from itertools import count

from jax import numpy as jnp, random, jit
from jax.lax import cond, while_loop
from scipy.optimize import linear_sum_assignment

def augmenting_path(cost, u, v, path, row4col, i):
    minVal = 0
    num_remaining = cost.shape[1]
    remaining = jnp.arange(cost.shape[1])[::-1]

    SR = jnp.full(cost.shape[0], False)
    SC = jnp.full(cost.shape[1], False)
    shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)

    sink = -1
    while sink == -1:
        index = -1
        lowest = jnp.inf
        SR = SR.at[i].set(True)

        for it in range(num_remaining):
            j = remaining[it]

            r = minVal + cost[i, j] - u[i] - v[j]

            path = cond(
                r < shortestPathCosts[j],
                lambda: path.at[j].set(i),
                lambda: path
            )
            shortestPathCosts = shortestPathCosts.at[j].min(r)

            index = cond(
                (shortestPathCosts[j] < lowest) | 
                ((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
                lambda: it,
                lambda: index
            )
            lowest = jnp.minimum(lowest, shortestPathCosts[j])

        minVal = lowest
        if minVal == jnp.inf: # infeasible cost matrix
            sink = -1
            break

        j = remaining[index]

        pred = row4col[j] == -1
        sink = cond(pred, lambda: j, lambda: sink)
        i = cond(~pred, lambda: row4col[j], lambda: i)

        SC = SC.at[j].set(True)
        num_remaining -= 1
        remaining = remaining.at[index].set(remaining[num_remaining])

    return sink, minVal, remaining, SR, SC, shortestPathCosts, path

def solve(cost):
    transpose = cost.shape[1] < cost.shape[0]

    if transpose:
        cost = cost.T

    u = jnp.full(cost.shape[0], 0.)
    v = jnp.full(cost.shape[1], 0.)
    path = jnp.full(cost.shape[1], -1)
    col4row = jnp.full(cost.shape[0], -1)
    row4col = jnp.full(cost.shape[1], -1)

    for curRow in range(cost.shape[0]):

        j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)

        u = u.at[curRow].add(minVal)

        mask = SR & (jnp.arange(cost.shape[0]) != curRow)
        u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])

        v = v.at[SC].add(shortestPathCosts[SC] - minVal)

        while True:
            i = path[j]
            row4col = row4col.at[j].set(i)

            col4row, j = col4row.at[i].set(j), col4row[i]

            if i == curRow:
                break

    if transpose:
        v = col4row.argsort()
        return col4row[v], v
    else:
        return jnp.arange(cost.shape[0]), col4row

def main():
    key = random.PRNGKey(0)
    for t in count():
        key, subkey = random.split(key)
        shape = random.randint(subkey, [2], 0, 6)

        key, subkey = random.split(key)
        cost = random.uniform(subkey, shape)

        if t < 0: # skip to failing case
            continue

        row_ind_1, col_ind_1 = linear_sum_assignment(cost)
        row_ind_2, col_ind_2 = solve(cost)

        print('{:5} {}'.format(t,
            (row_ind_1 == row_ind_2).all() and 
            (col_ind_1 == col_ind_2).all()
        ))

if __name__ == '__main__':
    main()
Read more comments on GitHub >

github_iconTop Results From Across the Web

scipy.optimize.linear_sum_assignment — SciPy v1.9.3 Manual
The linear sum assignment problem [1] is also known as minimum weight matching in bipartite graphs. A problem instance is described by a...
Read more >
Scipy - Linear Sum Assignment - Show the Workings
So, I am now trying to work with option #2. Here is what I have so far: import numpy as np from scipy.optimize...
Read more >
Optimization (scipy.optimize) — SciPy v1.11.0.dev0+1176.121 ...
Optimization ( scipy.optimize ) ... Linear sum assignment problem example ... Code which makes use of this Hessian product to minimize the Rosenbrock ......
Read more >
[Python] 7 lines Hungarian algorithm/Linear sum assignment ...
from scipy.optimize import linear_sum_assignment import numpy as np class Solution: def maximumANDSum(self, nums: List[int], m: int) -> int: ...
Read more >
Linear Sum Assignment Solver | OR-Tools - Google Developers
This section describes the linear sum assignment solver, a specialized solver for the simple assignment problem, which can be faster than ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found