Autoregressive Transformer Decoder in JAX from scratch

This implementation builds a transformer decoder from ground up. This doesn't use any higher level frameworks like Flax and I have used labml for logging and experiment tracking.

I have implemented a simple Module class to build basic building blocks upon.

This was my first JAX project and many implementations were taken from PyTorch implementations at

JAX can optimize and differentiate Python pure-functions. Pure functions are function that take a bunch of arguments and return a result without making changes to anything like local variables. JAX can also compile these functions to as well as vectorize to run them efficiently.

In JAX you don't have to worry about the batches. The functions are implemented for a single sample and jax.vit can vectorize (parallelize) the functions across the batch dimension (or any other dimension if needed).


View Run Twitter thread

48from functools import partial
49from typing import Dict, NamedTuple, Tuple, Any, Callable
50from typing import List, TypeVar, Generic
51from typing import Union, Optional
53import jax
54import jax.numpy as jnp
55from labml import lab, monit, experiment, tracker
56from labml import logger
57from labml.logger import Text
58from import download_file


This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.

You can skip these modules to get into the models directly.

The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.

This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.

61class Module:

Store all parameters and sub-modules in dictionaries

81    _submodules: Dict[str, 'Module']
82    _params: Dict[str, jnp.ndarray]


84    def __init__(self):
86        self._params = {}
87        self._submodules = {}

Get attribute

We override the get attribute operation. So when you reference an attribute with model.attribute this function gets called.

Read this guide if you are not familiar with Python magic methods.

89    def __getattr__(self, attr_name: str):

If the attribute is a parameter

101        if attr_name in self._params:
102            return self._params[attr_name]

If the attribute is a sub-module

104        elif attr_name in self._submodules:
105            return self._submodules[attr_name]

Otherwise fallback to normal attributes. The attributes are stored in __dict__ by Python.

108        else:
109            return self.__dict__[attr_name]

Set attribute

We override the set attribute operation. So when you assign an attribute with model.attribute this function gets called.

111    def __setattr__(self, key: str, value: Any):

If the value is also a module

120        if isinstance(value, Module):
121            self._submodules[key] = value

If the value is a JAX array

123        elif isinstance(value, jnp.ndarray):
124            self._params[key] = value

Otherwise add it to __dict__

126        else:
127            self.__dict__[key] = value

Clear parameters

These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.

129    def _clear_params(self):

Clear parameters of the module

138        self._params = {}

Recursively clear parameters of submodules

140        for sm in self._submodules.values():
141            sm._clear_params()

Collect all the parameters

This recursively collects all the parameters of the module and sub-modules into a dictionary.

143    def get_params(self) -> Dict[str, jnp.ndarray]:

Parameters of the model

151        params = self._params.copy()

Parameters of the submodules

153        for sm_name, sm in self._submodules.items():
154            for name, value in sm.get_params().items():

The dictionary keys are of the form module_name/module_name/param_name

156                params[sm_name + "/" + name] = value

158        return params

Set all the parameters

160    def _set_params(self, params: Dict[str, jnp.ndarray]):

Iterate through parameters. Their names have the form module_name/module_name/param_name

167        for name, value in params.items():

Split to get module names and parameter name

169            self._set_param(name.split("/"), value)

Set a single parameter

This is called by _set_params

171    def _set_param(self, param_path: List[str], value: jnp.ndarray):

No module names; i.e. a parameter of this module

178        if len(param_path) == 1:
179            self._params[param_path[0]] = value

Parameter of a submodule

181        else:
182            self._submodules[param_path[0]]._set_param(param_path[1:], value)

Transform a member method to a pure function

This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.

For example,

params = model.get_params()
pure_function = model.purify(model.calculate_loss)
output = pure_function(params, data)
184    def purify(self, method: Callable) -> Callable:
200        def pure_method(params: Dict[str, jnp.array], *args):

Clear parameters in the object

202            self._clear_params()

Assign the passed parameters

204            self._set_params(params)

Invoke the method

206            result = method(*args)

Return the result

