Abstract and 1. Introduction

2. Method

3. Experiments on real data

3.1. Benefits scale with model size and 3.2. Faster inference

3.3. Learning global patterns with multi-byte prediction and 3.4. Searching for the optimal n

3.5. Training for multiple epochs and 3.6. Finetuning multi-token predictors

3.7. Multi-token prediction on natural language

4. Ablations on synthetic data and 4.1. Induction capability

4.2. Algorithmic reasoning

5. Why does it work? Some speculation and 5.1. Lookahead reinforces choice points

5.2. Information-theoretic argument

6. Related work

7. Conclusion, Impact statement, Environmental impact, Acknowledgements, and References

A. Additional results on self-speculative decoding

B. Alternative architectures

C. Training speeds

D. Finetuning

E. Additional results on model scaling behavior

F. Details on CodeContests finetuning

G. Additional results on natural language benchmarks

H. Additional results on abstractive text summarization

I. Additional results on mathematical reasoning in natural language

J. Additional results on induction learning

K. Additional results on algorithmic reasoning

L. Additional intuitions on multi-token prediction

M. Training hyperparameters

Language modeling losses Dong et al. (2019) and Tay et al. (2022) train on a mixture of denoising tasks with different attention masks (full, causal and prefix attention) to bridge the performance gap with next token pretraining on generative tasks. Tay et al. (2022) uses the span corruption objective, which replaces spans of tokens with special tokens for the encoder and the decoder then predicts the contents of those spans. Unlike UniLM, this allows full causal training with teacher forcing. Similarly, Yang et al. (2019) train on permuted sequences, while conserving the original positional embeddings, effectively training the model to predict various parts of the sequence given a mix of past and future information. This permuted language modeling is the closest task to ours since it allows predicting beyond the next token. However all of these language modeling tasks train on a small percentage of the input text: on average only 15% of the tokens are backwarded through. For Dong et al. (2019), where the masking is done in BERT style, it is hard to mask more than 15% since it destroys too much information. For Tay et al. (2022), it is technically possible to have a larger proportion but in practice, the settings used have between 15% and 25% of masked tokens. (Yang et al., 2019) also makes it possible to train on the whole sequence since it is only permuted, and no information is lost. Yet, in practice, since the completely random permutation is very hard to reconstruct, only 15% are predicted for training stability reasons.

Multi-token prediction in language modelling Qi et al. (2020) argue that multi-token prediction encourages planning, improves representations and prevents the overfitting on local patterns that can result from teacher-forced training. However, their technical approach replicates the residual stream n-fold while ours allows for compute-matched comparisons and makes the residual representations participate more directly in the auxiliary loss terms. Stern et al. (2018) and Cai et al. (2024) propose model finetunings with multitoken prediction for faster inference but do not study the effects of such a loss during pretraining. Pal et al. (2023) use probing methods to show that next-token prediction models are able to predict additional consecutive tokens to a certain extent, but less so than our models which are specifically trained for this task. Jianyu Zhang (2024) observe improvements in language modelling tasks with multi-label binary classification over the occurrence of vocabulary words in the future as an auxiliary learning task.

Self-speculative decoding Stern et al. (2018) are, to the best of our knowledge, the first to suggest a speculative decoding scheme for faster inference. Our architecture replaces their linear prediction heads by transformer layers, but is otherwise similar. By reorganizing the order of the forward/backward, we can use all loss terms instead of stochastically picking one head for loss computation. Cai et al. (2024) present a more elaborate self-speculative decoding scheme that uses the top-k predictions of each head instead of the best one only. It can be used with the multi-token prediction models we train.

Multi-target prediction Multi-task learning is the paradigm of training neural networks jointly on several tasks to improve performance on the tasks of interest (Caruana, 1997). Learning with such auxiliary tasks allows models to exploit dependencies between target variables and can even be preferable in the case of independent targets (Waegeman et al., 2019). While more specifically tailored architectures for multi-target prediction are conceivable (SpyromitrosXioufis et al., 2016; Read et al., 2021), modern deep learning approaches usually rely on large shared model trunks with separate prediction heads for the respective tasks (Caruana, 1997; Silver et al., 2016; Lample et al., 2022) like we do. Multi-target prediction has been shown to be a successful strategy in various domains, e.g. for learning time series prediction with more distant time steps in the future as auxiliary targets (Vapnik and Vashist, 2009) or for learning from videos with several future frames (Mathieu et al., 2016; Srivastava et al., 2016) or representations of future frames (Vondrick et al., 2016) as auxiliary targets.

This paper is available on arxiv under CC BY 4.0 DEED license.

Authors:

(1) Fabian Gloeckle, FAIR at Meta, CERMICS Ecole des Ponts ParisTech, and contributed equally;

(2) Badr Youbi IdrissiFAIR at Meta, LISN Université Paris-Saclay, and contributed equally;

(3) Baptiste Rozière, FAIR at Meta;

(4) David Lopez-Paz, FAIR at Meta and his the last author;

(5) Gabriel Synnaeve, FAIR at Meta and the last author.