Scala DSL
See original GitHub issuePurpose
Introducing a secondary tensor operation DSL (Domain Specific Language) written in & optimised for Scala language & various compilers (the most common of which are JVM based scalac 2.12+, other options are Scala.js & Dotty)
Despite not being a dependently typed language, Scala may be favourable for some reasons:
- Scala is the preferred language of Apache Spark, Apache Flink, CHISEL (Constructing Hardware In Scala Embedded Language) and NVidia RAPIDS, all of which are important infrastructures for large scale grid learning.
- Singleton types (shapeless / 2.13) allows arbitrary number of shape types to be generated on-demand, without relying on code generator or metaprogramming
- Path-dependent types allows shape algebra to be defined for both symbols & numbers
- Type inference by summoning implicits is more flexible for theorem proving than inheritance & delegation
- Operator overloading & infix syntax allows DSL to be closer to math notations, e.g.
vec1 dot_* vec2
can be easier to read thandotProduct(vec1, vec2)
- Advanced type algebra can be expressed in macro & roadmap features. Specifically, the
singleton-ops
library and it’s associated discussions in SIP-23 have suggested that Scala 2.13 & Dotty will use a more direct & expressive approach
Usage
A PoC project has been submitted at:
https://github.com/tribbloid/shapesafe/tree/master
At this moment (06/07/2020) it only implemented very few operations (dot_product, concat) for double vectors. However most features in the following list has been proved and can be showcased by running gradle test on DoubleVectorSpec.scala
.
Required & proven features:
- Type-level shape representation for vector
- Compile-time shape safety for operations with constant shapes (1)
- Run-time shape safety for operations with variable shapes (2)
- (1) can smoothly degrade to (2) for more flexible programming paradigm
- Not relying on compile-time code generation or defining large number of compile-time classes, both of which may stress a JVM classloader
- Not relying on esoteric compiler plugin, macro or unstable libraries. (the PoC only uses
singleton-ops
which is actively maintained by lightbend core team)
Required & unproven:
- Type-level shape representation for matrix
- Type-level shape representation for arbitrary tensor, with no limitation on its arity and dimensionality (ideally through using recursive generic types, the same technique that defined
shapeless.HList
) - Support for UInt & Float element data types
- Support for more operations, particularly if used frequently by ANN
Nice to have:
- Not relying on church type encoding (The only violation in my PoC is
shapeles.Nat
), which causes slow compilation - Compile-time & run-time shape safety for named tensor
- Shape algebra defined for symbols instead of number supported by user-defined axioms & compile-time theorem proving
- Dotty or Scala.js support
Not in the scope:
- AutoDiff / Differentiable Programming: should be agnostic to DSL
- Interoperability with Apache Spark, NVidia RAPIDS or any other library, models should manifest into different executions regardless of DSL being used
- Shape safe compositor / pipeline API, too much work
How it fits into roadmap
The competition for supremacy of deep learning ecosystem is brutal and unforgiving. With torch & tensorflow dominating both research & development phase, people have little reasons to steer away from Python & a dynamically typed, procedurally validated scaffolding paradigm. But there are exceptions: the large scale, mission critical, complex systems in production, like autopilot and SLAM, most likely prefers spending much effort reinventing & maintaining a more verbose and arduous code base written in C++ or other low level languages. For these systems, demands for built-in correctness and predictability of execution far outweights the ability to write more concise code.
This is, IMHO, the market niche where kotlingrad can fit in: for mega-engineering rather than prototyping. In particular, to enable users to:
- write provably valid neural architecture WITHOUT sanity test
- if not possible, write neural architecture with lower test coverage that validates the part that cannot be proven, the 80-20 rule in test coverage is very real and account for most edge case failures in an architecture that lacks any type.
- under the above premise, write short & easy to understand code
. in that order. My design & optimisation of DSLs should be consistent with this doctrine. The chosen of Scala & JVM as carrier should naturally put kotlingrad in the ecosystem of Apache Spark, and maybe of RISC-V on the long run, both of which are for large scale production.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:23 (7 by maintainers)
Top GitHub Comments
I’ve been studying DiffTaichi pretty closely over the last few months, and agree it’s a good design choice. The tools they’ve built for physical simulation are also very impressive. My colleagues and I have been working on some similar ideas connecting differentiable physics and rendering that we hope to be able to share soon. Learning smoothing tricks to get stable gradients for collisions and discrete events has been really challenging and I think a good litmus test for where differentiable programming is headed in the future. I’m personally interested in making those techniques accessible to the broader field of software engineering.
This is probably more related to #8, but I am staying in school to pursue further research. Last year, I moved from UdeM to McGill, where I am now working to build a connection between automatic differentiation and learning on programs. I think the AD/ML community is doing important translational research, but in terms of this project, I feel the path to adoption requires taking a longer view on automatic differentiation than other libraries can afford. I think we’re in a good position to do so, and have identified three high level goals for Kotlin∇: (1) language design, (2) symbolic reasoning and (3) graph computation.
One of our goals for Kotlin∇ is to provide a staged eDSL that resembles an imperative program (containing variables, control flow, functions and data types) for mathematical code. The language should look and feel similar to a scripting language like Python with a flexible type system and mathematically idiomatic syntax. I think it is important to actually use the DSL, so I will try to get to the point where it’s fast enough for myself and others to use for training. Another project I recommend checking out is KMath, which has shared a number of inspirations and plans to support a much larger API surface (if that’s important to you).
Another goal for Kotlin∇ is providing tools for AD/compiler research: solvers for equational reasoning, term rewriting and symbolic computation. Although normalization and canonicalization are undecidable in general, if we impose certain constraints, it becomes “just” NP-hard, using Knuth-Bendix or similar completion algorithms. It would be nice to provide a tool for determining if two symbolic expressions are semantically equivalent in a certain axiom system, by using logic programming (e.g. miniKanren, Prolog SAT/SMT solvers) or learning techniques (e.g. Siamese GNNs, symbolic pregression) for expression simplification.
Finally, our goal is to compile the entire graph (including control flow, data types, logic, etc.) to sparse matrices, where “executing” the program consists of pure matrix arithmetic. I think one day there will be a VM or interpreter for a high-level language that runs entirely on a matrix processor, only exiting to perform I/O. Users will be able to run existing programs and get parallelization for free. That’s going to be a long term project, but there is some progress we can make today, e.g. leveraging GPU/TPU/IPU intrinsics using GraphBLAS. Graphs and matrix representaton is of the things I’ve been working on this summer.
I think a hybrid tracing and SCT-based design makes sense and is a good compromise for the JVM. Another approach from Wang et al. (2018) proposes multistage-programming, which seems closer in spirit to what we are currently doing. Have you looked into Lantern and the LMS framework? I think it’s a good architecture and also makes a lot of sense from a compiler perspective. What do you think are the advantages and disadvantages of these two approaches? Happy to discuss how to integrate with DJL or a similar backend. We don’t want to reinvent the wheel, and I think that exercise would be helpful in the context of (1).
It depends on the representation you’re using. In general, common subexpression elimination is NP-hard. If the graph is a DAG, the problem is GI-hard, or if the graph is a tree, it’s something like (Matula, 1978), or maybe a little easier. It’s a well-studied problem in the compiler literature, and there are lots of good resources and efficient implementations around. Due to frequent nonlinearities in typical deep learning pipelines, there are not often many algebraic simplifications you can do, but it could be a useful optimization for differentiable programming more broadly. Kotlin∇ uses some ad-hoc simplifications, but it’s probably not something we should be rolling ourselves. I’m still looking for a good solution, there are some good resources on computer algebra in the readme.
There is a toy symbolic expression generator in our tests. I need to write some more extensive test cases using property-based testing. Chapter 4 of my master’s thesis describes the theory behind PBT and metamorphic testing, which were experimentally validated, but have not yet been integrated into our CI test suite. Our high-level approach is to generate a large number of random trees and run a sensitivity analysis across numerical inputs. We need to set up some regression tests to check for runtime performance and numerical stability, but the idea is that you specify an invariant, and throw an error if you can find a simplification and inputs which violate it. If you’re interested, there are some mature testing frameworks which automate this process, maybe check out ScalaTest if you’re unfamiliar.
This is another active area of research known as Tensor contraction ordering or optimal Jacobian accumulation, which in general is NP-hard (Naumann, 2006). If you’re just doing inference, there is a matrix chain multiplication algorithm (Hu & Shing). As you mention, the naïve implementation in most AD packages is not very efficient. For example, JAX uses reverse mode by default. You can also accumulate the Jacobians using a custom order (e.g.
jacfwd(jacrev(f))
/jacrev(jacfwd(f))
) but these too are suboptimal. As you suggested, it is possible to realize significant speedups by considering the structure of the computation graph. For example, TensorNetwork uses the opt_einsum package (which provides various strategies, e.g. brute force, greedy, DP) to search for the optimal contraction sequence. I was hoping to try throwing Z3 at the problem, but haven’t got around to it yet.