208            return result

211        return pure_method

Type for generics in the module list class

215M = TypeVar('M', bound=Module)

Module list

This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.

218class ModuleList(Module, Generic[M]):

For list of modules

227    _submodules: List[M]

Initialize with a list of modules.

229    def __init__(self, modules: List[M]):
233        super().__init__()
234        self._submodules = modules

Get the idx -th module

236    def __getitem__(self, idx: int) -> M:
240        return self._submodules[idx]

This is not supported

242    def __setitem__(self, key, value):
246        raise NotImplementedError

Number of modules

248    def __len__(self):
252        return len(self._submodules)

Override __getattr__ of Module

254    def __getattr__(self, item):
258        return self.__dict__[item]

Override __setattr__ of Module

260    def __setattr__(self, key, value):
264        self.__dict__[key] = value

Clear all parameters

266    def _clear_params(self):
270        self._params = {}
271        for sm in self._submodules:
272            sm._clear_params()

Get all parameters

274    def get_params(self):
278        params = self._params
279        for i, sm in enumerate(self._submodules):
280            for name, value in sm.get_params().items():
281                params[f'{i}/{name}'] = value
282        return params

Set a parameter

284    def _set_param(self, param_path: List[str], value: jnp.ndarray):
288        self._submodules[int(param_path[0])]._set_param(param_path[1:], value)

Embedding layer

This maintains embeddings by id.

291class Embedding(Module):
  • rnd_key is the PRNG state
  • n_embeddings is the number of embeddings
  • n_dim is the size of an embedding
300    def __init__(self, rnd_key: jax.random.PRNGKey, n_embeddings: int, n_dim: int):
306        super().__init__()

Embeddings are initialized from

308        self.embeddings = jax.random.normal(rnd_key, (n_embeddings, n_dim))

Return the embeddings for the given ids

310    def __call__(self, ids: jnp.ndarray):
314        return self.embeddings[ids, :]

Embed tokens and add parameterized positional encodings

This is based on our PyTorch implementation.

317class EmbeddingsWithLearnedPositionalEncoding(Module):
  • rnd_key is the PRNG state
  • n_vocab is the vocabulary size
  • d_model is the embedding size
  • max_len is the maximum sequence length (to initialize positional encodings)
327    def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, max_len: int = 4096):
334        super().__init__()


336        self.embeddings = Embedding(rnd_key, n_vocab, d_model)

Positional encodings coefficient

338        self.pe_coef = 1 / d_model ** 0.5

Positional encodings initialized to zeros

340        self.positional_encodings = jnp.zeros((max_len, d_model))
342    def __call__(self, x: jnp.ndarray):

Get positional encodings

344        pe = self.positional_encodings[:x.shape[0]]

Get embeddings and add positional encodings

346        return self.embeddings(x) * self.pe_coef + pe

Linear Layer

This is a simple linear layer with a weight matrix and a bias vector

349class Linear(Module):
  • rnd_key is the PRNG state
  • in_features is the number of features in the input
  • out_features is the number of features in the output
358    def __init__(self, rnd_key: jax.random.PRNGKey, in_features: int, out_features: int):
364        super().__init__()

Initialize weights to

367        rnd_range = 1 / in_features ** 0.5
368        self.weight = jax.random.uniform(rnd_key, (in_features, out_features),
369                                         minval=-rnd_range, maxval=rnd_range)

Initialize the biases to

371        self.bias = jnp.zeros((out_features,))
373    def __call__(self, x: jnp.ndarray):

Multiply by weights and add the bias

375        return jnp.matmul(x, self.weight) + self.bias

Layer Normalization

This implements the the layer normalization from the paper Layer Normalization.

When input is a sequence of embeddings, where is the number of channels, is the length of the sequence. and .

This is based on our PyTorch implementation.

378class LayerNorm(Module):
  • normalized_shape is the shape of the elements (except the batch). The input should then be
  • eps is , used in for numerical stability
  • elementwise_affine is whether to scale and shift the normalized value
