RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite?
See original GitHub issueHi, really exciting project! I’m wondering if you’ve published the model conversion script that you used to create the js_models files from the .pth
model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.
Issue Analytics
- State:
- Created a year ago
- Comments:32 (14 by maintainers)
Top Results From Across the Web
Unable to use the @tensorflow/tfjs-tflite package outside of a ...
The @tensorflow/tfjs-tflite package can't be used outside of a web browser due to its use of browser-specific APIs.
Read more >TensorFlow.js TFLite API
This library is a wrapper of TFLite interpreter. It is packaged in a WebAssembly binary that runs in a browser. For more details...
Read more >@tensorflow/tfjs-tflite - npm
This package enables users to run arbitary TFLite models on the web. Users can load a TFLite model from a URL, use TFJS...
Read more >TensorflowJS app using TFLite model - Google Groups
I am trying to run tflite model in web app using tfjs. ... I added TFJS API link to html page, and simply...
Read more >The RWKV Language Model - PythonRepo
Unlike LSBert, MILES uses the bert-base-multilingual-uncased model, as well as simple ... RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite?...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Finally tested the webgl backend on a real GPU:
As seen above, the 430m model also works on webgl now. It turns out my state store/restore code was breaking it: with a 24 layer model, it would attempt to stack 24 tensors at once, which would exceed the 16 input textures limit in WebGL. I worked around this by stacking 12 tensors at a time, twice, then using torch.cat() to glue two stacks.
The stacking code can be removed, but then the 430m model will have 120 individual inputs/outputs for state, which sound scary.
I guess this kind of vindicates the webgl backend? It does outperform wasm when used on a real GPU, and it can also run the non-quantized 430m model, while wasm can’t. Of course, it is still significantly slower than native.
@josephrocca I opened two new PRs in your huggingface repo, one with the updated 430m webgl model and the other removing the outdated model.
@BlinkDL The final [768, 50277] matmul is the slowest component. It’s almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn’t fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.
@josephrocca Yeah, I think it’s better to wait for WebGPU instead of pursuing WebGL any further. It seems to work well for graphics, but not so much for compute.