Some of my thoughts, filtered slightly for public consumption.

Understanding Transformers

The famous Attention Is All You Need paper is very readable for cutting-edge research, but still assumes a fair amount of background familiarity with neural networks and existing research avenues — both in order to actually draw the full model it describes, and to motivate the design choices the authors made. Without this background it has taken me a fair amount of work to understand, so I've written this up both to clarify my own understanding and to help others.

The Transformer Architecture

Deep learning models for text generally consist of:

  1. A tokenizer that has a list of 10s or 100s of thousands of tokens (often whole words) which converts the text into these tokens, represented as vectors with a 1 at the index of the corresponding tokens and 0s in every other index. This is known as a sparse representation of the text.
  2. A linear map that converts the sparse representation to a much lower-dimensional dense representation.
  3. A multi-layered neural network that repeatedly transforms this dense representation.
  4. A head that transforms the final layer's output into the desired output, e.g. text.

Each layer in the neural network can contain various sub-layers, but in many architectures including Transformers these will always include:

  1. A linear map from the previous layer
  2. A non-linear activation function applied elementwise to the linear map's output
  3. A normalization function to keep the output norm constant (or at least bounded)

The original paper focused on a specific type of deep learning model known as an encoder-decoder model, which was at the time the dominant approach to deep learning models for text-based tasks such as translation, however the Transformer architecture is not specific to encoder-decoder models. As the name suggests, an encoder-decoder model consists of 2 components:

The decoder is typically run repeatedly, taking the highest-probability output token in the N+1 index and running the decoder again, until a special "end of sequence" token is predicted or the output hits a length limit. Sometimes a search strategy is used, taking multiple possible output tokens and continuing to generate with them for several iterations before using some heuristic to discard most branches in the search tree, but this is very computationally expensive.

The Transformer architecture introduces several components to this:

Training Models

Another important piece of background information to understand is that neural networks in general are known to be universal approximators — regardless of the exact structure, any neural network can be trained to approximate any function or imitate any program, provided it is large enough. Furthermore, while there are theoretical bounds on the size of the neural network necessary to approximate various kinds of functions or solve various classes of problems, these results have not generally mapped well onto the performance of the models in practice. As a result, it is irresponsible to talk about a deep learning model without talking about the training process.

In general, deep learning models are trained via a relatively simple process, iterating over a large number of data points and at each iteration:

  1. Running the model on the input
  2. Along the way, computing and tracking the derivative for each operation (a process called automatic differentiation) with respect to the learnable parameters of the model
  3. Computing a loss based on an objective function (usually some notion of distance to some expected output for that input) and the derivative of this objective function with respect to all parameters of the model, using the previously computed derivatives and the chain rule
  4. Updating the parameters by a small amount along the direction that minimizes the objective function

Additionally, it is typical to randomly zero a fraction of the outputs of each layer at each training step in order to make the model learn a more robust (though more redundant) representation and to break certain symmetries. This is known as dropout.

While choices such as how much to update in each iteration (the learning rate) and the exact notion of distance used (known as the loss function) have an impact, by far the most important choice in training the model is the dataset used and what outputs are expected. This might seem overdetermined — if you are trying to train a model to perform a certain task, surely you should train on examples of the inputs and outputs of that task? However, in practice it is better to train primarily on tasks that foster a greater "understanding" than tasks that are closest to your use case, and only afterwards train on the desired task for a much smaller number of iterations. Often this involves switching out the head to generate a different kind of output, such as a binary classifier. This insight was best illustrated with the creation of BERT in 2018, which was trained to predict sequences of masked tokens within text and then fine-tuned for various downstream tasks — achieving state-of-the-art performance on each.

Current Frontier Models

Most current models for general text generation use an even simpler architecture — they are decoder-only, with no encoder and hence no cross-attention. Instead of feeding the input text to an encoder, these models rely on some simple transformation of the input text so that the model can reasonably treat it as the earlier part of its output which it is simply continuing. A common method is to turn the input into the start of a dialog, e.g. if the user asks "What is 1+1?" the decoder will be asked to predict the next token in this sequence:

User: What is 1+1?
Assistant:

Decoder-only models are more parallelizable than encoder-decoder models, which has allowed creating larger models with longer context windows. For most applications this trade-off has been worthwhile.

Decoder-only models are not well-suited to masking technique used to train BERT, but are instead trained on a large variety of tasks that have been structured as continuing a stream of text. The exact details are closely guarded trade secrets, since the training data and process differentiates these models more than architecture.