398    def __init__(self, normalized_shape: Union[Tuple[int], List[int]], *,
399                 eps: float = 1e-5, elementwise_affine: bool = True):
407        super().__init__()
409        self.eps = eps
410        self.elementwise_affine = elementwise_affine
411        self.normalized_shape = tuple(normalized_shape)

Create parameters for and for gain and bias

414        if elementwise_affine:
415            self.gain = jnp.ones(normalized_shape)
416            self.bias = jnp.zeros(normalized_shape)
418    def __call__(self, x: jnp.ndarray):

Sanity check to make sure the shapes match

420        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

The exes to calculate the mean and variance on

423        axes = [-(i + 1) for i in range(len(self.normalized_shape))]

Calculate the mean of all elements; i.e. the means for each element

426        mean = x.mean(axis=axes, keepdims=True)

Calculate the squared mean of all elements; i.e. the means for each element

429        mean_2 = (x ** 2).mean(axis=axes, keepdims=True)

Variance of all element

431        var = mean_2 - mean ** 2


433        x_norm = (x - mean) / (var + self.eps) ** 0.5

Scale and shift

436        if self.elementwise_affine:
437            x_norm = self.gain * x_norm + self.bias

440        return x_norm

Prepare for multi-head attention

This module does a linear transformation and splits the vector into given number of heads for multi-head attention. This is used to transform key, query, and value vectors.

443class PrepareForMultiHeadAttention(Module):
454    def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
455        super().__init__()

Linear layer for linear transform

457        self.linear = Linear(rnd_key, d_model, heads * d_k)

Number of heads

459        self.heads = heads

Number of dimensions in vectors in each head

461        self.d_k = d_k
463    def __call__(self, x: jnp.ndarray):

Input has shape [seq_len, batch_size, d_model] or [batch_size, d_model] . We apply the linear transformation to the last dimension and split that into the heads.

467        head_shape = x.shape[:-1]

Linear transform

470        x = self.linear(x)

Split last dimension into heads

474        x = x.reshape(*head_shape, self.heads, self.d_k)

Output has shape [seq_len, batch_size, heads, d_k] or [batch_size, d_model]

477        return x

Multi-Head Attention Module

This computes scaled multi-headed attention from the paper Attention Is All You Need for given query , key and value vectors.

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.

Softmax is calculated along the axis of of the sequence (or time) for keys.

This is based on our PyTorch implementation.

480class MultiHeadAttention(Module):
  • rnd_key is the PRNG state
  • heads is the number of heads.
  • d_model is the number of features in the query , key and value vectors.
506    def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):
513        super().__init__()

Split the PRNG state

516        _, *rnd_keys = jax.random.split(rnd_key, 5)

Number of features per head

519        self.d_k = d_model // heads

Number of heads

521        self.heads = heads

These transform the query , key and value vectors for multi-headed attention.

524        self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
525        self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
526        self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)

Output layer

529        self.output = Linear(rnd_keys[3], d_model, d_model)

Scaling factor before the softmax

531        self.scale = 1 / self.d_k ** 0.5

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model] .

mask has shape [seq_len, seq_len] and mask[i, j] indicates whether query at position i can see key-value at position j .

533    def __call__(self, *,
534                 query: jnp.ndarray,
535                 key: jnp.ndarray,
536                 value: jnp.ndarray,
537                 mask: Optional[jnp.ndarray] = None):

Get sequence length

548        seq_len = len(query)
550        if mask is not None:

Check mask shape

552            assert mask.shape[0] == query.shape[0]
553            assert mask.shape[1] == key.shape[0]

Same mask applied to all heads.

556            mask = mask[:, :, None]

Prepare query , key and value for attention computation. These will then have shape [seq_len, heads, d_k] .

560        query = self.query(query)
561        key = self.key(key)
562        value = self.value(value)

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, heads] .

567        scores = jnp.einsum('ihd,jhd->ijh', query, key)

Scale scores

570        scores *= self.scale

Apply mask

573        if mask is not None:
574            scores = scores + (mask == 0) * float('-inf')

