[RELAY][PASS] Make Fusor to handle split gracefully
See original GitHub issueHi folks,
A common pattern in LSTM/GRU-style cells is a structure like (for simplicity):
rnn_dim = 10
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
Normally when implementing this in Relay, we’d expect that graph_fuse would fuse this entire sequence (matmul + split + sigmoid/tanh/exp/add/mul) into a single function, as that’s entirely reasonable expectation and generates the highest performance code. That is, we expect:
fn (%X: Tensor[(1, 10), float32],
%y: Tensor[(30, 10), float32])
-> Tensor[(1, 10), float32] {
%0 = nn.dense(%X, %y, units=None)
%1 = split(%0, indices_or_sections=int64(3), axis=1)
%2 = %1.0
%3 = sigmoid(%2)
%4 = %1.1
%5 = tanh(%4)
%6 = %1.2
%7 = exp(%6)
%8 = multiply(%5, %7)
%9 = add(%3, %8)
%9
}
Instead, Relay generates something like:
fn (%X: Tensor[(1, 10), float32],
%y: Tensor[(30, 10), float32])
-> Tensor[(1, 10), float32] {
%0 = fn(%p0: Tensor[(1, 10), float32],
%p1: Tensor[(30, 10), float32])
-> Tensor[(1, 30), float32] {
%1 = nn.dense(%p0, %p1, units=None) # ty=Tensor[(1, 30), float32]
%1
}
%2 = %0(%X, %y) # ty=Tensor[(1, 30), float32]
%3 = fn(%p01: Tensor[(1, 30), float32])
-> Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]] {
%4 = split(%p01, indices_or_sections=int64(3), axis=1) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
%4
}
%5 = %3(%2) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
%6 = %5.0
%7 = %5.1
%8 = %5.2
%9 = fn(%p02: Tensor[(1, 10), float32],
%p11: Tensor[(1, 10), float32],
%p2: Tensor[(1, 10), float32])
-> Tensor[(1, 10), float32] {
%10 = sigmoid(%p02) # ty=Tensor[(1, 10), float32]
%11 = tanh(%p11) # ty=Tensor[(1, 10), float32]
%12 = exp(%p2) # ty=Tensor[(1, 10), float32]
%13 = multiply(%11, %12) # ty=Tensor[(1, 10), float32]
%14 = add(%10, %13) # ty=Tensor[(1, 10), float32]
%14
}
%15 = %9(%6, %7, %8) # ty=Tensor[(1, 10), float32]
%15
}
While of course it would be possible to implement a “GateComputation” op or similar which is internally just (split + pointwise functions), but it would be quite elegant to avoid that if possible.
I’m not fluent in the Relay GraphFuser code, but I was hoping someone (@jroesch?) knows off the top of their head what needs to be modified inside the fuser, and I or someone else can do the implementation work.
Issue Analytics
- State:
- Created 4 years ago
- Comments:13 (13 by maintainers)
Top GitHub Comments
@masahi thanks for taking point on this 👍
FWIW I realized this also introduces some interesting new ideas for scheduling. Normally in our existing frameworks we do stuff like batching the entire gate computation into a single X * k * D gemm (for k = 3 for GRU, 4 for LSTM, etc).
This has the problem that (once you can fuse with split), that to compute the gate values you inspect at D-strided locations in the output, and so in practice you end up fully realizing that intermediate computation. It may be probably more natural to instead reorder the single GEMM into vector-width blocked variants (i.e. instead of a (kD, D) matrix you’d reorder it into a ((D // V) x k x V, D) and then realize it at (k, V)-sized locations. Most existing frameworks I know of don’t take advantage of this, but conceptually it has the potential to be quite useful.