question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

neural-ode compatibility with dm-haiku models ?

See original GitHub issue

I am currently following the Neural-ODE tutorial on the Diffrax documentation (https://docs.kidger.site/diffrax/examples/neural_ode/). I was wondering if instead of instantiating Equinox modules the Func and NeuralODE there is a work around to incorporate Haiku modules (hk.Module) ?

My end goal would be to use diffrax.diffeqsolve on a ode term that uses hk.Module as such :

solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),  # saving all of the intermediate states (disc then opt)
        )

where self.func = Func() and Func is a Haiku module as below :

class Func(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.layers = []
        for out_shape in [2,10]:
            self.layers.append(hk.Linear(out_shape, name=name))
            self.layers.append(jax.nn.relu)
        self.layers.append(hk.Linear(1, name=name))

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
patrick-kidgercommented, May 25, 2022

That’s correct. I don’t think you’re doing anything like that at the moment. So you’ll need to pass those params through to the vector field each time you call it.

i.e. something like

class Network(hk.Module):
    ...

    def __call__(self, t, y, args):
        ...

network = hk.without_apply_rng(hk.transform(...))
params = network.init(...)
vector_field = functools.partial(network.apply, params)
diffeqsolve(ODETerm(vector_field), ...)

This issue isn’t anything to do with Diffrax I’m afraid; what you’re describing are difficulties with Haiku. If you want to use Haiku then I suggest reading up on how to use Haiku, and perhaps trying it in a non-neural-ODE context first to gain familiarity with that library.

0reactions
thibmonselcommented, May 25, 2022

Im not sure to understand, in order to use transformed_object.apply we need the params of our model that is given by invoking the transformed_object.init

Read more comments on GitHub >

github_iconTop Results From Across the Web

deepmind/dm-haiku: JAX-based neural network library - GitHub
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to...
Read more >
Haiku API reference
Transforms a function using Haiku modules into a pair of pure functions. multi_transform (f). Transforms a collection of functions using Haiku into pure ......
Read more >
Neural ODEs: breakdown of another deep learning ...
Neural ODEs : breakdown of another deep learning breakthrough. Visualization of the Neural ODE learning the dynamical system. Hi everyone!
Read more >
Variational Neural Networks implementation in Pytorch and JAX
the uncertainty of a neural network by sampling different models for each input. ... v1.0.0. Permanent link to code/repository used for this code...
Read more >
dm-haiku 0.0.8 - PythonFix.com
Haiku is a library for building neural networks in JAX. ... /pkg/d/dm-haiku/dm-haiku-banner.webp ... Here are some dm-haiku code examples and snippets.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found