question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Need a way to specify constants

See original GitHub issue

I have a use case to jit a function that does slicing depending on a dict arguments

@jit
def foo(x, d: dict):
  return x[:d['num']] + d['val']

foo(jnp.arange(10), {'num': 5, 'value': np.arange(5)})

no surprise I got IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax.

static_argnums is yet not working as expected since I have an array in d

@partial(jit, static_argnums=1)
def foo(x, d: dict):
  return x[:d['num']] + d['val']

foo(jnp.arange(10), {'num': 5, 'value': np.arange(5)})

I got ValueError: Non-hashable static arguments are not supported.

Here is a working dirty trick

@jit
def bar(x, d: dict):
  return x[:d['num'].shape[0]] + d['val']

bar(jnp.arange(10), {'num': jnp.zeros((5, 0)), 'val': np.arange(5)})

I got DeviceArray([0, 2, 4, 6, 8], dtype=int32)

Is there a way to specify argument d['num'] as constant?

The above weird use case is more reasonable in Flax whose variables are collected as a dict input.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
NeilGirdharcommented, Mar 3, 2022

I want to understand better the difference between the two, since I noticed that in the latter case I can actually use non-hashable

If you put non-hashable data into static fields, the JIT lookup will fail. Jax doesn’t give you an error for putting unhashable data in static fields yet.

There’s some more information & discussion of this foot-gun in this issue

Another related issue #7826

1reaction
jakevdpcommented, Mar 3, 2022

For any custom class Python defaults to creating a hash function based on the object ID, without reference to any of the classes members or contents:

import jax.numpy as jnp

class Foo:
  def __init__(self, x):
    self.x = x

f = Foo(jnp.arange(4))
first_hash = hash(f)

f.x = jnp.ones(10)
second_hash = hash(f)

print("This is bad if True:", first_hash == second_hash)
# This is bad if True: True

This means that if you create a PyTree class and use it as a static argument, then mutate it, and pass it to the function again, you will get unexpected results. There’s some more information & discussion of this foot-gun in this issue: https://github.com/google/jax/issues/9024

If you create a new PyTree class each time, you will have a different object (and thus a different hash) and your function will be recompiled each time you call it.

There are ways you can hack around using jnp.ndarray as a static argument (essentially by wrapping it in some class that defines __hash__), but I wouldn’t recommend it because it’s easy to inadvertently do the wrong thing, and this kind of problem is usually better solved by re-writing your code in a manner that conforms to typical JAX approaches.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to define constants in C# | Microsoft Learn
Constants are fields whose values are set at compile time and can never be changed. Use constants to provide meaningful names instead of...
Read more >
Constants in C Explained – How to Use #define and the const ...
From the previous example, you have the constants STUDENT_ID and COURSE_CODE . Now you'll define them as constants using the const qualifier.
Read more >
Python Constants: Improve Your Code's Maintainability
In this tutorial, you'll learn how to properly define constants in Python. By coding a bunch of practical example, you'll also learn how...
Read more >
What is the best way to implement constants in Java? [closed]
If you have access to Java 5+, use enums to define your specific constants for an application area. All parts of the application...
Read more >
Constants and Variables – Programming Fundamentals
A constant is a value that cannot be altered by the program during normal execution, i.e., the value is constant. When associated with...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found