Logging metrics during an ODE solve
See original GitHub issueHello @patrick-kidger,
thank you for open-sourcing this nice library! I was going to resume work on my own small ODE lib, but since this is much more elaborate than what I came up with so far, I am inclined to use this instead for a small project in the future.
One question that came up to me when reading the source code: Is there currently a way to compute step-wise metrics during the solve? (Think logging step sizes, Jacobian eigenvalues, etc.)
This would presumably happen in the integrate
method. Could I e.g. use the solver_state
pytree for this in, say, overridden solver classes? Thank you for your consideration.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (7 by maintainers)
Top Results From Across the Web
LOG-ODE METHOD - OpenReview
Here, we apply the log-ODE method, which is a numerical method from stochastic analysis and rough path theory. It is a method for...
Read more >How does changing metrics help to find solutions to a partial ...
I am taking a course in functional analysis and while reviewing the definition of a metric and various examples, my professor mentioned that...
Read more >Solving differential equations in Python with ode-explorer
Logging and metrics calculations. Another design emphasis is that one should be able to calculate metrics along an integration curve. This could ...
Read more >(PDF) DIFFERENTIAL EQUATIONS ON METRIC GRAPH
PDF | This is a development report for the investigation of the partial differential equation networks. In this report, we mainly discuss ...
Read more >How to Best Use MTT* Metrics to Optimize Your Incident ...
The goal is to proactively monitor the system behavior (metrics and logs) to identify anomalous patterns and symptoms as opposed to relying on ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
So these are an example of an “integral loss” (or regularisation term): a loss of the form
\int k(y(t)) dt
for some functionk
. Indeed it’s a known trick to avoid calculating these on the forward pass, and only calculate their gradients on the backward pass.(In practice this is pretty much only ever possible when using a custom backward pass like
BacksolveAdjoint
though. I don’t think any autodifferentiation framework, JAX or otherwise, is currently smart enough to elide the forward pass if using the “standard”RecursiveCheckpointAdjoint
.)Right now the best way to tackle this in Diffrax is indeed just to append these values during the forward calculation. Even with other libraries, I think people have often had to really know what they’re doing and then hack something together, since it’s such an edge-case in terms of autodifferentiation frameworks.
I’ve opened #69 to track this though, as this would be nice to have. CC @jacobjinkelly CC @Zymrael as I recall you have something similar in torchdyn?
@patrick-kidger – Awesome library
In this context of “logging-metrics”, I was wondering how one might go about implementing Learning Differential Equations that are Easy to Solve. My naive implementation “appended” the regularizing term to the state. However, from discussions with one of the authors, they suggest a more efficient way to implement it.
Do you have any suggestions on how to tackle this with diffrax? – happy to open a separate issue if you would prefer.