question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Padding doesn't work in jitted mode

See original GitHub issue

Hi,

Thanks for the work on the library. For a project I’m needing to implement a padding on the last dimension of a tensor. Everything works fine when the function is not jitted, but I get a conversion error when I jit it. It seems to me like it is coming from a bug in the lax_numpy internals (this line), but maybe I’m missing something…

Adrien

The code is as follows:

from jax import jit as jjit
import jax.numpy as jnp
import numpy as np

vals = np.random.randn(50, 100)

def pad_last_dim(array, pad_size):
    ndim = jnp.ndim(array)
    npad = jnp.zeros((ndim, 2), dtype=jnp.int32)
    axis = ndim - 1
    npad = ops.index_update(npad, ops.index[axis, 1], pad_size)
    return jnp.pad(array, npad, 'constant', constant_values=0)

print(pad_last_dim(vals, 1))  # All good
print(jjit(pad_last_dim)(vals, 1))  # raises

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
AdrienCorenfloscommented, Jul 1, 2020

Thanks a lot for that Matt. Bit late for me but I’ll come back to it tomorrow. I think you should forbid lists and ndarrays tbh. Would be consistent with tf where they consider lists to be a collection of some kind.

In the meantime I like your solution and I’ll use it (I don’t need the ndarray, just was easier to set the padding)

0reactions
AdrienCorenfloscommented, Jul 2, 2020
from jax import jit as jjit
import jax.numpy as jnp
from jax import ops
import numpy as np

vals = np.random.randn(50, 100)

def pad_last_dim(array, pad_size):
    ndim = jnp.ndim(array)
    npad = [[0, pad_size]] * ndim
    return jnp.pad(array, npad, 'constant', constant_values=0)

print(pad_last_dim(vals, 1))  # All good
print(jjit(pad_last_dim, static_argnums=(1,))(vals, 1))

Your solution worked for me, thanks a lot!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Margin doesn't properly work? Padding doing job of margin?
It doesn't work because of a CSS quirk known as collapsing margins. If you need a border around your h3 element you have...
Read more >
Just-in-Time Mode - Tailwind CSS
Since JIT mode generates your CSS on-demand by scanning your template ... Tailwind doesn't include any sort of client-side runtime, so class names...
Read more >
Order Extra Inventory by Padding JIT POs - RetailOps
UNPUBLISHED- I don't believe this works using SKU mode POs anymore. In some cases, after customers have ordered a JIT-eligible item...
Read more >
How to apply arbitrary styles with tailwind JIT - Educative.io
Unlike component frameworks or libraries, Tailwind doesn't restrict you to ... up a Vue project and install Tailwind, and after that, we'll set...
Read more >
Conv3d — PyTorch 1.13 documentation
padding ='same' pads the input so the output has the shape as the input. However, this mode doesn't support any stride values other...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found