This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.
I have implemented a simple Module
class to build basic building blocks upon.
This was my first JAX project and many implementations were taken from PyTorch implementations at nn.labml.ai.
JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.
In JAX you don't have to worry about the batches. The functions are implemented for a single sample and jax.vit
can vectorize (parallelize) the functions across the batch dimension (or any other dimension if needed).
48from functools import partial
49from typing import Dict, NamedTuple, Tuple, Any, Callable
50from typing import List, TypeVar, Generic
51from typing import Union, Optional
52
53import jax
54import jax.numpy as jnp
55from labml import lab, monit, experiment, tracker
56from labml import logger
57from labml.logger import Text
58from labml.utils.download import download_file
This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.
You can skip these modules to get into the models directly.
The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.
This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.
61class Module:
Store all parameters and sub-modules in dictionaries
81 _submodules: Dict[str, 'Module']
82 _params: Dict[str, jnp.ndarray]
Initialize
84 def __init__(self):
86 self._params = {}
87 self._submodules = {}
We override the get attribute operation. So when you reference an attribute with model.attribute
this function gets called.
Read this guide if you are not familiar with Python magic methods.
89 def __getattr__(self, attr_name: str):
If the attribute is a parameter
101 if attr_name in self._params:
102 return self._params[attr_name]
If the attribute is a sub-module
104 elif attr_name in self._submodules:
105 return self._submodules[attr_name]
Otherwise fallback to normal attributes. The attributes are stored in __dict__
by Python.
108 else:
109 return self.__dict__[attr_name]
We override the set attribute operation. So when you assign an attribute with model.attribute
this function gets called.
111 def __setattr__(self, key: str, value: Any):
If the value is also a module
120 if isinstance(value, Module):
121 self._submodules[key] = value
If the value is a JAX array
123 elif isinstance(value, jnp.ndarray):
124 self._params[key] = value
Otherwise add it to __dict__
126 else:
127 self.__dict__[key] = value
These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.
129 def _clear_params(self):
Clear parameters of the module
138 self._params = {}
Recursively clear parameters of submodules
140 for sm in self._submodules.values():
141 sm._clear_params()
This recursively collects all the parameters of the module and sub-modules into a dictionary.
143 def get_params(self) -> Dict[str, jnp.ndarray]:
Parameters of the model
151 params = self._params.copy()
Parameters of the submodules
153 for sm_name, sm in self._submodules.items():
154 for name, value in sm.get_params().items():
The dictionary keys are of the form module_name/module_name/param_name
156 params[sm_name + "/" + name] = value
158 return params
160 def _set_params(self, params: Dict[str, jnp.ndarray]):
Iterate through parameters. Their names have the form module_name/module_name/param_name
167 for name, value in params.items():
Split to get module names and parameter name
169 self._set_param(name.split("/"), value)
171 def _set_param(self, param_path: List[str], value: jnp.ndarray):
No module names; i.e. a parameter of this module
178 if len(param_path) == 1:
179 self._params[param_path[0]] = value
Parameter of a submodule
181 else:
182 self._submodules[param_path[0]]._set_param(param_path[1:], value)
This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.
For example,
params = model.get_params()
pure_function = model.purify(model.calculate_loss)
output = pure_function(params, data)
184 def purify(self, method: Callable) -> Callable:
200 def pure_method(params: Dict[str, jnp.array], *args):
Clear parameters in the object
202 self._clear_params()
Assign the passed parameters
204 self._set_params(params)
Invoke the method
206 result = method(*args)
Return the result
208 return result
211 return pure_method
Type for generics in the module list class
215M = TypeVar('M', bound=Module)
This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.
218class ModuleList(Module, Generic[M]):
For list of modules
227 _submodules: List[M]
Initialize with a list of modules.
229 def __init__(self, modules: List[M]):
233 super().__init__()
234 self._submodules = modules
idx
-th module236 def __getitem__(self, idx: int) -> M:
240 return self._submodules[idx]
This is not supported
242 def __setitem__(self, key, value):
246 raise NotImplementedError
248 def __len__(self):
252 return len(self._submodules)
Override __getattr__
of Module
254 def __getattr__(self, item):
258 return self.__dict__[item]
Override __setattr__
of Module
260 def __setattr__(self, key, value):
264 self.__dict__[key] = value
266 def _clear_params(self):
270 self._params = {}
271 for sm in self._submodules:
272 sm._clear_params()
274 def get_params(self):
278 params = self._params
279 for i, sm in enumerate(self._submodules):
280 for name, value in sm.get_params().items():
281 params[f'{i}/{name}'] = value
282 return params
284 def _set_param(self, param_path: List[str], value: jnp.ndarray):
288 self._submodules[int(param_path[0])]._set_param(param_path[1:], value)
291class Embedding(Module):
rnd_key
is the PRNG state n_embeddings
is the number of embeddings n_dim
is the size of an embedding300 def __init__(self, rnd_key: jax.random.PRNGKey, n_embeddings: int, n_dim: int):
306 super().__init__()
Embeddings are initialized from
308 self.embeddings = jax.random.normal(rnd_key, (n_embeddings, n_dim))
Return the embeddings for the given ids
310 def __call__(self, ids: jnp.ndarray):
314 return self.embeddings[ids, :]
This is based on our PyTorch implementation.
317class EmbeddingsWithLearnedPositionalEncoding(Module):
rnd_key
is the PRNG state n_vocab
is the vocabulary size d_model
is the embedding size max_len
is the maximum sequence length (to initialize positional encodings)327 def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, max_len: int = 4096):
334 super().__init__()
Embeddings
336 self.embeddings = Embedding(rnd_key, n_vocab, d_model)
Positional encodings coefficient
338 self.pe_coef = 1 / d_model ** 0.5
Positional encodings initialized to zeros
340 self.positional_encodings = jnp.zeros((max_len, d_model))
342 def __call__(self, x: jnp.ndarray):
Get positional encodings
344 pe = self.positional_encodings[:x.shape[0]]
Get embeddings and add positional encodings
346 return self.embeddings(x) * self.pe_coef + pe
349class Linear(Module):
rnd_key
is the PRNG state in_features
is the number of features in the input out_features
is the number of features in the output358 def __init__(self, rnd_key: jax.random.PRNGKey, in_features: int, out_features: int):
364 super().__init__()
Initialize weights to
367 rnd_range = 1 / in_features ** 0.5
368 self.weight = jax.random.uniform(rnd_key, (in_features, out_features),
369 minval=-rnd_range, maxval=rnd_range)
Initialize the biases to
371 self.bias = jnp.zeros((out_features,))
373 def __call__(self, x: jnp.ndarray):
Multiply by weights and add the bias
375 return jnp.matmul(x, self.weight) + self.bias
This implements the the layer normalization from the paper Layer Normalization.
When input is a sequence of embeddings, where is the number of channels, is the length of the sequence. and .
This is based on our PyTorch implementation.
378class LayerNorm(Module):
normalized_shape
is the shape of the elements (except the batch). The input should then be eps
is , used in for numerical stability elementwise_affine
is whether to scale and shift the normalized value398 def __init__(self, normalized_shape: Union[Tuple[int], List[int]], *,
399 eps: float = 1e-5, elementwise_affine: bool = True):
407 super().__init__()
408
409 self.eps = eps
410 self.elementwise_affine = elementwise_affine
411 self.normalized_shape = tuple(normalized_shape)
Create parameters for and for gain and bias
414 if elementwise_affine:
415 self.gain = jnp.ones(normalized_shape)
416 self.bias = jnp.zeros(normalized_shape)
418 def __call__(self, x: jnp.ndarray):
Sanity check to make sure the shapes match
420 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
The exes to calculate the mean and variance on
423 axes = [-(i + 1) for i in range(len(self.normalized_shape))]
Calculate the mean of all elements; i.e. the means for each element
426 mean = x.mean(axis=axes, keepdims=True)
Calculate the squared mean of all elements; i.e. the means for each element
429 mean_2 = (x ** 2).mean(axis=axes, keepdims=True)
Variance of all element
431 var = mean_2 - mean ** 2
Normalize
433 x_norm = (x - mean) / (var + self.eps) ** 0.5
Scale and shift
436 if self.elementwise_affine:
437 x_norm = self.gain * x_norm + self.bias
440 return x_norm
This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.
443class PrepareForMultiHeadAttention(Module):
454 def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
455 super().__init__()
Linear layer for linear transform
457 self.linear = Linear(rnd_key, d_model, heads * d_k)
Number of heads
459 self.heads = heads
Number of dimensions in vectors in each head
461 self.d_k = d_k
463 def __call__(self, x: jnp.ndarray):
Input has shape [seq_len, batch_size, d_model]
or [batch_size, d_model]
. We apply the linear transformation to the last dimension and split that into the heads.
467 head_shape = x.shape[:-1]
Linear transform
470 x = self.linear(x)
Split last dimension into heads
474 x = x.reshape(*head_shape, self.heads, self.d_k)
Output has shape [seq_len, batch_size, heads, d_k]
or [batch_size, d_model]
477 return x
This computes scaled multi-headed attention from the paper Attention Is All You Need for given query
, key
and value
vectors.
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.
Softmax is calculated along the axis of of the sequence (or time) for keys.
This is based on our PyTorch implementation.
480class MultiHeadAttention(Module):
rnd_key
is the PRNG state heads
is the number of heads. d_model
is the number of features in the query
, key
and value
vectors.506 def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):
513 super().__init__()
Split the PRNG state
516 _, *rnd_keys = jax.random.split(rnd_key, 5)
Number of features per head
519 self.d_k = d_model // heads
Number of heads
521 self.heads = heads
These transform the query
, key
and value
vectors for multi-headed attention.
524 self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
525 self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
526 self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)
Output layer
529 self.output = Linear(rnd_keys[3], d_model, d_model)
Scaling factor before the softmax
531 self.scale = 1 / self.d_k ** 0.5
query
, key
and value
are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model]
.
mask
has shape [seq_len, seq_len]
and mask[i, j]
indicates whether query at position i
can see key-value at position j
.
533 def __call__(self, *,
534 query: jnp.ndarray,
535 key: jnp.ndarray,
536 value: jnp.ndarray,
537 mask: Optional[jnp.ndarray] = None):
Get sequence length
548 seq_len = len(query)
549
550 if mask is not None:
Check mask shape
552 assert mask.shape[0] == query.shape[0]
553 assert mask.shape[1] == key.shape[0]
Same mask applied to all heads.
556 mask = mask[:, :, None]
Prepare query
, key
and value
for attention computation. These will then have shape [seq_len, heads, d_k]
.
560 query = self.query(query)
561 key = self.key(key)
562 value = self.value(value)
Compute attention scores . This gives a tensor of shape [seq_len, seq_len, heads]
.
567 scores = jnp.einsum('ihd,jhd->ijh', query, key)
Scale scores
570 scores *= self.scale
Apply mask
573 if mask is not None:
574 scores = scores + (mask == 0) * float('-inf')
attention along the key sequence dimension
578 attn = jax.nn.softmax(scores, axis=1)
Multiply by values
582 x = jnp.einsum("ijh,jhd->ihd", attn, value)
Concatenate multiple heads
585 x = x.reshape(seq_len, -1)
Output layer
588 return self.output(x)
591class FeedForward(Module):
rnd_key
is the PRNG state d_model
is the number of features in a token embedding d_ff
is the number of features in the hidden layer of the FFN activation
is the activation function 601 def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, d_ff: int,
602 activation=jax.nn.relu):
609 super().__init__()
Split the PRNG state
611 _, *rnd_keys = jax.random.split(rnd_key, 5)
Layer one parameterized by weight and bias
614 self.layer1 = Linear(rnd_keys[0], d_model, d_ff)
Layer one parameterized by weight and bias
616 self.layer2 = Linear(rnd_keys[1], d_ff, d_model)
Activation function
618 self.activation = activation
620 def __call__(self, x: jnp.ndarray):
622 x = self.activation(self.layer1(x))
624 return self.layer2(x)
This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.
627class TransformerLayer(Module):
d_model
is the token embedding size self_attn
is the self attention module feed_forward
is the feed forward module637 def __init__(self,
638 d_model: int,
639 self_attn: MultiHeadAttention,
640 feed_forward: FeedForward):
646 super().__init__()
647 self.size = d_model
648 self.self_attn = self_attn
649 self.feed_forward = feed_forward
650 self.norm_self_attn = LayerNorm([d_model])
651 self.norm_ff = LayerNorm([d_model])
653 def __call__(self, x: jnp.ndarray, mask: jnp.ndarray):
Normalize the vectors before doing self attention
655 z = self.norm_self_attn(x)
Run through self attention, i.e. keys and values are from self
657 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
658 x = x + self_attn
Normalize for feed-forward
661 z = self.norm_ff(x)
Pass through the feed-forward network
663 ff = self.feed_forward(z)
Add the feed-forward results
665 x = x + ff
667 return x
670class CrossEntropyLoss(Module):
677 def __init__(self):
678 super().__init__()
Use jax.vmap
to vectorize the loss function
681 self._loss_vmap = jax.vmap(self._loss, in_axes=(0, 0,))
683 def _loss(self, output: jnp.ndarray, target: jnp.ndarray):
685 return -jax.nn.log_softmax(output)[target]
output
is the model outputs of shape [seq_len, n_vocab]
target
is the target of shape [seq_len]
687 def __call__(self, output: jnp.ndarray, target: jnp.ndarray):
Use the vectorized loss function and calculate the mean.
We could have used a for loop to calculate the losses but using vmap is about 10X faster
696 return self._loss_vmap(output, target).mean()
699class AutoregressiveTransformer(Module):
707 layers: ModuleList[TransformerLayer]
rnd_key
is the PRNG state n_vocab
is the vocabulary size d_model
is the number of features in a token embedding n_layers
is the number of transformer layers heads
is the number of attention heads d_ff
is the number of features in the hidden layer of the FFN709 def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, n_layers: int, heads: int, d_ff: int):
718 super().__init__()
719 self.n_vocab = n_vocab
720 self.d_model = d_model
721 self.loss_func = CrossEntropyLoss()
For transformer layers
724 layers = []
725 for i in range(n_layers):
Split PRNG state
727 rnd_key, mha_key, ffn_key = jax.random.split(rnd_key, 3)
Create a transformer layer
729 attn = MultiHeadAttention(mha_key, heads, d_model)
730 ffn = FeedForward(ffn_key, d_model, d_ff)
731 layers.append(TransformerLayer(d_model, attn, ffn))
Make a module list
733 self.layers = ModuleList(layers)
Split PRNG state
736 rnd_key, emb_key, out_key = jax.random.split(rnd_key, 3)
Create embedding layer
738 self.embeddings = EmbeddingsWithLearnedPositionalEncoding(emb_key, n_vocab, d_model)
Final normalization and output layer
740 self.norm = LayerNorm([d_model])
741 self.output = Linear(out_key, d_model, n_vocab)
743 def __call__(self, x: jnp.ndarray):
Get sequence length
745 seq_len = len(x)
A mask for attention so that a token can only see tokens before that
747 mask = jnp.tril(jnp.ones((seq_len, seq_len), bool))
Get embeddings with positional encodings
749 x = self.embeddings(x)
Apply the transformer layers
751 for i in range(len(self.layers)):
752 x = self.layers[i](x, mask)
Final normalization and linear transformation to get the logits
755 return self.output(self.norm(x))
757 def get_loss(self, x: jnp.ndarray):
Get model outputs
762 output = self(x)
Cross entropy loss
764 return self.loss_func(output[:-1], x[1:])
766 def sample(self, seq: jnp.ndarray, length: int = 20):
772 for i in range(length):
Sample the highest probability token
774 idx = jnp.argmax(self(seq)[-1])
Add it to the sequence
776 seq = jnp.concatenate((seq, idx[None]))
Return the sampled sequence
779 return seq
This is a named tuple for storing Adam optimizer state for a parameter
782class AdamState(NamedTuple):
786 m: jnp.ndarray
787 v: jnp.ndarray
This is from paper Adam: A Method for Stochastic Optimization.
For parameter and gradient at step , the Adam update is,
where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
790class Adam:
params
is the tree-map of parameters lr
is the learning rate betas
is a tuple of (, ) eps
is `816 def __init__(self, params: Dict,
817 lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
818 eps: float = 1e-16, ):
826 super().__init__()
827 self.lr = lr
828 self.betas = betas
829 self.eps = eps
States for each parameter
832 self.states = jax.tree_multimap(self._init_state, params)
Optimized step function
834 self._step_jit = jax.jit(self._step)
Number of steps taken
836 self._n_steps = 0
Optimized update state function
838 self._update_state_jit = jax.jit(self._update_state)
Initialize the state for a given parameter
840 def _init_state(self, param: jnp.ndarray):
844 return AdamState(jnp.zeros_like(param), jnp.zeros_like(param))
846 def step(self, params: Dict, grads: Dict):
Increment step
854 self._n_steps += 1
Update states for each parameter
856 self.states = jax.tree_multimap(self._update_state_jit, grads, self.states)
Return updated parameters
858 return jax.tree_multimap(partial(self._step_jit, self._n_steps), params, self.states)
860 def _step(self, n_steps: int, param: jnp.ndarray, state: AdamState):
Bias corrections for : and for :
868 bias_correction = [1 - beta ** n_steps for beta in self.betas]
Uncorrected first and second moments and
870 m, v = state
873 step_size = self.lr * (bias_correction[1] ** 0.5) / bias_correction[0]
875 den = (v ** 0.5) + self.eps
879 return param - step_size * m / den
881 def _update_state(self, grad, state: AdamState):
Uncorrected first and second moments and
888 m, v = state
Clip gradients
890 grad = jnp.clip(grad, -1, 1)
892 m = self.betas[0] * m + grad * (1 - self.betas[0])
894 v = self.betas[1] * v + (grad ** 2) * (1 - self.betas[1])
Return the new state
897 return AdamState(m, v)
900class TinyShakespeare:
rnd_key
is the PRNG state seq_len
is the sequence length of a sample batch_size
is the batch size907 def __init__(self, rnd_key: jax.random.PRNGKey, seq_len: int, batch_size: int):
914 self.batch_size = batch_size
PRNG key for shuffling the samples
916 _, self.rnd_key = jax.random.split(rnd_key)
Local path of the text file
919 path = lab.get_data_path() / 'tiny_shakespeare.txt'
Download if it doesn't exist
921 url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
922 if not path.exists():
923 download_file(url, path)
Read the file
926 with open(str(path), 'r') as f:
927 self.text = f.read()
Get the characters/tokens
930 tokens = sorted(list(set(self.text)))
Number of tokens
933 self.n_tokens = len(tokens)
Map tokens to ids
935 self.stoi = {t: i for i, t in enumerate(tokens)}
Id to token/character
937 self.itos = tokens
As a list of ids
940 data = jnp.array([self.stoi[s] for s in list(self.text)])
Number of batches
942 self.n_batches = len(data) // (seq_len * batch_size)
Truncate
944 data = data[:self.n_batches * seq_len * batch_size]
Reshape into a samples (better to use random offsets, but lets ignore that here)
946 self.data = data.reshape((-1, seq_len))
List of sample indexes
948 self.idx = jnp.arange(len(self.data))
Setup for iteration
950 def __iter__(self):
Iteration step
955 self._iter_idx = 0
Split PRNG key
957 self.rnd_key, rnd_key = jax.random.split(self.rnd_key)
Shuffle sample indexes
959 self.idx = jax.random.permutation(rnd_key, self.idx)
962 return self
Number of batches
964 def __len__(self):
968 return self.n_batches
Get next batch
970 def __next__(self):
Stop iteration after iterating through all batches
976 if self._iter_idx >= self.n_batches:
977 raise StopIteration()
Sample indexes for the batch
980 idx = self.idx[self._iter_idx * self.batch_size:(self._iter_idx + 1) * self.batch_size]
Increment iteration step
982 self._iter_idx += 1
Return samples
985 return self.data[idx]
988def main():
Create experiment
996 experiment.create(name='jax')
Create PRNG key
998 rnd_key = jax.random.PRNGKey(0)
Create dataset
1000 dataset = TinyShakespeare(rnd_key, seq_len=32, batch_size=128)
Create the model
1003 model = AutoregressiveTransformer(rnd_key, dataset.n_tokens,
1004 d_model=128, n_layers=3, heads=8, d_ff=512)
Get model parameters
1006 params = model.get_params()
JAX compiled pure sampling function
1009 pure_sample_fn = jax.jit(model.purify(model.sample))
JAX compiled pure function to get logits for a batch. First we transform model.__call__
to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes
specifies which arguments to parallelize and along which axis. (None, 0)
means we have the same parameters but parallelize the inputs across the first axis. out_axes
specifies along which axis to merge the results.
1017 pure_forward_fn = jax.jit(jax.vmap(model.purify(model.__call__),
1018 in_axes=(None, 0), out_axes=0))
Similarly we vectorize loss computation
1020 pure_loss_fn = jax.jit(jax.vmap(model.purify(model.get_loss),
1021 in_axes=(None, 0), out_axes=0))
A function to get mean loss
1024 def get_loss(params, seq):
1025 return pure_loss_fn(params, seq).mean()
A function to compute gradients for the first argument (parameters)
1028 grad_loss_fn = jax.jit(jax.grad(get_loss, argnums=0))
Create optimizer
1031 optimizer = Adam(params)
Start the experiment
1034 with experiment.start():
Iterate for 32 epochs
1036 for epoch in monit.loop(32):
Iterate through batches
1038 for data in monit.iterate('Train', dataset):
Compute and log the loss
1040 loss = get_loss(params, data)
1041 tracker.save('loss', loss)
Get the gradients
1043 grads = grad_loss_fn(params, data)
Update parameters
1045 params = optimizer.step(params, grads)
1048 tracker.new_line()
Log a sample after each epoch
1050 prompt = [dataset.stoi[c] for c in 'It ']
1051 sampled = pure_sample_fn(params, jnp.array(prompt))[len(prompt):]
1052 sampled = ''.join([dataset.itos[i] for i in sampled])
1053 sampled = sampled.replace('\n', '\\n')
1054 logger.log(('It ', Text.meta), (sampled, Text.value))
1058if __name__ == '__main__':
1059 main()