attention along the key sequence dimension

578        attn = jax.nn.softmax(scores, axis=1)

Multiply by values

582        x = jnp.einsum("ijh,jhd->ihd", attn, value)

Concatenate multiple heads

585        x = x.reshape(seq_len, -1)

Output layer

588        return self.output(x)

Position-wise Feed-Forward layer

This is based on our PyTorch implementation.

591class FeedForward(Module):
  • rnd_key is the PRNG state
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • activation is the activation function
601    def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, d_ff: int,
602                 activation=jax.nn.relu):
609        super().__init__()

Split the PRNG state

611        _, *rnd_keys = jax.random.split(rnd_key, 5)

Layer one parameterized by weight and bias

614        self.layer1 = Linear(rnd_keys[0], d_model, d_ff)

Layer one parameterized by weight and bias

616        self.layer2 = Linear(rnd_keys[1], d_ff, d_model)

Activation function

618        self.activation = activation
620    def __call__(self, x: jnp.ndarray):

622        x = self.activation(self.layer1(x))

624        return self.layer2(x)

Transformer Layer

This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.

627class TransformerLayer(Module):
  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
637    def __init__(self,
638                 d_model: int,
639                 self_attn: MultiHeadAttention,
640                 feed_forward: FeedForward):
646        super().__init__()
647        self.size = d_model
648        self.self_attn = self_attn
649        self.feed_forward = feed_forward
650        self.norm_self_attn = LayerNorm([d_model])
651        self.norm_ff = LayerNorm([d_model])
653    def __call__(self, x: jnp.ndarray, mask: jnp.ndarray):

Normalize the vectors before doing self attention

655        z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

657        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
658        x = x + self_attn

Normalize for feed-forward

661        z = self.norm_ff(x)

Pass through the feed-forward network

663        ff = self.feed_forward(z)

Add the feed-forward results

665        x = x + ff

667        return x

Cross Entropy Loss

670class CrossEntropyLoss(Module):
677    def __init__(self):
678        super().__init__()

Use jax.vmap to vectorize the loss function

681        self._loss_vmap = jax.vmap(self._loss, in_axes=(0, 0,))
683    def _loss(self, output: jnp.ndarray, target: jnp.ndarray):

685        return -jax.nn.log_softmax(output)[target]
  • output is the model outputs of shape [seq_len, n_vocab]
  • target is the target of shape [seq_len]
687    def __call__(self, output: jnp.ndarray, target: jnp.ndarray):

Use the vectorized loss function and calculate the mean.

We could have used a for loop to calculate the losses but using vmap is about 10X faster

696        return self._loss_vmap(output, target).mean()

Autoregressive Transformer

This is the transformer decode with embedding and output layers.

699class AutoregressiveTransformer(Module):
707    layers: ModuleList[TransformerLayer]
  • rnd_key is the PRNG state
  • n_vocab is the vocabulary size
  • d_model is the number of features in a token embedding
  • n_layers is the number of transformer layers
  • heads is the number of attention heads
  • d_ff is the number of features in the hidden layer of the FFN
709    def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, n_layers: int, heads: int, d_ff: int):
718        super().__init__()
719        self.n_vocab = n_vocab
720        self.d_model = d_model
721        self.loss_func = CrossEntropyLoss()

For transformer layers

724        layers = []
725        for i in range(n_layers):

Split PRNG state

727            rnd_key, mha_key, ffn_key = jax.random.split(rnd_key, 3)

Create a transformer layer

729            attn = MultiHeadAttention(mha_key, heads, d_model)
730            ffn = FeedForward(ffn_key, d_model, d_ff)
731            layers.append(TransformerLayer(d_model, attn, ffn))

Make a module list

733        self.layers = ModuleList(layers)

Split PRNG state

736        rnd_key, emb_key, out_key = jax.random.split(rnd_key, 3)

Create embedding layer

738        self.embeddings = EmbeddingsWithLearnedPositionalEncoding(emb_key, n_vocab, d_model)

Final normalization and output layer

