question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Inverse Accumulation Mode

See original GitHub issue

Inverted Jacobian products are useful in a variety of algorithms such as the efficient implementation of Newton’s method with regularization. However, Jax currently only provides non-inverted Jacobian products (jvp and vjp).

It appears that it is possible to efficiently implement inverted Jacobian products in an automatic differentiation library like Jax thanks to a recent paper:

Siskind, Jeffrey Mark. “Automatic Differentiation: Inverse Accumulation Mode.” (2019).

The interface for the inverted Jacobians could be something like:

  • ijvp that accepts a function and primals and produces a function f_ijvp that accepts input cotangents and produces output cotangents, and
  • ivjp that accepts a function and primals and produces a function f_ivjp that accepts output tangents and produces input tangents.

Using these, we could also produce inverse_jacfwd and inverse_jacrev, one of which could be mapped to inverse_jacobian.

Has anyone on the Jax team looked into this?

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:5
  • Comments:8 (6 by maintainers)

github_iconTop GitHub Comments

3reactions
mattjjcommented, Sep 29, 2022

Thanks to the inverse function theorem, another way to compute the same thing is to compose jvp, vjp, jacfwd, or jacrev with the oryx.core.inverse transformation. That is, for f : R^n \to R^n, we have x \mapsto ∂ f^{-1}(x) == x \mapsto inv(∂f(x)) pointwise, where inv is meant to denote dense matrix inverse. Actually I’d be interested if that ends up generating the same computation as discussed in that paper (which I admit not to have read yet, beyond the abstract!).

I think jacfwd-of-oryx.core.inverse-of-f would in general be a fairly different computation than jnp.linalg.inverse-of-jacfwd-of-f because the former would exploit sparsity structure represented in the program and its dataflow, whereas the latter would just be operating on a dense matrix.

2reactions
qobicommented, Sep 30, 2022

Indeed the advantage of our approach is that it is compositional and thus has running time proportional to the primal. But like ordinary reverse mode, it requires a tape whose size is proportional to the the running time.

There is no need to do Gaussian elimination. The only inversion is that of the scalar a in (5a right) because steps (4) involve only unary and binary operations.

The catch is that each step must be what we call “equasive”; it must have the same number of inputs and outputs. (We call a step that has more outputs than inputs “expansive” and a step that has fewer outputs than inputs “contractive”.) Equasive steps have square Jacobians; Expansive and Contractive steps do not. Non-square Jacobians are not invertible.

It is possible that the overall computation is equasive but the individual steps are not. If the dimension of the intermediate state is ever smaller than the input/output dimension, then the Jacobian is not invertible. But it is possible that as the computation progresses, the dimension increases and then decreases, perhaps more than once, never going below the input/output dimension. It is further possible that the dimension varies over the course of the computation (alway being above the input/output dimension) but at one or more intermediate point returns to the input/output dimension. If this is the case, it is possible to split the computation at those points into a sequence of equasive chunks we call “lumps”. It is then possible to apply the method to that sequence of lumps. When doing this, the lumps would no longer be steps consisting of unary and binary operations. Thus instead of (5a) you have (5b) where in (5b right) you have to invert A. The saving grace is that the dimension of A is likely to be (much) smaller than the input/output dimension of the whole problem.

We implemented the stepwise equasive variant in a variant of R6RS-AD.

https://github.com/qobi/R6RS-AD

