Contents
1. 서론 및 연구 배경
본 논문에서는 트랜스포머 모델의 인-컨텍스트 학습 능력을 중점적으로 분석하며, 특히 프리트레이닝 시 사용된 다양한 함수 클래스가 모델 선택 및 일반화에 어떠한 영향을 미치는지에 대해 연구하고 있다. Garg 등[4]에 의해 제시된 인-컨텍스트 학습 방법을 기반으로, 트랜스포머가 선형 및 비선형 함수를 학습하고 예측하는 과정을 살펴본다.
\(x_i \sim \mathcal{N}(0, I_d)\) \(f \sim D(F)\) \(\tilde{f}(x_{n+1}) = \text{Model's prediction for } f(x_{n+1})\)
$\mathcal{N}(0, I_d)$는 $d$ 차원에서의 정규분포를 나타내며, $D(F)$는 다양한 함수 클래스들의 분포를 의미한다. 이런 배경지식을 바탕으로, 트랜스포머가 어떻게 각각의 함수를 학습하고 예측하는지 분석한다.
2. 예비 연구 및 기존 방법
트랜스포머 모델은 시퀀스 데이터를 처리하여 다음 토큰을 예측하는 역할을 한다. 이 연구에서는 입력된 시퀀스 $s = (x_1, f(x_1), …, x_n, f(x_n), x_{n+1})$를 사용하여 $f(x_{n+1})$의 예측값을 생성하는 문제를 다룬다. 이 과정에서 모델의 성능은 예측 제곱손실로 측정된다.
\[E[(\tilde{f}(x_{n+1}) - f(x_{n+1}))^2]\]이 공식은 입력과 레이블이 주어진 상황에서 모델이 얼마나 정확하게 함수 $f$를 예측할 수 있는지를 나타낸다. 이는 트랜스포머가 시퀀스를 학습하면서 얻는 데이터의 내재적 특성을 반영한다.
3. 모델 선택 현상
데이터 소스 $D$에서 추출된 시퀀스 $s_i = (x_{i,1}, f(x_{i,1}), …, x_{i,n}, f(x_{i,n}))$를 사용하여 트랜스포머 모델 $m_{\theta}(s)$의 파라미터 $\theta$를 프리트레이닝한다. 사용된 주요 목표는 평균 제곱 손실을 최소화하는 것이다.
\[\ell = E[(\tilde{f}(x_{i,j}) - f(x_{i,j}))^2]\]본 연구에서는 선형 함수와 같이 단순한 함수 클래스부터 복잡한 비선형 함수 클래스까지 다양한 종류의 함수가 프리트레이닝 데이터로 사용되었다. 이런 다양한 함수 클래스의 혼합은 모델이 인-컨텍스트 예제를 통해 어떻게 다른 함수 클래스를 선택하고, 새로운 데이터에 대해 어떻게 일반화하는지를 연구하는 데 중요한 역할을 한다.
4. 일반화 및 모델 성능 분석
트랜스포머 모델의 일반화 능력을 평가하기 위해 다양한 실험을 수행하였다. 이 실험들은 모델이 프리트레이닝 중 보지 못한 함수 클래스의 데이터에 대해 어떻게 반응하는지를 보여준다. 예를 들어, 선형 함수와 사인 함수의 복합체를 사용한 실험에서 모델은 training dataset에서 보지 못한 함수 조합에 대한 예측을 생성할 때 오류가 발생하였다. 이는 모델이 training dataset에 근접한 함수 클래스에 대해서는 높은 예측 정확도를 보이지만, 새로운 형태의 데이터에 대해서는 일반화하는데 어려움을 겪을 수 있음을 시사한다.
이 연구는 트랜스포머 모델이 다양한 함수 클래스를 어떻게 학습하고, 이를 바탕으로 어떻게 새로운 데이터에 대한 예측을 생성하는지에 대한 깊이 있는 분석을 제공한다. 또한, 프리트레이닝 데이터의 다양성이 모델의 일반화 능력에 미치는 영향에 대해 탐구한다.
One of the impressive capabilities demonstrated by large language models is their ability to do few-shot learning by providing examples in-context and asking the model to generate a response to follow the final input provided [3]. Researchers have taken the underlying machine learning technology, transformer models, and demonstrated that they can perform also in-context learning In these works, they demonstrate the ability tasks in domains other than language [4, 1, 5]. of transformers to learn high-dimensional and non-linear functions of the inputs from in-context examples, often matching or exceeding the performance of state-of-the-art machine learning models tuned for the purpose of learning those functions. For example, Garg et al. [4] showed that after pretraining on sparse linear data, a transformer network can in-context learn unseen sparse linear functions as well as the Lasso, which is known to be statistically optimal for data being modeled. In these models, as in the large language models that they are designed to reflect, pretraining (or fine-tuning) the model with relevant data to teach the model how to perform in-context In this work, we focus in on a specific aspect of learning is critical to enabling this capability.
The few-shot learning setup that we study follows Garg et al. [4], where the goal is to use a set of provided inputs and labels, $\left((x_1, f(x_1)), (x_2, f(x_2)), \ldots (x_n, f(x_n))\right)$ to make a prediction about the label $f(x_{n+1})$ for a new input $x_{n+1}$. The number of examples provided is small relative to the amount of data used to pretrain the meta-learning model used to perform this task (hence the “few-shot” nomenclature). A common approach to apply sequence models for few-shot learning is first to pass the examples in sequentially, alternating inputs and labels, as $(x_1, f(x_1), x_2, f(x_2), \ldots, x_n, f(x_{n+1}))$. Finally, the test input point $x_{n+1}$ is passed as the final element of the sequence, and model prediction for the next item in the sequence is treated as the predicted label. Previous work [4, 1, 5] shows that transformer models are capable of learning many types of data distributions for $(x, f(x))$ pairs and investigate transformer models’ ability to make such predictions.
Training the model to be capable of such predictions requires fitting the model on many sequences of the form $s_i = (x_{1,i}, f(x_{1,i}), x_{2,i}, \ldots, x_{n,i}, f(x_{n,i}), x_{n+1,i}, f(x_{n+1,i}))$. Each example in the sequence is drawn using the same function $f$, and each sequence uses a different function $f$ drawn from some distribution $D(F)$ over function class $F$. We investigate interactions between pretraining data composition and transformers’ abilities to few-shot learn related tasks.
Our contributions are as follows:
Transformers are sequence models that provide next-token predictions conditional on the previous sequence tokens. We consider a data-generating model where $d$-dimensional covariates are drawn $x_i \sim \mathcal{N}(0, I_d)$ and a (random) function $f \sim D(F)$ is sampled from a distribution over function classes. Like Garg et al. [4] and Akyürek et al. [1], we frame the ICL problem as providing a single prompt sequence $s = (x_1, f(x_1), x_2, f(x_2), \ldots x_n, f(x_n), x_{n+1})$ to the model (i.e., a transformer) and generating a prediction for $f(x_{n+1}): \tilde{f}(x_{n+1})$. We refer to the problem of predicting the next token as in-context learning. The performance of an in-context learner is judged by its predictive squared-loss $E[(\tilde{f}(x_{n+1}) - f(x_{n+1}))^2]$, with the expectation taken over the randomness in the prompt and query.
Garg et al. [4] demonstrated that by pretraining a transformer model on simulated prompts represented as such sequences, the model is able to in-context learn unseen functions drawn from
We assume that the inputs and outputs are all represented as real scalars or vectors. If not in the same dimension, we can project the lower dimensional items into the higher dimensional space, filling in zeros for the empty dimensions.
For example, they demonstrate such pretrained transformers are able to perform as well as state-of-the-art ML methods on linear function classes (with both sparse and dense coefficient vectors) when pretrained on data generated from linear functions, decision trees when pretrained on data generated from decisions trees, and ReLU networks when pretrained on data generated from ReLU networks. Akyürek et al. [1] study transformers’ ability to learn linear models in-context and provide mechanistic interpretations of how they may be performing such learning. Li et al. [5] studies generalization properties of transformers in this setting and demonstrates similar results for linear dynamical systems. Raventós et al. [6] investigates the role of pretraining function diversity for in-context learning in the setting of pure linear regression – arguing that a sufficiently diverse distribution over linear tasks in pretraining is needed for ICL at test-time. Closest to our work on model selection is that of Bai et al. [2], which explores transformers’ abilities to perform model selection on empirical grounds and provides rigorous theoretical guarantees for transformers generalization properties in pretraining and their downstream in-context prediction performance. The guarantees and empirics in Bai et al. [2] are restricted to explorations on model selection amongst different linear function class families – namely studying model selection across evenly weighted task mixtures of linear regression with different label noise strengths and linear/logistic regression.
In our setting, given a data source $D$ containing sequences $s_i = (x_{i,1}, f(x_{i,1}), …, x_{i,n}, f(x_{i,n}))$, we pretrain the parameters $\theta$ of the transformer model $m_{\theta}(s)$ by performing loss minimization on the “teacher forcing” objective for the squared-loss $\ell$, where we use $s_{i,1:j}$ to refer to the $i$-th sequence up to (but not including) the $j$-th output $f(x_{i,j})$. We use a 9.5M parameter decoder model with 12 layers, 8 attention heads, and a 256-dimensional embedding space as in Garg et al. [4]. Details on the model and our training setup are in Appendix A.1.
The focus of this paper is understanding how the construction of the data source $D$ affects the in-context learning abilities of the model in the controlled setting of learning function classes. In this case of studying a single function class family as in Garg et al. [4], Akyürek et al. [1] and Li et al. [5], for a linear data-generating model, a single $f$ can be effectively sampled by drawing $\beta \sim \mathcal{N}(0, I_d)$ and defining $f(x) = \beta^T x$ for use in a sequence $s_i$. two-layer ReLU networks, and (d) $D(F_{\text{sine}})$:
In this work, we use data mixtures which combine together examples generated from multiple distinct function families. We consider several distributions over base function classes $D(F)$: (a) $D(F_{\text{dense}})$: dense linear functions, (b) \(D(F_{\text{sparse,nnz}})\): sparse linear functions with nnz non-zero coordinates, (c) $D(F_{\text{relu}})$: sinusoidal functions. Additional details on how these base distributions over function classes are generated are provided in the Appendix A.2. Mixture distributions over function classes take the form: \(D(F) = w \cdot D(F_A) + (1 − w) \cdot D(F_B)\) for selected function classes $F_A$ and $F_B$. Each training prompt sequence is constructed by randomly selecting a function class from the mixture based on $w$ and sampling from that base class. For example, for $w = .25$ with selected function classes $D(F_{\text{dense}})$ and $D(F_{\text{sparse,2}})$, a sequence $s_i$ is drawn from $D(F_{\text{dense}})$ with probability 0.25 and from $D(F_{\text{sparse,2}})$ with probability 0.75.
Garg et al. [4] argues transformers generalize well to tasks/function drawn from the same distribution as the training data. However, one general open question is how these models perform on examples that are out-of-distribution from the training data.
In the setting studied here, we interpret the generalization question as the following: “Can a model generate good predictions with in-context examples from a function not in any of the base function classes seen in the pretraining data mixture?” We build to this question by first studying the abilities of transformers to perform model selection between different function class families seen in pretraining in Section 3. We then transition to answering the OOD-generalization question for a few important cases in Section 4.
One question that comes up when pretraining data mixture of different function classes is “how does the model select between different function classes when presented with in-context examples in the support of the pretraining mixture?” We address this question here from an empirical point of view. In this section, we find that the models make optimal (or nearly so) predictions after seeing in-context examples from a function class which is a member of the pretraining data mixture. We also observe how models perform on functions that are not cleanly part of any single component function class before exploring some functions that are definitively out-of-distribution from all pretraining data in Section 4.
We begin with the study of linear functions, which have received a significant attention in this area of in-context learning. Garg et al. [4] show that transformers pretrained on linear functions perform nearly optimally at in-context learning on new linear functions. In particular, they consider two models: one trained on dense linear functions (where all of the coefficients of the linear model are non-zero), and one trained on sparse linear functions (where say only 2 of the 20 coefficients are non-zero). Each model performs correspondingly as well as linear regression and Lasso regression on new dense and sparse linear functions, respectively. We additionally compare these two models to a model pretrained on a mixture of both sparse and dense linear functions.
Figure 1 shows that the model pretrained on a \(D(F) = 0.5 \cdot D(F_{\text{dense}}) + 0.5 \cdot D(F_{\text{sparse,2}})\) mixture performs similarly at in-context learning as models pretrained on only one function class. Since the model pretrained on the mixture performs similarly to the models shown by Garg et al. [4] to be theoretically optimal, we infer that this model is nearly optimal, as well. The ICL learning curves in Figure 2 show us that this in-context model selection ability is relatively uniform with respect to the number of in-context examples provided. In Figure 2, we also see that ICL learning curves for pretraining data mixtures with various non-trivial weight $w$ (or $1 - w$) for a given function class nearly match the optimal baseline sample complexity compared to pretraining a model purely on that function class. We observe only small deviations and these decay quickly as ICL sample count increases, matching the behavior in Figure 1 which corresponds to a single point on the ICL learning curve. Figure 2 also demonstrates that transformer model ICL generalization suffers out-of-distribution. Even though dense and sparse linear classes are both linear functions, we can see the poor performance of the red curve in Figure 2a (which corresponds to a transformer pretrained on only sparse linear functions and evaluated on dense linear data) and vice-versa for the teal curve in Figure 2b. We see similar behavior for other nonlinear function classes, as detailed in Appendix B. We additionally briefly explore the effect of model size on model selection in the later part of Appendix B.
The mechanistic question about how these empirical phenomena come to be is interesting, but we do not address it here. We believe that first clearly documenting the current phenomena is important for the research community.
Figures Described:
Returning to the experiment in Figure 1 and plotting the error as a function of the number of non-zero coefficients over the entire range of possibilities shows that the model pretrained on the mixture with $w = 0.5$, $D(F) = 0.5 \cdot D(F_{\text{dense}}) + 0.5 \cdot D(F_{\text{sparse,2}})$, performs as well as the models pretrained on the mixture components, i.e., $w = 0$ and $w = 1$, throughout (see Figure 3a). This suggests that the model is capable of performing model selection to choose whether to make predictions using knowledge solely from one base function class in the pretraining mixture or the other. Indeed, Figure 3b shows that when the examples provided in-context come from very sparse or very dense functions, the predictions are nearly identical to those made by the models pretrained on either only sparse or only dense data, respectively.
We examine the In-Context Learning (ICL) generalization capabilities of models along two axes. First, we explore and test ICL performance on functions the model has never seen in training but could plausibly predict: convex combinations of functions drawn from the pretraining function classes. Second, we evaluate ICL performance on functions which are extreme versions of functions seen in pretraining (i.e., sinusoids with much higher or lower frequencies than those typically seen in pretraining). In both cases, we find little evidence of out-of-distribution generalization. When the function is significantly far from those seen during pretraining, the predictions are erratic. However, when the function is sufficiently close to the pretraining data, the model approximates it well with predictions from the function classes on which it was pretrained.
Figure 3a shows that the transformer’s predictions at moderate sparsity levels (nnz = 3 to 7) are not similar to any of the predictions from either of the function classes provided at pretraining, but rather, something in between the two. Hence, one may hypothesize that the model has some inductive bias that allows it to combine pretrained function classes in nontrivial ways. For instance, one may suspect that the model can produce predictions from the combination of the functions that it saw during pretraining. To test this hypothesis in a context with clearly disjoint function classes, we study the ability to perform ICL on linear functions, sinusoids, and convex combinations of the two. We focus on the one dimensional case to make evaluating and visualizing nonlinear function classes straightforward.
Figure 4: Comparing predictions from models pretrained on different data sources after providing 3 sets of in-context examples (shown in red). Predictions are made by sweeping over x values passed in as the last element of the sequence after the in-context examples.
(a) The models pretrained on linear or both linear and sinusoids make good linear predictions when provided examples from a linear model. (b) The models pretrained on sinusoids or both linear and sinusoids make good sinusoidal predictions when provided examples from a cosine. (c) None of the models predict well when provided examples from a convex combination of a linear and cosine functions (although the linear-only model approximates the line of best fit).
Figure 4c: shows a specific convex combination of a linear function and a sinusoid.
In Figure 5, we sweep over the relative weights of the linear function and sine wave in the convex combination. Here, we observe that when the combined function is predominantly from one function class or the other – i.e., well-approximated by the function classes learned during pretraining – the in-context predictions are reasonable. However, when both functions contribute significantly to the convex combination, the model makes erratic predictions not well-justified by the in-context examples. This shows that the model selection capabilities of the model are limited by proximity to the pretraining data, and suggests that broad coverage of function space is critical for generalized in-context learning capabilities.
Figure 5: Predictions from transformer models pretrained on linear functions (blue), sinusoids (orange), or both (green) when provided the red examples in-context, coming from a combination of a linear function and sinusoid, with the relative weights noted in the titles of each plot.
As the tasks become so rare as to be out-of-distribution, we find that model generalization starts strong and then degrades dramatically. Specifically, we trained models on a mixture of sinusoids with frequencies drawn from a Gamma(6,1/6) distribution. This distribution ensures that the average frequency is 1. Frequencies below 0.01 or above 5 are extremely rare; the probability of sampling a frequency outside this range is less than 1×10−7, i.e., we expect to see approximately 100 examples out of the 1 billion used in pretraining.
Figure 6: shows the predictions from all of the models on a variety of frequencies inside and outside this range.
Figure 4c showed a specific convex combination of a linear function and a sinusoid. In Figure 5, we sweep over the relative weights of the linear function and sine wave in the convex combination. Here, we observe that when the combined function is predominantly from one function class or the other – i.e. well-approximated by the function classes learned during pretraining – the in-context predictions are reasonable. However, when both functions contribute significantly to the convex combination, the model makes erratic predictions not well-justified by the in-context examples. This shows that the model selection capabilities of the model are limited by proximity to the pretraining data, and suggests that broad coverage of function space is critical for generalized in-context learning capabilities.
The previous convex combinations were specifically constructed so that the model had never seen similar functions in pretraining. Shifting to the scenario where sections of the function class space are increasingly rare in the pretraining data, we find that the model generalization starts strong and then degrades dramatically as the tasks become so rare as to be out-of-distribution. Specifically, we trained models on a mixture of sinusoids with frequencies drawn from a Gamma(6, 1/6) distribution. This distribution ensures that the average frequency is 1. Frequencies below 0.01 or above 5 are extremely rare: the probability of a sampling a frequency outside this range is less than $1 \times 10^{-7}$, i.e. we expect to see approximately 100 examples out of the 1 billion used in pretraining. Figure 6 shows the predictions from all of the models on a variety of frequencies inside and outside this range.
Note here convex combination refers to adding the weighted y-values of the functions. This is distinct from the weighted mixture distribution used in pretraining which samples either a linear function or sine function for the f used in each prompt sequence.
Figure 5. Predictions from transformer models pretrained on linear functions (blue), sinusoids (orange), or both (green) when provided the red examples in-context, coming from a combination of a linear function and sinusoid, with the relative weights noted in the titles of each plot.
Figure 6. Predictions from transformer models pretrained on sinusoids (orange), or both sinusoids and linear functions (green) when provided the red examples in-context, coming from sinusoids of increasing frequency (noted in the plot title).
We have empirically explored the role of the pretraining data composition on the ability of pre-trained transformers to in-context learn function classes both inside and outside the support of their pretraining data distribution. We have empirically shown that for task families or function classes well-represented in the pretraining mixture, cost of selecting the appropriate function class to use for in-context learning is nearly zero. We next explored generalizability on two scenarios: (1) We found that the pretrained transformers struggle to predict on convex combinations of functions drawn from pre-training function classes, and (2) We observed that transformers can generalize effectively on rarer sections of the function-class space and still break down as the tasks become out-of-distribution.
An important question is understanding how the observations we make here carry over to tokenized models and to questions represented in natural language. We attempted an experiment to train a tokenized model for the one-dimensional examples presented in Section 4 by binning the scalar values into buckets, and treating the bucket indices as tokens for the input to a transformerbased language model. We trained this model for 5M epochs with a cross-entropy loss as typically used in language models, but were unable to significantly decrease the loss. Understanding the challenges to training such a model and evaluating whether this framing has different model selection or out-of-distribution generalization properties is important future work. In the natural language setting, the intuitive notions of how to appropriately define task families (in our case function classes), precise pretraining mixtures, and convex combinations are all less clear. We believe bridging the gap between the notions presented here and in language modeling may help to improve our understanding of the power of ICL and how to effectively enable it.