740        self.norm = LayerNorm([d_model])
741        self.output = Linear(out_key, d_model, n_vocab)
743    def __call__(self, x: jnp.ndarray):

Get sequence length

745        seq_len = len(x)

A mask for attention so that a token can only see tokens before that

747        mask = jnp.tril(jnp.ones((seq_len, seq_len), bool))

Get embeddings with positional encodings

749        x = self.embeddings(x)

Apply the transformer layers

751        for i in range(len(self.layers)):
752            x = self.layers[i](x, mask)

Final normalization and linear transformation to get the logits

755        return self.output(self.norm(x))

Calculate the loss

757    def get_loss(self, x: jnp.ndarray):

Get model outputs

762        output = self(x)

Cross entropy loss

764        return self.loss_func(output[:-1], x[1:])


The starting sequence is given by seq and we greedily sample `length1 tokens

766    def sample(self, seq: jnp.ndarray, length: int = 20):
772        for i in range(length):

Sample the highest probability token

774            idx = jnp.argmax(self(seq)[-1])

Add it to the sequence

776            seq = jnp.concatenate((seq, idx[None]))

Return the sampled sequence

779        return seq

This is a named tuple for storing Adam optimizer state for a parameter

782class AdamState(NamedTuple):
786    m: jnp.ndarray
787    v: jnp.ndarray

Adam Optimizer

This is from paper Adam: A Method for Stochastic Optimization.

For parameter and gradient at step , the Adam update is,

where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

790class Adam:
  • params is the tree-map of parameters
  • lr is the learning rate
  • betas is a tuple of (, )
  • eps is `
816    def __init__(self, params: Dict,
817                 lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
818                 eps: float = 1e-16, ):
826        super().__init__()
827 = lr
828        self.betas = betas
829        self.eps = eps

States for each parameter

832        self.states = jax.tree_multimap(self._init_state, params)

Optimized step function

834        self._step_jit = jax.jit(self._step)

Number of steps taken

836        self._n_steps = 0

Optimized update state function

838        self._update_state_jit = jax.jit(self._update_state)

Initialize the state for a given parameter

840    def _init_state(self, param: jnp.ndarray):
844        return AdamState(jnp.zeros_like(param), jnp.zeros_like(param))

Step function

  • params is a tree-map of parameters
  • grads is a tree-map of gradients
846    def step(self, params: Dict, grads: Dict):

Increment step

854        self._n_steps += 1

Update states for each parameter

856        self.states = jax.tree_multimap(self._update_state_jit, grads, self.states)

Return updated parameters

858        return jax.tree_multimap(partial(self._step_jit, self._n_steps), params, self.states)

Update parameters

This performs a Adam update on the given parameter

860    def _step(self, n_steps: int, param: jnp.ndarray, state: AdamState):

Bias corrections for : and for :

868        bias_correction = [1 - beta ** n_steps for beta in self.betas]

Uncorrected first and second moments and

870        m, v = state

873        step_size = * (bias_correction[1] ** 0.5) / bias_correction[0]

875        den = (v ** 0.5) + self.eps

879        return param - step_size * m / den

Update state

This updates uncorrected first and second moments and

881    def _update_state(self, grad, state: AdamState):

Uncorrected first and second moments and

888        m, v = state

Clip gradients

890        grad = jnp.clip(grad, -1, 1)

892        m = self.betas[0] * m + grad * (1 - self.betas[0])

894        v = self.betas[1] * v + (grad ** 2) * (1 - self.betas[1])

Return the new state

897        return AdamState(m, v)

Tiny Shakespeare dataset

900class TinyShakespeare:
  • rnd_key is the PRNG state
  • seq_len is the sequence length of a sample
  • batch_size is the batch size
907    def __init__(self, rnd_key: jax.random.PRNGKey, seq_len: int, batch_size: int):
914        self.batch_size = batch_size

PRNG key for shuffling the samples

916        _, self.rnd_key = jax.random.split(rnd_key)

Local path of the text file

919        path = lab.get_data_path() / 'tiny_shakespeare.txt'

