This was a fun project. At the time I was trying to build a good midi prediction model to accompany a keyboard track in real time. That still hasn't panned out yet. But along the way I was exploring model architectures for auto-regressive sequences that might be simpler or lighter than transformers, especially for running on consumer hardware. And that turned into a side quest to validate these architectures on familiar territory: language modeling. And that turned into a side side quest to build a web-assembly demo showcasing some of these tiny home-grown models.
This project spanned a few technologies: from cuda to cupy to torch to torch.fx to cpp to emscripten to js.
To do a web-assembly demo, not only does the model need to be tiny for the page to load quickly, but all the code required to run inference also needs to be tiny. Rather than meticulously evaluating lesser used tokenizers and BLAS libraries for their compiled size and webasm portability, I chose to just write these myself with that purpose in mind.
I was drawn to RNN models for their cheap inference cost on arbitrary length sequences, and QRNNs in particular offered some training optimizations by "untangling" the recurrent data flow into parallel channels. While QRNN has largely fallen out of favor since its release in 2016, I wanted to give it a new chance, mixing the kernel design with the more modern full model designs of transformers. Essentially transplant the QRNN for the self-attention layer in a transformer, keeping residual connections, layer norms, and MLP layers. The intuition here is that the general architecture advancements of transformers are somewhat understated while self-attention got all the glory, and maybe this would breath fresh life into QRNNs.
The nice folks from Salesforce were kind enough to open source an (almost) working reference impl of QRNNs in pytorch complete with a custom cuda kernel. It just required some minor tweaks to be compatible with newer versions of pytorch.
The kernel implementation was basically ideal for learning. It contains:
- ~20 line naive torch reference impl
- ~60 lines of C code, for forward and backward pass of the qrnn operation
- ~60 line adapter to compile the C code with cupy
, and satisfy torch.autograd.Function
- ~20 line torch.nn.Module
wrapper
- and about ~60 lines of test code
And what better way to learn it than to make my own twist on it. I set out to "simplify" the kernel by removing the f[t] * x[t]
step from h[t] = f[t] * x[t] + (1 - f[t]) * h[t-1]
, since it's already efficient to compute that part, and I wanted to try coefficients other than f[t]
and 1-f[t]
. I also felt like the C code could benefit from some indexing helpers as well as templating on dtype to support float32, float16, and bfloat16.
So I wrote sqrll.py featuring:
- The name SQRLL, from Simplified Quasi Recurrent Linear Layer
- 8 line naive torch reference impl
- ~70 lines of C++ code, with nd index helpers, and dtype templates
- ~80 line adapter to compile with cupy
, and satisfy torch.autograd.Function
The in sqrllm.py I built a useful model stack for training. This file owes some credit to the Meta Llama team for their very readable reference impl of llama and I borrowed a few things like overall structure and rms_norm.
The actual SqrllLayer
applies a fair amount of extra projections and gates around the basic kernel. This is inspired by the RetNet paper and also the success of the LSTM design. Again I felt the best chance to breath new life into the QRNN idea was to give it all the surrounding features that have consistently proven effective in competing architectures.
class SqrllLayer(torch.nn.Module):
def __init__(self, n_in, n_mem, n_out):
super().__init__()
self.wr = torch.nn.Linear(n_in, n_mem)
self.wi = torch.nn.Linear(n_in, n_mem, bias=False)
self.wig = torch.nn.Linear(n_in, n_mem)
self.wog = torch.nn.Linear(n_in, n_mem)
self.wo = torch.nn.Linear(n_mem, n_out, bias=False)
def forward(self, x, mem=None):
ig = self.wig(x).sigmoid()
og = self.wog(x).sigmoid()
r = self.wr(x).sigmoid()
x = self.wi(x) * ig
y = sqrll_kernel(x, r, mem)
mem = y[:,-1].detach().clone()
y = torch.nn.functional.softsign(y)
y = y * og
y = self.wo(y)
return y, mem
As an alternative to multi-head, I added an option to multiply intersperse SqrllLayers and FFN/MLP layers at ratios other than 1:1. I.e. multiple sequential sqrll blocks before the ffn block would stand in for multi-head self-attn. Pu it another way, it's a knob to bias parameter allocation for FFN vs QRNN functionality.
Next I needed a training pipeline. I wanted to use wikipedia, assuming it's the sort of canonical high quality text data. This turned out to be quite an adventure to find a good way to get just plain text from wikipedia, as the raw data is packed with markup syntax, tables, image captions, etc. I settled on zim
pre-processed archives of wikipedia, and libzim parser to load them. The text coming from this seemed much higher quality than other options I looked at, including some very popular wikipedia parsers and normalizers.
I wanted a little bit of a "chat" experience so I also included the SQuAD Q&A dataset to at least get some question answering behavior. This required a little preprocessing to fit the prompt format I had in mind.
Working out a tokenizer was a side quest on it's own. So I started with training byte predictor models i.e. vocab size of 256 with a true no-op tokenizer.
I used a PC with AMD 7950X and Nvidia 4090 for training. I did the bulk of data pre-processing "online" as training runs, enabling fast iteration on both data processing and model code. The data preprocessing run on CPU leveraging all 16 cores. Building a rich string processing library for GPU might be a fun project too, but I decided not to allocate time for it.
After byte level modeling was working at a basic level, I started exploring tokenization, to take some "spelling" load of these <10M
parameter models and let them learn more semantics. I used HuggingFace tokenizer lib to train a BPE model and then used it to tokenize the text online. Training BPE based LM's proved effective and showed more semantic capabilities for a given number of parameters.
When trying to use the BPE tokens for inference I learned that HuggingFace BPE does not simply concatenate tokens when decoding. In particular non-ascii UTF8 chunks undergo complex transformations when combined. This seemed unnecessary, especially for my goal of build from scratch and porting it to webassembly.
So I implemented a BPE tokenizer and trainer in C++ with a focus on simplicity. I chose to operate at the byte level, with no special handling for utf-8. I did have to add some options for not creating bigger tokens across common whitespace boundaries, as the downstream LM accuracy suffered without it. Notably decoding would always be a simple concatenation operation.
I added a nice python package with nanobind to give it a clean python API with minimal data copies. For training I used this python interface to train the model and then tokenize text online for LM training. On my own benchmarks, it could tokenize text 50x faster than HuggingFace's tokenizer (python wrapped rust).
I also found a handy way to expose the C++ API from the pip installed package where python -m btok.includes
would print the include paths that a C++ compiler would need. For the webassembly port, I would need to include the C++ encode/decode APIs, and this worked nicely for that.
The compression ratio of tokenizers can be measured. The ratio can be computed as (log(n_vocab) * n_tokens) / (log(256) * n_bytes)
where n_bytes is the length of original text, and n_tokens is the length after tokenization, assuming enough text to overwhelm sampling noise.
I noticed that I was achieving better compression than HF tokenizers at similar vocab sizes. However the LMs using btok
BPE were consistently very slightly worse than using HF BPE, so I guess the extra complexity was doing something useful after all. Seems like a lot of black magic can go into tokenizer design, with impacts that only show up after language model training.
Now that the tokenizer is ready for a tiny C++ build, only the model itself remains.
I've done a number of trained model transformations and export adapters in the past at work, so this was a familiar problem. I wanted to keep this process dynamic, allow myself to rewrite parts of the pytorch model and not have to manually rewrite a parallel C++ impl. So I needed a true transpiler.
Fortunately torch.fx is a great front-end for parsing most well-behaved pytorch models. So I just needed to write a backend for C++.
To support the tensor math primitives underpinning my model, I started a tiny tensor math library in C++. It was a fun experiment to get numpy style multi-dim index broadcasting to work using C++ variadic templates and generic programming. Of course a major limitation of the tech is that tensor dimensions have to be compile time constants. But when code generating from a frozen model, it works great.
I used torch.fx.Interpreter
to help walk through the parsed AST in the order the code would need to run. I let it recurse into most functions and modules, except for those I had a direct C++ translation for. I let it rip on my model, and 1 by 1 implemented each "unknown function" it would hang up on.
In addition to all the low-level function calls, I also need to track all the parameters. I used the bfloat16 binary representation for all paramters, and saved them in the generated C++ as uint16 hex literals with a constructor to convert all the way back to float32 for webasm. This gives a simple 2x reduction in footprint with negligible accuracy impact. I would have liked to implement int8 quantization and math ops, but other projects were calling my name, so I settled for this.
This one became a standalone repo as an afterthought, and it's not too far from becoming a good one. Looking back I see a few places where LM/tokenizer is overly "mixed in", where it would be better to have a task-agnostic torch converter core, and some side modules to handle tokenizers and LM APIs. All of torch2cpp and the webassembly demo was completed in a single weekend after all.
To verify the C++ code, I was using a simple command-line wrapper to build it with a native compiler and test basic prompt completion functions. I set up the APIs of the generated code to cleanly plug into the CLI or the webasm wrapper without modification.
This took a fair amount of tinkering to find all the right compile flags to arrive at the exposed JS api that made sense to me, as well as bundling everything into a single file to simplify the html inclusion.
WebAssembly definitely favors C apis, so I settled on some classic C patterns like writing into pre-allocated buffers passed by pointer. Model RNN states were just held as singleton (global) variables which was good enough for this use case.
extern "C" {
void model_reset();
int model_step(int prevtok, float temperature);
int model_encode(char const* str, int str_len, int * out, int out_len);
int model_decode(int const* toks, int toks_len, char * out, int out_len);
} // extern C
Astute readers might notice that I merged the prefill and generate functions under model_step
. Don't worry, it's not much wasted compute here, and it just wasn't worth the trouble to optimize the prefill path.
This allowed me to write a simple chat app in HTML/JS to feed user generated prompt text in, and display model generated completions.
The tiny Wikipedia model has with a final wasm size of 23.5 MB.
The super tiny Finnegan's Wake model has a final was size of 4.7 MB.
Both models give very low quality responses, especially by today's standards. But it was a great experience to build all these parts from scratch. And being able to download a working language model model and runtime in the same time as 1 high res JPEG is an unusual result these days!