Improving Language Understandingby Generative Pre-Training

Paper Notes

January 2025

Abstract

Paper Notes

Unsupervised Pre-Training

Given unsupervised tokens U={u1,,un}\mathcal{U} = \{u_1, \dots, u_n \}, the objective is to maximize the likelihood (equivalent to minimizing negative log-likelihood):

L1(U)=ilogP(uiuik,,ui1;Θ)L_1(\mathcal{U}) = \sum_i \log P(u_i \vert u_{i-k}, \dots, u_{i-1}; \Theta)

where Θ\Theta are parameters and kk is the size of the context window.

They use a Transformer Decoder, using MHSA\text{MHSA} over input context tokens, with then a position-wise feed forward layer to produce an output probability distribution.

h0=UWe+Wphl=TransformerBlock(hl1),l[1,n]P(u)=softmax(hnWeT)h_0 = UW_e + W_p \\[3mm] h_l = \text{TransformerBlock}(h_{l-1}), \forall l \in [1 ,n]\\[3mm] P(u) = \text{softmax}(h_nW^T_e)

where WpW_p is the positional encoding matrix.

Before the softmax, we compute the vector-matrix product in order to return the logits of the final hidden representation, hnh_n.

If aTb=abcos(θ)a^Tb = |a||b|\cos(\theta), where θ\theta is the angle between aa and bb, then aTbab=cos(θ)\frac{a^Tb}{|a||b|} = \cos(\theta).

Therefore, when we multiply each colunm vector of WeW_e with hnh_n, we're extracting a similarity metric, where the higher the value of the iith number in the output vector, zz, is, the higher likelihood the next-word is at index ii.

If WeW_e represents the tokens for each word as an nn-dimensional vector, this becomes a reliable way to predict the next word.

Supervised Fine-Tuning

After training the model via the objective L1(U)L_1(\mathcal{U}), they perform supervised finetuning.

Given a labeled dataset C\mathcal{C}, with mm input tokens and a label yy, where yy can be a sequence or a single index, the model is trained to predict

P(yx1,xm)=softmax(hlmWy)P(y \vert x_1, \dots x^m) = \text{softmax}(h_l^mW_y)

We add as final lienar layer, WyW_y, in order to be able to transform the hidden representation, hlmh_l^m into the proper nn-dimensional vector which we can feed into softmax\text{softmax}. In this case, there are nn output classes (for finetuning)

As the objective, it's found that maximizing:

L2(C)=(x,y)logP(yx1,,xm)L_2(\mathcal{C}) = \sum_{(x, y)}\log P(y \vert x^1, \dots, x^m)

alongside L1L_1 as:

L3(C)=L2(C)+λL1(C)L_3(\mathcal{C}) = L_2(\mathcal{C}) + \lambda L_1(\mathcal{C})

was found to improve generalization, as the model is constrained from overfitting to L2L_2 by λL1\lambda L_1.

Task-Specific Input Transformations

Textual Entailment: Premise + Hypothesis concatenated into one string (is the hypothesis true or false given the premise:w
a?)

Similarity: Non inherent ordering of two sentences, therefore they conduct multiple inference passes using different sentence orderings -- and process each independently. Then both F()=hlm\mathcal{F}() = h_l^m are concatenated as [hlm,hl2m][h_l^m, h_l^{2m}] and fed into as single linear layer for binary classification

Multiple Choice: Given context zz, question qq, and answers (ak)\begin{pmatrix}a_k\end{pmatrix}, they concatenate the document context and question with each posisble answer, adding a delimter "$" token in between.

Each are processed independently and normalized over softmax to get an output distribution over all possible answers.

Experiments

Unsupervised Pre-Training

  • BookCorpus Dataset for pre-training -- 7,000+ unpublished books from a variety of genres.

Model Specifications

See Attention

  • Trained a 12-layer decoder only transformer with masked self-attention heads.
    • 768 dimensional states (original embedding size and typically qkv size) and 12 attention heads (then we have a hidden size of 64)
  • Adam Optimizer w max learning rate of .00025
    • Learning rate was increased lienarly from over the first 2000 steps and then annealed to 0 using a cosine scheduler.
    • Trained for 100 epochs on minibatches of 64 randomly sampled sequences of 512 tokens.
  • Weight init of N(0,0.02)N(0, 0.02).
  • Use BPE with 40,000 merges.
  • Residual, Embedding, and Dropout with rate of .1
  • Use decoupled weight decay.