question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Transform Feedforward-Network + solver into a Recurrent-Network

See original GitHub issue

Hello Patrick,

let me first quickly motivate my feature request. As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method

def select_action(params, state, observation, time):
    apply = neural_network.apply
    state, action = apply(params, state, observation, time)
    return state, action

while True:
    action = select_action(..., observation, env.time)
    observation = env.step(action)

Typically, the apply-function is some recurrent neural network. Suppose the environment env is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.

I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.

def select_action(params, ode_state, observation, time):
    rhs = lambda x,u: neural_network.apply(params, x, u)
    solution, ode_state = odeint(ode_state, rhs, t1=time, u=(observation, time))
    return ode_state, solution.x(time)

I would like to emphasis that this select_action must remain differentiable: The x-output w.r.t the network parameters.

I would love to hear your input 😃 Anyways thank you in advance.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
patrick-kidgercommented, May 29, 2022

Yep, this is definitely possible. Diffrax is intrinsically differentiable so no special care is needed. Untested, but perhaps something like the following:

import equinox as eqx
import diffrax as dfx
import jax.numpy as jnp

# wraps an MLP to concatenate state and observation together
class Func(eqx.Module):
   mlp: eqx.nn.MLP

   def __init__(self,  state_size, observation_size, width_size, depth, key):
       in_size = 1 + state_size + observation_size
       self.mlp = eqx.nn.MLP(in_size, state_size, width_size, depth, key=key)

   def __call__(self, t, state, observation):
       in_ = jnp.concatenate([t[None], state, observation])
       return self.mlp(in_)

func =  Func(...)
get_action = eqx.nn.MLP(...)

def select_action(model, state, observation, time):
   func, get_action = model
   prev_time, state = state
   term = dfx.ODETerm(func)
   # specify solver, dt0, stepsize_controller in whatever way you think appropriate
   solver = ...
   dt0 = ...
   stepsize_controller = ...
   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
                         stepsize_controller=stepsize_controller)
   (state,) = sol.ys
   action = get_action(state)
   state = (time, state)
   return state, action

It’s not critical, but as a nice-to-have this uses Equinox as a convenient neural network library.

0reactions
SimiPixelcommented, Jun 5, 2022

For completeness let me post my minimal working example. Spoiler: This uses haiku, simply because i am already comfortable with that. Equinox probably would make this more beautiful 😃

import diffrax as dfx
import jax.numpy as jnp
from acme.jax import utils 
from functools import partial 
import haiku as hk 
import jax 

sampling_rate = 100 # Hz
stepsize_controller = dfx.ConstantStepSize()
dt0 = 1/sampling_rate
solver = dfx.Euler()

action_size = 3
obs_size = 2
u_dummy = jnp.ones((action_size))

latent_state_size = 20
hidden_layers = [50,50]
@hk.without_apply_rng
@hk.transform_with_state
def rhs(t, u):
   t = jnp.atleast_1d(t)
   x = hk.get_state("x", shape=(latent_state_size,), init=jnp.zeros, dtype=jnp.float32)
   txu = utils.batch_concat((t,x,u), num_batch_dims=0)
   X = hk.nets.MLP(hidden_layers + [latent_state_size])(txu)
   return {"~": {"x": X}}

def haiku2dfx_rhs(rhs):
   def __rhs(params):
      def _rhs(t, x, u):
         # x is simply passed through
         dxdt, x = rhs(params, x, t, u)
         del x 
         return dxdt 
      return _rhs 
   return __rhs 

# this is not great / quite confusing
dxdt = haiku2dfx_rhs(rhs.apply)

@hk.without_apply_rng
@hk.transform  
def measurement_function(x):
   x = utils.batch_concat(x, num_batch_dims=0)
   C = hk.get_parameter("C", shape=(obs_size,x.shape[-1]), dtype=jnp.float32,
      init=lambda shape, dtype: jax.random.normal(hk.next_rng_key(), shape, dtype=dtype))
   return jnp.matmul(C, x)

def gen_init_solver_state(solver: dfx.AbstractSolver, params_rhs, x0):
   term = dfx.ODETerm(dxdt(params_rhs))
   t0=0.0 
   return solver.init(term, t0=t0, t1=t0+dt0, y0=x0, args=u_dummy)

def gen_init_controller_state():
   return dt0

saveat = dfx.SaveAt(t1=True,solver_state=True,controller_state=True,made_jump=True)

def step_fun_dynamics_to_time(params, state, u, time):
   prev_time, x, solver_state, controller_state, made_jump = state
   term = dfx.ODETerm(dxdt(params["rhs"]))

   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, x, args=u,
                         stepsize_controller=stepsize_controller, saveat=saveat,
                         solver_state=solver_state, controller_state=controller_state,
                         made_jump=made_jump
                         )
   x = sol.ys
   x = utils.squeeze_batch_dim(x)
   state = (time, x, sol.solver_state, sol.controller_state, sol.made_jump)
   obs = measurement_function.apply(params["C"], x)
   return state, obs 

def step_fun_dynamics(params, state, u):
   prev_time = state[0]
   return step_fun_dynamics_to_time(params, state, u, prev_time + dt0)

@jax.jit 
@partial(jax.vmap, in_axes=(None, None, 0))
def unrolled_step_fun_dynamics(params, state, us):
   step_fun_dynamics_constraint = lambda state, u: step_fun_dynamics(params, state, u)
   state, obss = jax.lax.scan(step_fun_dynamics_constraint, init=state, xs=us)
   return obss

# initialise parameters
params_rhs, x0 = rhs.init(jax.random.PRNGKey(1), 0.0, u_dummy)
C = measurement_function.init(jax.random.PRNGKey(1), jnp.ones((latent_state_size,)))

params = {
    "rhs": params_rhs,
    "C": C 
}

# initialise step functions state
# (t0, x0, solver_state0, controller_state0, made_jump0)
made_jump0 = False 
init_state = (0.0, x0, gen_init_solver_state(solver, params_rhs, x0), gen_init_controller_state(), made_jump0)

# make prediction
bs=32
T=5.0 
uss = jnp.ones((bs, int(T*sampling_rate), action_size))
obsss = unrolled_step_fun_dynamics(params, init_state, uss)

Thanks Patrick for your help. It works perfectly.

Read more comments on GitHub >

github_iconTop Results From Across the Web

What is Transformer Network | Towards Data Science
The Transformer Neural Network is a novel architecture that aims to solve sequence-to-sequence tasks while handling long-range dependencies ...
Read more >
Transformer Neural Networks: A Step-by-Step Breakdown
Now, the second step is the feed-forward neural network. A simple feed-forward neural network is applied to every attention vector to transform ...
Read more >
What is the role of feed forward layer in Transformer Neural ...
The feed-forward layer is weights that is trained during training and the exact same matrix is applied to each respective token position.
Read more >
MIT 6.S191: Recurrent Neural Networks and Transformers
MIT Introduction to Deep Learning 6.S191: Lecture 2Recurrent Neural NetworksLecturer: Ava SoleimanyJanuary 2022For all lectures, slides, ...
Read more >
arXiv:2012.14913v2 [cs.CL] 5 Sep 2021
Feed -forward layers constitute two-thirds of a transformer model's parameters, yet their role in the network remains under-explored. We.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found