neural-ode compatibility with dm-haiku models ?
See original GitHub issueI am currently following the Neural-ODE tutorial on the Diffrax documentation (https://docs.kidger.site/diffrax/examples/neural_ode/
). I was wondering if instead of instantiating Equinox modules the Func
and NeuralODE
there is a work around to incorporate Haiku modules (hk.Module
) ?
My end goal would be to use diffrax.diffeqsolve
on a ode term that uses hk.Module
as such :
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts), # saving all of the intermediate states (disc then opt)
)
where self.func = Func()
and Func
is a Haiku module as below :
class Func(hk.Module):
def __init__(self, name=None):
super().__init__(name=name)
self.layers = []
for out_shape in [2,10]:
self.layers.append(hk.Linear(out_shape, name=name))
self.layers.append(jax.nn.relu)
self.layers.append(hk.Linear(1, name=name))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
Issue Analytics
- State:
- Created a year ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
deepmind/dm-haiku: JAX-based neural network library - GitHub
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to...
Read more >Haiku API reference
Transforms a function using Haiku modules into a pair of pure functions. multi_transform (f). Transforms a collection of functions using Haiku into pure ......
Read more >Neural ODEs: breakdown of another deep learning ...
Neural ODEs : breakdown of another deep learning breakthrough. Visualization of the Neural ODE learning the dynamical system. Hi everyone!
Read more >Variational Neural Networks implementation in Pytorch and JAX
the uncertainty of a neural network by sampling different models for each input. ... v1.0.0. Permanent link to code/repository used for this code...
Read more >dm-haiku 0.0.8 - PythonFix.com
Haiku is a library for building neural networks in JAX. ... /pkg/d/dm-haiku/dm-haiku-banner.webp ... Here are some dm-haiku code examples and snippets.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
That’s correct. I don’t think you’re doing anything like that at the moment. So you’ll need to pass those
params
through to the vector field each time you call it.i.e. something like
This issue isn’t anything to do with Diffrax I’m afraid; what you’re describing are difficulties with Haiku. If you want to use Haiku then I suggest reading up on how to use Haiku, and perhaps trying it in a non-neural-ODE context first to gain familiarity with that library.
Im not sure to understand, in order to use
transformed_object.apply
we need the params of our model that is given by invoking thetransformed_object.init