Download if it doesn't exist

921        url = ''
922        if not path.exists():
923            download_file(url, path)

Read the file

926        with open(str(path), 'r') as f:
927            self.text =

Get the characters/tokens

930        tokens = sorted(list(set(self.text)))

Number of tokens

933        self.n_tokens = len(tokens)

Map tokens to ids

935        self.stoi = {t: i for i, t in enumerate(tokens)}

Id to token/character

937        self.itos = tokens

As a list of ids

940        data = jnp.array([self.stoi[s] for s in list(self.text)])

Number of batches

942        self.n_batches = len(data) // (seq_len * batch_size)


944        data = data[:self.n_batches * seq_len * batch_size]

Reshape into a samples (better to use random offsets, but lets ignore that here)

946 = data.reshape((-1, seq_len))

List of sample indexes

948        self.idx = jnp.arange(len(

Setup for iteration

950    def __iter__(self):

Iteration step

955        self._iter_idx = 0

Split PRNG key

957        self.rnd_key, rnd_key = jax.random.split(self.rnd_key)

Shuffle sample indexes

959        self.idx = jax.random.permutation(rnd_key, self.idx)

962        return self

Number of batches

964    def __len__(self):
968        return self.n_batches

Get next batch

970    def __next__(self):

Stop iteration after iterating through all batches

976        if self._iter_idx >= self.n_batches:
977            raise StopIteration()

Sample indexes for the batch

980        idx = self.idx[self._iter_idx * self.batch_size:(self._iter_idx + 1) * self.batch_size]

Increment iteration step

982        self._iter_idx += 1

Return samples

985        return[idx]

Run the experiment

988def main():

Create experiment

996    experiment.create(name='jax')

Create PRNG key

998    rnd_key = jax.random.PRNGKey(0)

Create dataset

1000    dataset = TinyShakespeare(rnd_key, seq_len=32, batch_size=128)

Create the model

1003    model = AutoregressiveTransformer(rnd_key, dataset.n_tokens,
1004                                      d_model=128, n_layers=3, heads=8, d_ff=512)

Get model parameters

1006    params = model.get_params()

JAX compiled pure sampling function

1009    pure_sample_fn = jax.jit(model.purify(model.sample))

JAX compiled pure function to get logits for a batch. First we transform model.__call__ to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes specifies which arguments to parallelize and along which axis. (None, 0) means we have the same parameters but parallelize the inputs across the first axis. out_axes specifies along which axis to merge the results.

1017    pure_forward_fn = jax.jit(jax.vmap(model.purify(model.__call__),
1018                                       in_axes=(None, 0), out_axes=0))

Similarly we vectorize loss computation

1020    pure_loss_fn = jax.jit(jax.vmap(model.purify(model.get_loss),
1021                                    in_axes=(None, 0), out_axes=0))

A function to get mean loss

1024    def get_loss(params, seq):
1025        return pure_loss_fn(params, seq).mean()

A function to compute gradients for the first argument (parameters)

1028    grad_loss_fn = jax.jit(jax.grad(get_loss, argnums=0))

Create optimizer

1031    optimizer = Adam(params)

Start the experiment

1034    with experiment.start():

Iterate for 32 epochs

1036        for epoch in monit.loop(32):

Iterate through batches

1038            for data in monit.iterate('Train', dataset):

Compute and log the loss

1040                loss = get_loss(params, data)
1041      'loss', loss)

Get the gradients

1043                grads = grad_loss_fn(params, data)

Update parameters

1045                params = optimizer.step(params, grads)

1048            tracker.new_line()

Log a sample after each epoch

1050            prompt = [dataset.stoi[c] for c in 'It ']
1051            sampled = pure_sample_fn(params, jnp.array(prompt))[len(prompt):]
1052            sampled = ''.join([dataset.itos[i] for i in sampled])
1053            sampled = sampled.replace('\n', '\\n')
1054            logger.log(('It ', Text.meta), (sampled, Text.value))

1058if __name__ == '__main__':
1059    main()