question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Compatibility with Flax

See original GitHub issue

Will diffrax.diffeqsolve work inside a Flax linen Module? How would you set up the initialization to use Flax inside of ODETerm instead of Equinox?

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
patrick-kidgercommented, Sep 16, 2022

Hurrah! I’m glad you figured this out.

0reactions
ameya98commented, Sep 16, 2022

Ah, I figured out a (slightly hacky) way to do this:

class NeuralODE(flax.struct.PyTreeNode):
    """A simple neural ODE."""

    encoder: nn.Module
    derivative_net: nn.Module
    decoder: nn.Module

    def init(self, rng, coords):
        rng, encoder_rng, derivative_net_rng, decoder_rng = jax.random.split(rng, 4)
        coords, encoder_params = self.encoder.init_with_output(encoder_rng, coords)
        coords, derivative_net_params = self.derivative_net.init_with_output(derivative_net_rng, coords)
        coords, decoder_params = self.decoder.init_with_output(decoder_rng, coords)

        return {
            "encoder": encoder_params,
            "derivative_net": derivative_net_params,
            "decoder": decoder_params
        }

    def apply(self, params, coords):
        coords = self.encoder.apply(params["encoder"], coords)

        def f(t, y, args):
            return self.derivative_net.apply(params["derivative_net"], y)

        term = diffrax.ODETerm(f)
        solver = diffrax.Euler()
        solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=coords)
        coords = solution.ys
        coords = self.decoder.apply(params["decoder"], coords)
        return coords

rng = jax.random.PRNGKey(0)
coords = jnp.ones((1, 4))

model = NeuralODE(
    encoder=nn.Dense(10),
    derivative_net=nn.Dense(10),
    decoder=nn.Dense(4))
params = jax.jit(model.init)(rng, coords)

Then you can simply use this like any other nn.Module:

@jax.jit
def compute_loss(params, coords, true_coords):
    preds = model.apply(params, coords)
    return jnp.abs(preds - true_coords).sum()

grads = jax.grad(compute_loss)(params, coords, jnp.zeros_like(coords))

This just uses flax.struct.PyTreeNode instead of eqx.Module. I didn’t want to mix both of them in my codebase. Thanks a lot for the help!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fluorination of flax fibers for improving the interfacial ...
Thereby, treated fibers become perfectly compatible with hydrophobic polymer, and an improvement in the mechanical performance of the resulting composite is ...
Read more >
Versatile Flaxseed Linked to Many Health Benefits
Science shows flaxseed's fiber, healthy fats, protein and other nutrients may help you avoid ills from cancer to diabetes and heart disease ...
Read more >
Flax Seeds - 88 Acres
Flax is one fibrous little seed. Fiber from foods like flax can help improve digestive health, lower cholesterol, and even reduce the risk...
Read more >
Adsorption of albumin on flax fibers increases ... - PubMed
Adsorption of albumin on flax fibers increases endothelial cell adhesion and blood compatibility in vitro. J Biomater Sci Polym Ed. 2014;25(7):698-712. doi: ...
Read more >
Checkpointing with flax.training.checkpoints - Read the Docs
... '--xla_force_host_platform_device_count=8' command works only with ... with `@flax.struct.dataclass` decorator to make it compatible. tx ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found