How are embeddings models trained for RAG?
You can ask your RAG pipeline, "What line is the bug on?", and it will tell you the answer almost instantly. How?
Embeddings models are the secret sauce that makes RAG work so well. How are they trained in this "asking questions of documents" use case?
In this blog post we'll unpack how embeddings models like OpenAI's text-embedding-3-large
are trained to support this document retrieval and chat use case.
Table of contents
- Training Data and Labels
- Model Architecture
- Loss Functions
- Training Procedure
- Embedding Extraction
- Let's consider a real example
- Summary
Training Data and Labels
In the context of training an embedding model for RAG, the training data usually consists of pairs of queries and documents.
The labels aren't traditional categorical labels; instead, they are used to create pairs of similar (positive) and dissimilar (negative) document embeddings.
- Positive pairs: Queries and their relevant documents.
- Negative pairs: Queries and irrelevant documents (often sampled randomly or using hard negatives).
Here's what a simple pre-training example might look like in Python:
import random
# Sample data
queries = ["What is RAG?", "Explain embeddings.", "How to train a model?"]
documents = [
"RAG stands for Retrieval-Augmented Generation.",
"Embeddings are vector representations of text.",
"Model training involves adjusting weights based on loss."
]
# Function to create positive and negative pairs
def create_pairs(queries, documents):
pairs = []
labels = []
for query in queries:
# Positive pair (query and relevant document)
positive_doc = random.choice(documents)
pairs.append((query, positive_doc))
labels.append(1)
# Negative pair (query and irrelevant document)
negative_doc = random.choice([doc for doc in documents if doc != positive_doc])
pairs.append((query, negative_doc))
labels.append(0)
return pairs, labels
pairs, labels = create_pairs(queries, documents)
for pair, label in zip(pairs, labels):
print(f"Query: {pair[0]} \nDocument: {pair[1]} \nLabel: {label}\n")
Model Architecture
Many modern embedding models are based on transformer architectures, such as BERT, RoBERTa, or specialized models like Sentence-BERT (SBERT). These models typically output token-level embeddings.
- Token-level embeddings: Each token (word or subword) in the input sequence gets its own embedding vector. I built a demo showing what word and subword tokens look like here.
- Pooling mechanism: Embeddings for each token are useful, but how do we roll those up into something more meaningful? To get a single vector representation of the entire document or query, a pooling mechanism is applied to the token-level embeddings.
Pooling Mechanism
Pooling mechanisms are used to get an embedding that represents an entire document or query. How can we condense the token-level embeddings into a single vector? There are several common approaches:
Mean Pooling
Mean pooling involves averaging the embeddings of all tokens in the sequence. This method takes the mean of each dimension across all token embeddings, resulting in a single embedding vector that represents the average contextual information of the entire input.
This approach provides a smooth and balanced representation by considering all tokens equally. For example:
import torch
# Example token embeddings (batch_size x seq_length x embedding_dim)
token_embeddings = torch.randn(1, 10, 768)
#### Mean pooling
mean_pooled_embedding = torch.mean(token_embeddings, dim=1)
print(mean_pooled_embedding.shape) # Output shape: (1, 768)
[CLS] Token Embedding
In models like BERT, a special [CLS] token is added at the beginning of the input sequence. The embedding of this [CLS] token, produced by the final layer of the model, is often used as a representation of the entire sequence. The [CLS] token is designed to capture the aggregated information of the entire input sequence.
This approach provides a strong, contextually rich representation due to its position and function.
import torch
from transformers import BertModel, BertTokenizer
# Initialize BERT model and tokenizer
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Example input text
text = "Example input text for BERT model."
# Tokenize input
inputs = tokenizer(text, return_tensors='pt')
# Get token embeddings from BERT model
outputs = model(**inputs)
# Extract [CLS] token embedding
cls_embedding = outputs.last_hidden_state[:, 0, :]
print(cls_embedding.shape) # Output shape: (1, 768)
Max Pooling
Max pooling selects the maximum value from each dimension across all token embeddings. This method highlights the most significant features in each dimension, providing a single vector representation that emphasizes the most prominent aspects of the input.
This method captures the most salient features, and can be useful in scenarios where the most significant feature in each dimension is important.
import torch
# Example token embeddings (batch_size x seq_length x embedding_dim)
token_embeddings = torch.randn(1, 10, 768)
# Max pooling
max_pooled_embedding = torch.max(token_embeddings, dim=1)[0]
print(max_pooled_embedding.shape) # Output shape: (1, 768)
In summary:
- Mean Pooling: Averages all token embeddings to get a balanced representation.
- [CLS] Token Embedding: Uses the embedding of the [CLS] token, which is designed to capture the overall context of the sequence.
- Max Pooling: Selects the maximum value from each dimension to emphasize the most significant features.
These pooling mechanisms transform the token-level embeddings into a single vector that represents the entire input sequence, making it suitable for downstream tasks such as similarity comparisons and document retrieval.
Loss Functions
The training objective is to learn embeddings such that queries are close to their relevant documents in the vector space and far from irrelevant documents.
Common loss functions include:
- Contrastive loss: Measures the distance between positive pairs and minimizes it, while maximizing the distance between negative pairs. See also Geoffrey Hinton's paper on Contrastive Divergence.
- Triplet loss: Involves a triplet of (query, positive document, negative document) and aims to ensure that the query is closer to the positive document than to the negative document by a certain margin. This paper on FaceNet describes using triplets, and this repository has code samples.
- Cosine similarity loss: Maximizes the cosine similarity between the embeddings of positive pairs and minimizes it for negative pairs.
Training Procedure
The training process involves feeding pairs of queries and documents through the model, obtaining their embeddings, and then computing the loss based on the similarity or dissimilarity of these embeddings.
- Input pairs: Query and document pairs are fed into the model.
- Embedding generation: The model generates embeddings for the query and document.
- Loss computation: The embeddings are used to compute the loss (e.g., contrastive loss, triplet loss).
- Backpropagation: The loss is backpropagated to update the model weights.
Embedding Extraction
After training, the model is often truncated to use only the layers up to the point where the desired embeddings are produced.
For instance:
- Final layer embeddings: In many cases, the embeddings from the final layer of the model are used.
- Intermediate layer embeddings: Sometimes, embeddings from an intermediate layer are used if they are found to be more useful for the specific task.
Let's consider a real example
Sentence-BERT (SBERT) is a good example of a model specifically designed for producing sentence-level embeddings.
- Model architecture: Based on BERT, but with a pooling layer added to produce a fixed-size vector for each input sentence.
- Training data: Uses pairs of sentences with a label indicating if they are similar or not.
- Training objective: Uses a Siamese network structure and contrastive loss to ensure that similar sentences have embeddings close to each other and dissimilar sentences have embeddings far apart.
Summary
Training an embedding model for Retrieval Augmented Generation use cases requires a few key components:
- Training data: Pairs of queries and documents (positive and negative).
- Model output: Typically token-level embeddings pooled to create sentence/document-level embeddings.
- Loss functions: Contrastive loss, triplet loss, cosine similarity loss.
- Training: Involves generating embeddings, computing loss, and updating model weights.
- Embedding extraction: Uses final or intermediate layer embeddings after training.