Large Language Models (LLMs) are trained on vast volumes of data and use billions of parameters to support tasks like answering questions, translating languages, and completing sentences. There are a few challenges when working with LLMs such as domain knowledge gaps, factuality issues, and hallucination, which affect their reliability especially for the fields that require high levels of accuracy, such as healthcare, law, or engineering. Retrieval Augmented Generation (RAG) provides a solution to mitigate some of these issues by augmenting LLMs with a specific domain or an organization’s internal knowledge base, without the need to retrain the model.
The RAG knowledge source is generally business specific databases which are typically deployed on general-purpose CPU infrastructure. So, deploying RAG on general-purpose CPU infrastructure alongside related business services is both efficient and cost-effective. With this motivation, we evaluated RAG deployment on AWS Graviton based Amazon EC2 instances which have been delivering up to 40% price-performance advantage compared to comparable instances for the majority of the workloads including databases, in-memory caches, big data analytics, media codecs, gaming servers, and machine learning inference.
In the past we published a few blog posts on how PyTorch was optimized for AWS Graviton processors to accelerate ML Inference performance for both eager mode (blog) and torch.compile
mode (blog). In this blog we cover how to deploy a typical RAG workload using PyTorch and torch.compile
, how we improved its performance up to 1.7x for embedding model and 1.3x for RAG query on AWS Graviton3-based m7g.xlarge instance compared to the default PyTorch “eager mode”, and finally a few recommendations that you can apply for your RAG use cases.
How to Optimize RAG?
Without RAG, the LLM takes the user input and creates a response based on information it was trained on (what it already knows). With RAG, an information retrieval component is introduced that utilizes the user input to first pull information from a new data source. The user query and the relevant information are both given to the LLM. The LLM uses the new knowledge and its training data to create better responses. The following diagram shows the conceptual flow of using RAG with LLMs.
Image 1: Conceptual flow of using RAG with LLMs
Source: https://aws.amazon.com/what-is/retrieval-augmented-generation/
Embedding model
At the core of RAG is an embedding model that takes the text data and converts into a vector representation. These vectors are then stored in a vector db. When a user makes a query, the query is first converted to a vector and the RAG does a similarity search on the vector db. Hence, the first step in optimizing RAG performance is optimizing an embedding model’s inference performance. We used the AWS Graviton3-based m7g.xlarge instance and the HuggingFace sentence-transformer embedding model for the optimization work. Here is a sample script for profiling the HuggingFace sentence-transformer embedding model inference with PyTorch Eager mode.
import torch
from torch.profiler import profile, ProfilerActivity, record_function
from transformers import AutoModel, AutoTokenizer
model_name = "sentence-transformers/all-mpnet-base-v2"
input_text = ["This is an example sentence", "Each sentence is converted"]
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoded_input = tokenizer(
input_text, padding=True, truncation=True, return_tensors="pt"
)
warmup, actual = 100, 100
model.eval()
with torch.no_grad():
# warmup
for i in range(warmup):
embeddings = model(**encoded_input)
with profile(activities=[ProfilerActivity.CPU]) as prof:
with record_function("model_inference"):
for i in range(actual):
embeddings = model(**encoded_input)
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
Eager mode
Since PyTorch eager mode was already optimized on AWS Graviton processors with the following runtime environment settings, we included them in the baseline and measured the following performance. Please refer to Optimized PyTorch 2.0 Inference with AWS Graviton processors for more details on how we optimized the PyTorch eager mode on AWS Graviton processors.
# Enable the fast math GEMM kernels, to accelerate fp32 inference with bfloat16 gemm
export DNNL_DEFAULT_FPMATH_MODE=BF16
# Enable Linux Transparent Huge Page (THP) allocations,
# to reduce the tensor memory allocation latency
export THP_MEM_ALLOC_ENABLE=1
# Set LRU Cache capacity to cache the primitives and avoid redundant
# memory allocations
export LRU_CACHE_CAPACITY=1024
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::addmm 61.01% 2.638s 62.49% 2.702s 370.197us 7300
model_inference 12.01% 519.161ms 100.00% 4.324s 4.324s 1
aten::bmm 6.25% 270.084ms 11.96% 517.089ms 215.454us 2400
aten::select 3.98% 172.165ms 5.34% 230.863ms 1.331us 173500
aten::copy_ 2.11% 91.133ms 2.11% 91.133ms 6.200us 14700
--------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 4.324s
Table 1: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with PyTorch Eager mode
Next, we added torch.compile
, weights pre-packing, and torch.inference_mode
and observed around 1.7x performance improvement. The following section talks about each of these optimizations and the resulting speedup.
torch.compile
In contrast to eager mode, the torch.compile
pre-compiles the entire model into a single graph in a manner that’s optimized for running on given hardware. Please refer to Accelerated PyTorch Inference with torch.compile on AWS Graviton processors for more details on torch.compile
features and how we optimized them on AWS Graviton processors. Invoke torch.compile
as shown in the following snippet to trigger PyTorch dynamo compilation for the model. This resulted in around 1.04x performance improvement from the baseline.
model = torch.compile(model)
---------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
---------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::addmm 64.46% 2.675s 66.66% 2.766s 378.905us 7300
Torch-Compiled Region 19.76% 820.085ms 99.04% 4.109s 41.094ms 100
aten::bmm 6.66% 276.216ms 12.52% 519.527ms 216.470us 2400
aten::select 3.98% 164.991ms 5.41% 224.488ms 1.299us 172800
aten::as_strided 1.66% 69.039ms 1.66% 69.039ms 0.383us 180100
---------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 4.149s
Table 2: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with torch.compile mode
Weights pre-packing
torch.compile
opens up opportunities like pre-packing the model weights into a format that is more suitable for the given hardware during the model compilation, thus improving the performance. Set the following config to trigger weights pre-packing. This resulted in around 1.69x improvement from the baseline.
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
mkldnn::_linear_pointwise 39.10% 994.821ms 41.50% 1.056s 144.628us 7300
Torch-Compiled Region 35.12% 893.675ms 98.42% 2.504s 25.043ms 100
aten::bmm 10.96% 278.859ms 21.66% 551.073ms 229.614us 2400
aten::select 7.34% 186.838ms 9.98% 253.840ms 1.469us 172800
aten::as_strided 2.63% 67.002ms 2.63% 67.002ms 0.388us 172800
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.544s
Table 3: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with torch.compile and weights pre-packing
torch.inference_mode
Additionally, use torch.inference_mode()
to get savings from turning off version control for tensors and view tracking of tensors. Please refer to the PyTorch documentation for more details.
with torch.inference_mode():
# instead of
with torch.no_grad():
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
mkldnn::_linear_pointwise 38.92% 987.276ms 41.17% 1.044s 143.056us 7300
Torch-Compiled Region 34.92% 885.895ms 98.45% 2.498s 24.975ms 100
aten::bmm 11.25% 285.292ms 22.22% 563.594ms 234.831us 2400
aten::select 7.74% 196.223ms 10.22% 259.251ms 1.500us 172800
aten::as_strided 2.48% 63.027ms 2.48% 63.027ms 0.365us 172800
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.537s
Table 4: Profiler output for HuggingFace sentence-transformer embedding model inference on AWS Graviton3-based m7g.xlarge instance with torch.compile, weights pre-packing, and inference_mode
The following table shows the incremental performance improvements achieved for the standalone embedding model inference.
Optimization level | Latency measured (in sec) | Improvement over the baseline |
PyTorch eager mode (Baseline) | 0.04324 | NA |
torch.compile | 0.04149 | 1.04x |
weights pre-packing | 0.02544 | 1.69x |
torch.inference_mode | 0.02537 | 1.70x |
The following script is an updated example for the embedding model inference with the previously discussed optimizations included. The optimizations are highlighted in GREEN.
import torch from torch.profiler import profile, record_function, ProfilerActivity from transformers import AutoTokenizer, AutoModel import torch._inductor.config as config config.cpp.weight_prepack=True config.freezing=True model_name = "sentence-transformers/all-mpnet-base-v2" input_text = ['This is an example sentence', 'Each sentence is converted'] model = AutoModel.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) encoded_input = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt') warmup , actual = 100, 100 model.eval() model = torch.compile(model) with torch.inference_mode(): #instead of with torch.no_grad() # warmup for i in range(warmup): embeddings = model(**encoded_input) with profile(activities=[ProfilerActivity.CPU]) as prof: with record_function("model_inference"): for i in range(actual): embeddings = model(**encoded_input) print(prof.key_averages().table(sort_by="self_cpu_time_total"))
End-to-End RAG scenario on CPU
After optimizing the embedding model inference, we started with a PyTorch eager mode based RAG setup, mainly to validate the functionality on the CPU backend. We built the RAG solution with HuggingFaceEmbeddings from langchain_community.embeddings
, as shown in the following code snippet.
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain.prompts import PromptTemplate
from langchain_core.prompts import format_document
from bs4 import BeautifulSoup as Soup
import torch
url = "https://pytorch.org/blog/pytorch2-5/"
chunk_size = 1000
chunk_overlap = 0
embedding_model = "sentence-transformers/all-mpnet-base-v2"
N = 5
question = "What's new in PyTorch 2.5?"
from transformers import AutoTokenizer, AutoModel
from typing import Any, List
loader = RecursiveUrlLoader(
url=url, max_depth=3, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
# Split the document into chunks with a specified chunk size
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
all_splits = text_splitter.split_documents(docs)
# Store the document into a vector store with a specific embedding model
model = HuggingFaceEmbeddings(model_name=embedding_model)
warmup , actual = 100, 100
with torch.inference_mode():
vectorstore = FAISS.from_documents(all_splits, model)
for i in range(warmup):
searchDocs = vectorstore.similarity_search(question, k=N)
import time
start = time.time()
for i in range(actual):
searchDocs = vectorstore.similarity_search(question, k=N)
end = time.time()
print(f"Time for 1 inference is {(end-start)/actual} seconds")
doc_prompt = PromptTemplate.from_template("{page_content}")
context = ""
for i, doc in enumerate(searchDocs):
context += f"\n{format_document(doc, doc_prompt)}\n"
Next, our goal was to optimize the end-to-end RAG use case with torch.compile and weights pre-packing that gave 1.7x improvement for the standalone embedding model inference. However, the optimizations didn’t work out of the box for the RAG scenario.
What are the challenges and solutions to achieve similar gains in an end-to-end RAG scenario?
Challenge 1: model handle
There was no way to get the model handle that was instantiated with HuggingFaceEmbeddings
, and the wrapper class doesn’t provide compile APIs. So, there was no way for our application to invoke torch.compile
to trigger the PyTorch dynamo compilation process.
Solution
We implemented our custom embedding class so that we can get a handle for the model. This instantiated the embedding model from sentence-transformers
, and maintained the handle for immediate compilation or compilation at a later stage. With this, we were able to trigger torch.compile
and hence the dynamo compilation.
class CustomEmbedding(HuggingFaceEmbeddings):
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
# Load model from HuggingFace Hub
self.client = AutoModel.from_pretrained(self.model_name)
class Config:
arbitrary_types_allowed = True
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
texts = list(map(lambda x: x.replace("\n", " "), texts))
# Tokenize sentences
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
embeddings = self.client(
**encoded_input, output_hidden_states=True
)
embeddings = embeddings.pooler_output.detach().numpy()
return embeddings.tolist()
# instead of model = HuggingFaceEmbeddings(model_name=embedding_model)
model = CustomEmbedding(model_name=embedding_model)
# torch.compile the model
model.client = torch.compile(model.client)
Challenge 2: triggering the optimization
For a typical inference scenario where the graph is frozen and gradient calculations are disabled, Torch inductor (the compiler backend we used for CPUs) invokes hardware specific optimizations like graph rewrite into more performant operators, operator fusion, and weights pre-packing. Though Torch dynamo was able to see the model and trigger generic compilation, it failed to trigger these additional Fx passes in the Torch inductor.
There were two main reasons for Torch inductor not triggering the optimization passes: (1) The application didn’t set no_grad()
or inference_mode()
for torch inductor to detect that the graph was frozen; and (2) We hit a limitation with the torch.compile framework, where, if the no_grad
is set just at the beginning of the compiled region, torch.compile
wouldn’t be able to detect it while invoking the inductor Fx
passes because it would not have hit the no_grad
region by then. Please refer to this GitHub issue for more details.
Solution
We work around this limitation by moving the no_grad()
context into the application code from within the model class. With this, the model compilation happened as expected and gave around 1.3x performance improvement when we profiled the stable inference pass for eager and compiled versions.
Challenge 3: extra compilation
With the previous fixes, the query lookup inference performance was improved, but not the total execution time of the benchmarking script. We root-caused it to redundant compilation for the model during the RAG inference. Further deep diving revealed that it was because of the batch size mismatch between the word embedding and the RAG query stages. For example, in our benchmarking script, when the database was vectorized and stored in vector db, we used the batch size of 16, hence the model was compiled with shapes of 16xNxK. Whereas, the RAG query lookup is usually a single request of shape 1xNxK. So, there was a batch size mismatch (dimension “0” of these tensors) that triggered the recompilation for the query lookup stage. We confirmed it with the following Torch logging: TORCH_LOGS="recompiles"
TORCH_LOGS="recompiles" python rag_compile.py
V1103 02:48:08.805986 34281 site-packages/torch/_dynamo/guards.py:2813] [0/1] [__recompiles] Recompiling function forward in site-packages/transformers/models/mpnet/modeling_mpnet.py:502
V1103 02:48:08.805986 34281 site-packages/torch/_dynamo/guards.py:2813] [0/1] [__recompiles] triggered by the following guard failure(s):
V1103 02:48:08.805986 34281 site-packages/torch/_dynamo/guards.py:2813] [0/1] [__recompiles] - 0/0: tensor 'L['input_ids']' size mismatch at index 0. expected 16, actual 1
Solution
Torch dynamo provides a decorator to mark the dimension of a given tensor as dynamic and specify an expected value for the same, so that re-compilation is not triggered. For example, specifying dimension “0” of input_ids
and attention_mask
as dynamic, and specifying that value of “1” is allowed in that dimension (as shown in the following code snippet), should have avoided the redundant compilations.
torch._dynamo.decorators.mark_unbacked(encoded_input['input_ids'], 0)
torch._dynamo.mark_dynamic(encoded_input['input_ids'], 1)
torch._dynamo.decorators.mark_unbacked(encoded_input['attention_mask'], 0)
torch._dynamo.mark_dynamic(encoded_input['attention_mask'], 1)
However, the Torch dynamo decorator and marking didn’t work in this particular case. Moreover, using the decorator created graph breaks. So, we added some warmup iterations to hide the compilation latency, and profiled the query lookup performance in the steady state. However, the good news is that, in practice, this re-compilation is triggered only for the first query, so it might not affect the production scenario if the database size is fixed. Moreover, PyTorch AOT Inductor (a new feature in PyTorch) addresses re-compilation and warm up challenges with torch.compile. In a follow-up blog we will address how in a production environment we can use AOT Inductor to address these challenges.
With these solutions we were able to apply torch.compile, weights pre-packing and the AWS Graviton specific optimizations for an end-end RAG scenario and improve the performance by 1.3x from the baseline eager mode.
Deployment
A detailed guide on how to deploy torch compiled RAG on AWS Graviton-based Amazon EC2 instances and how to deploy it in conjunction with Llama using TorchServe can be found on the PyTorch website.
Conclusion
In this blog, we covered how we optimized embedding model inference performance on AWS Graviton3-based EC2 instances. We also shared the challenges faced, the solutions we implemented to bring those optimizations for a RAG use case, and the resulting speedups. We hope that you will give it a try! If you need any support with ML software on Graviton, please open an issue on the AWS Graviton Technical Guide GitHub.
We would like to express our gratitude to Eli Uriegas for the support in making this blog post happen.
Authors
Sunita Nadampalli is a Principal Engineer and AI/ML expert at AWS. She leads AWS Graviton software performance optimizations for AI/ML and HPC workloads. She is passionate about open source software development and delivering high-performance and sustainable software solutions for SoCs based on the Arm ISA.
Ankith Gunapal is an AI Partner Engineer at Meta (PyTorch). He leads customer support, evangelizing & release engineering of TorchServe. He is passionate about solving production problems in model inference and model serving. He also enjoys distilling technically complex material in a user friendly format.
Hamid Shojanazeri leads the AI Frameworks Partner Engineering team at Meta. He is passionate about building scalable AI solutions and specializes in working with PyTorch to tackle the challenges of large-scale distributed training, inference, model serving, and optimization.