(Note that JAX is based on HIPS Autograd which is based on R6RS-AD. R6RS-AD predates HIPS Autograd by about 7 years and predates JAX by about a decade. R6RS-AD was used in

@inproceedings{nips2011, author = {D. Wingate and N. Goodman and A. Stuhlm{"{u}}ller and J. M. Siskind}, title = {Nonstandard Interpretations of Probabilistic Programs for Efficient Inference}, booktitle = nips, location = {Granada, Spain}, day = {12–15}, month = dec, year = 2011, url = {http://engineering.purdue.edu/~qobi/papers/nips2011.pdf}} )

The implementation is straightforward and enclosed. We worked on methods to automatically divide an arbitrary computation graph into lumps (what we call “lumpification”). But that work is not complete. It is complicated because the dimension of the intermediate state depends on how you schedule the operations. Thus the possible lumpifications depend on scheduling. It appears to be NP hard to optimally schedule to minimize the dimension of the A of the maximal lump.

The link to stackexchange discusses the Moore-Penrose inverse of non-square Jacobians. We spent some time investigating this, as well as a variety of other pseudoinverses besides the Moore-Penrose pseudoinverse. We are unaware of any pseudoinverse that is compositional. Compositonality is required to make (3) work. There might be one that is compositional (and useful) that we are unaware of. It also might be the case that a product of Moore-Penrose (or other) pseudoinverse, while not preserving the properties of that pseudoinverse, might still be useful. We never got very far along this line of investigation.

#!r6rs

(library
  (tape-AD)
 (export (rename (d+ +))
	 (rename (d- -))
	 (rename (d* *))
	 (rename (d/ /))
	 (rename (dsqrt sqrt))
	 (rename (dexp exp))
	 (rename (dlog log))
	 (rename (dexpt expt))
	 (rename (dsin sin))
	 (rename (dcos cos))
	 (rename (datan atan))
	 (rename (d= =))
	 (rename (d< <))
	 (rename (d> >))
	 (rename (d<= <=))
	 (rename (d>= >=))
	 (rename (dzero? zero?))
	 (rename (dpositive? positive?))
	 (rename (dnegative? negative?))
	 (rename (dreal? real?))
	 write-real
	 j*
	 *j
	 j*^-1
	 *j^-1)
 (import (rnrs))

 (define-record-type tape
  (fields primal
	  factors
	  tapes
	  (mutable fanout)
	  (mutable co/tangent)))

 (define (new-tape primal factors tapes)
  (make-tape primal factors tapes 0 0))

 (define (tapify x) (new-tape x '() '()))

 (define (lift-real->real f df/dx)
  (letrec ((self (lambda (x)
		  (if (tape? x)
		      (new-tape (self (tape-primal x))
				(list (df/dx (tape-primal x)))
				(list x))
		      (f x)))))
   self))

 (define (lift-real*real->real f df/dx1 df/dx2)
  (letrec ((self
	    (lambda (x1 x2)
	     (if (tape? x1)
		 (if (tape? x2)
		     (new-tape (self (tape-primal x1) (tape-primal x2))
			       (list (df/dx1 (tape-primal x1) (tape-primal x2))
				     (df/dx2 (tape-primal x1) (tape-primal x2)))
			       (list x1 x2))
		     (new-tape (self (tape-primal x1) x2)
			       (list (df/dx1 (tape-primal x1) x2))
			       (list x1)))
		 (if (tape? x2)
		     (new-tape (self x1 (tape-primal x2))
			       (list (df/dx2 x1 (tape-primal x2)))
			       (list x2))
		     (f x1 x2))))))
   self))

 (define first car)

 (define second cadr)

 (define rest cdr)

 (define (fold f l)
  (let loop ((l (cdr l)) (c (first l)))
   (if (null? l) c (loop (rest l) (f c (first l))))))

 (define (count-if p l)
  (let loop ((l l) (c 0))
   (cond ((null? l) c)
	 ((p (first l)) (loop (rest l) (+ c 1)))
	 (else (loop (rest l) c)))))

 (define (map-reduce g i f l . ls)
  (if (null? l)
      i
      (apply map-reduce
	     g
	     (g i (apply f (first l) (map first ls)))
	     f
	     (rest l)
	     (map rest ls))))

 (define (list-remove-ith l i)
  (if (zero? i) (rest l) (cons (first l) (list-remove-ith (rest l) (- i 1)))))

 (define (position-if p l)
  (let loop ((l l) (i 0))
   (cond ((null? l) #f)
	 ((p (first l)) i)
	 (else (loop (rest l) (+ i 1))))))

 (define (lift-real^n->real f df/dx1 df/dx2)
  (lambda xs
   (if (null? xs) (f) (fold (lift-real*real->real f df/dx1 df/dx2) xs))))

 (define (lift-real^n+1->real f df/dx df/dx1 df/dx2)
  (lambda xs
   (cond ((null? xs) (f))
	 ((null? (rest xs)) ((lift-real->real f df/dx) (first xs)))
	 (else (fold (lift-real*real->real f df/dx1 df/dx2) xs)))))

 (define (primal* x) (if (tape? x) (primal* (tape-primal x)) x))

 (define (lift-real^n->boolean f) (lambda xs (apply f (map primal* xs))))

 (define d+ (lift-real^n->real + (lambda (x1 x2) 1) (lambda (x1 x2) 1)))

 (define d- (lift-real^n+1->real
	     - (lambda (x) -1) (lambda (x1 x2) 1) (lambda (x1 x2) -1)))

 (define d* (lift-real^n->real * (lambda (x1 x2) x2) (lambda (x1 x2) x1)))

 (define d/ (lift-real^n+1->real
	     /
	     (lambda (x) (d- (d/ (d* x x))))
	     (lambda (x1 x2) (d/ x2))
	     (lambda (x1 x2) (d- (d/ x1 (d* x2 x2))))))

 (define dsqrt (lift-real->real sqrt (lambda (x) (d/ (d* 2 (dsqrt x))))))

 (define dexp (lift-real->real exp (lambda (x) (dexp x))))

 (define dlog (lift-real->real log (lambda (x) (d/ x))))

 (define dexpt
  (lift-real*real->real expt
			(lambda (x1 x2) (d* x2 (dexpt x1 (d- x2 1))))
			(lambda (x1 x2) (d* (dlog x1) (dexpt x1 x2)))))

 (define dsin (lift-real->real sin (lambda (x) (dcos x))))

 (define dcos (lift-real->real cos (lambda (x) (d- (dsin x)))))

 (define (datan . xs)
  (cond ((null? xs) (apply atan xs))
	((null? (rest xs)) (datan (first xs) 1))
	((null? (rest (rest xs)))
	 ((lift-real*real->real
	   atan
	   (lambda (x1 x2) (d/ x2 (d+ (d* x1 x1) (d* x2 x2))))
	   (lambda (x1 x2) (d/ (d- x1) (d+ (d* x1 x1) (d* x2 x2)))))
	  (first xs)
	  (second xs)))
	(else (apply atan xs))))

 (define d= (lift-real^n->boolean =))

 (define d< (lift-real^n->boolean <))

 (define d> (lift-real^n->boolean >))

 (define d<= (lift-real^n->boolean <=))

 (define d>= (lift-real^n->boolean >=))

 (define dzero? (lift-real^n->boolean zero?))

 (define dpositive? (lift-real^n->boolean positive?))

 (define dnegative? (lift-real^n->boolean negative?))

 (define dreal? (lift-real^n->boolean real?))

 (define (write-real x)
  (cond ((tape? x) (write-real (tape-primal x)) x)
	(else (write x) (newline) x)))

 (define (determine-fanout! tape)
  (tape-fanout-set! tape (+ (tape-fanout tape) 1))
  (when (= (tape-fanout tape) 1)
   (for-each determine-fanout! (tape-tapes tape))))

 (define (initialize-co/tangent! tape)
  (tape-co/tangent-set! tape 0)
  (tape-fanout-set! tape (- (tape-fanout tape) 1))
  (when (zero? (tape-fanout tape))
   (for-each initialize-co/tangent! (tape-tapes tape))))

 (define (forward-accumulation-sweep tape)
  (if (null? (tape-tapes tape))
      (tape-co/tangent tape)
      (map-reduce d+
		  0
		  (lambda (factor tape)
		   (d* factor (forward-accumulation-sweep tape)))
		  (tape-factors tape)
		  (tape-tapes tape))))

 (define (reverse-accumulation-sweep! tape)
  (when (zero? (tape-fanout tape))
   (for-each (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1)))
	     (tape-tapes tape))
   (let ((cotangent (tape-co/tangent tape)))
    (for-each (lambda (factor tape)
	       (tape-co/tangent-set!
		tape (d+ (tape-co/tangent tape) (d* cotangent factor))))
	      (tape-factors tape)
	      (tape-tapes tape)))
   (for-each reverse-accumulation-sweep!(tape-tapes tape))))

 (define (stepwise-equasive-forward-inverse-accumulation-sweep tape)
  (cond
   ((null? (tape-tapes tape)) (tape-co/tangent tape))
   (else
    (for-each (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1)))
	      (tape-tapes tape))
    (unless (= (count-if (lambda (tape) (zero? (tape-fanout tape)))
			 (tape-tapes tape))
	       1)
     (error #f "Not equasive"))
    (let* ((i (position-if (lambda (tape) (zero? (tape-fanout tape)))
			   (tape-tapes tape)))
	   (factor-i (list-ref (tape-factors tape) i))
	   (tape-i (list-ref (tape-tapes tape) i))
	   (cotangent-i
	    (stepwise-equasive-forward-inverse-accumulation-sweep tape-i)))
     ;;\needswork: commutativity
     (- (d/ cotangent-i factor-i)
	(map-reduce
	 d+
	 0
	 (lambda (factor tape)
	  ;;\needswork: commutativity
	  (d/ (d* (stepwise-equasive-forward-inverse-accumulation-sweep tape)
		  factor)
	      factor-i))
	 (list-remove-ith (tape-factors tape) i)
	 (list-remove-ith (tape-tapes tape) i)))))))

 (define (stepwise-equasive-reverse-inverse-accumulation-sweep! tape)
  (unless (null? (tape-tapes tape))
   (when (zero? (tape-fanout tape))
    (for-each (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1)))
	      (tape-tapes tape))
    (unless (= (count-if (lambda (tape) (zero? (tape-fanout tape)))
			 (tape-tapes tape))
	       1)
     (error #f "Not equasive"))
    (let* ((i (position-if (lambda (tape) (zero? (tape-fanout tape)))
			   (tape-tapes tape)))
	   (factor-i (list-ref (tape-factors tape) i))
	   (tape-i (list-ref (tape-tapes tape) i))
	   (tangent-i (tape-co/tangent tape)))
     (for-each
      (lambda (factor tape)
       (tape-co/tangent-set!
	tape
	;;\needswork: commutativity
	(d- (tape-co/tangent tape) (d/ (d* tangent-i factor) factor-i))))
      (list-remove-ith (tape-factors tape) i)
      (list-remove-ith (tape-tapes tape) i))
     ;;\needswork: commutativity
     (tape-co/tangent-set! tape-i (d/ tangent-i factor-i))
     (for-each stepwise-equasive-reverse-inverse-accumulation-sweep!
	       (tape-tapes tape))))))

 (define (map-walk1 f x)
  ;;\needswork: Not safe for space.
  (cond ((eq? x #t) #t)
	((eq? x #f) #f)
	((null? x) '())
	((char? x) x)
	((string? x) x)
	((dreal? x) (f x))
	((pair? x) (cons (map-walk1 f (car x)) (map-walk1 f (cdr x))))
	;;\needswork: vectors
	(else (error #f "Not walkable"))))

 (define (map-walk2 f x x-prime)
  ;;\needswork: Not safe for space.
  (cond ((and (eq? x #t) (eq? x-prime #t)) #t)
	((and (eq? x #f) (eq? x-prime #f)) #f)
	((and (null? x) (null? x-prime)) '())
	((and (char? x) (char? x-prime) (char=? x x-prime)) x)
	((and (string? x) (string? x-prime) (string=? x x-prime)) x)
	((and (dreal? x) (dreal? x-prime)) (f x x-prime))
	((and (pair? x) (pair? x-prime))
	 (cons (map-walk2 f (car x) (car x-prime))
	       (map-walk2 f (cdr x) (cdr x-prime))))
	;;\needswork: vectors
	(else (error #f "Values don't conform: ~s ~s" x x-prime))))

 (define (for-each-walk1! f x)
  ;;\needswork: Not safe for space.
  (cond ((eq? x #t) #f)
	((eq? x #f) #f)
	((null? x) #f)
	((char? x) #f)
	((string? x) #f)
	((dreal? x) (f x))
	((pair? x) (for-each-walk1! f (car x)) (for-each-walk1! f (cdr x)))
	;;\needswork: vectors
	(else (error #f "Not walkable"))))

 (define (for-each-walk2! f x x-prime)
  ;;\needswork: Not safe for space.
  (cond ((and (eq? x #t) (eq? x-prime #t)) #f)
	((and (eq? x #f) (eq? x-prime #f)) #f)
	((and (null? x) (null? x-prime)) #f)
	((and (char? x) (char? x-prime) (char=? x x-prime)) #f)
	((and (string? x) (string? x-prime) (string=? x x-prime)) #f)
	((and (dreal? x) (dreal? x-prime)) (f x x-prime))
	((and (pair? x) (pair? x-prime))
	 (for-each-walk2! f (car x) (car x-prime))
	 (for-each-walk2! f (cdr x) (cdr x-prime)))
	;;\needswork: vectors
	(else (error #f "Values don't conform: ~s ~s" x x-prime))))

 (define (tape-mode mode f x co/tangent)
  (let* ((x-tape (map-walk1 tapify x))
	 (y-tape (f x-tape)))
   (case mode
    ((forward)
     (for-each-walk2! tape-co/tangent-set! x-tape co/tangent)
     (list (map-walk1 tape-primal y-tape)
	   (map-walk1 forward-accumulation-sweep y-tape)))
    ((reverse)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1! initialize-co/tangent! y-tape)
     (for-each-walk2! tape-co/tangent-set! y-tape co/tangent)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1!
      (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1))) y-tape)
     (for-each-walk1! reverse-accumulation-sweep! y-tape)
     (list (map-walk1 tape-primal y-tape)
	   (map-walk1 tape-co/tangent x-tape)))
    ((forward-inverse)
     (for-each-walk2! tape-co/tangent-set! x-tape co/tangent)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1!
      (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1))) y-tape)
     (list (map-walk1 tape-primal y-tape)
	   ;; What you return here will depend on  whether it is stepwise
	   ;; equasive, expansive, or contractive.
	   (map-walk1 stepwise-equasive-forward-inverse-accumulation-sweep
		      y-tape)))
    ((reverse-inverse)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1! initialize-co/tangent! y-tape)
     (for-each-walk2! tape-co/tangent-set! y-tape co/tangent)
     (for-each-walk1! determine-fanout! y-tape)
     (for-each-walk1!
      (lambda (tape) (tape-fanout-set! tape (- (tape-fanout tape) 1))) y-tape)
     ;; What you do here will depend on  whether it is stepwise equasive,
     ;; expansive, or contractive.
     (for-each-walk1!
      stepwise-equasive-reverse-inverse-accumulation-sweep! y-tape)
     (list (map-walk1 tape-primal y-tape)
	   (map-walk1 tape-co/tangent x-tape)))
    (else (error #f "Unknown mode")))))

 (define (j* f x x-tangent) (tape-mode 'forward f x x-tangent))

 (define (*j f x y-cotangent) (tape-mode 'reverse f x y-cotangent))

 (define (j*^-1 f x x-cotangent) (tape-mode 'forward-inverse f x x-cotangent))

 (define (*j^-1 f x y-tangent) (tape-mode 'reverse-inverse f x y-tangent)))
Read more comments on GitHub >

github_iconTop Results From Across the Web

Automatic Differentiation: Inverse Accumulation Mode
Abstract: We show that, under certain circumstances, it is possible to automatically compute Jacobian-inverse-vector and ...
Read more >
3: (a) Accumulation, (b) depletion, and (c) inversion modes in ...
This work comprises the study of oxide semiconductors (Sb-doped SnO2 and TiO2) and insulating materials (ZrO2) obtained by sol-gel, and the investigation of ......
Read more >
MOS Capacitor - Accumulation Mode Explained - YouTube
https://www.patreon.com/edmundsjIf you want to see more of these videos, or would like to say thanks for this one, the best way you can...
Read more >
MOS Transistor - Modes of Operation | Know - How - YouTube
The structure of simple nMOS is shown and voltage conditions are varied to depict the accumulation, depletion and inversion modes.
Read more >
MOS Capacitor - Inversion Mode Explained - YouTube
https://www.patreon.com/edmundsjIf you want to see more of these videos, or would like to say thanks for this one, the best way you can...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found