Batching liên tục từ những nguyên tắc cơ bản
Batching liên tục từ những nguyên tắc cơ bản
- 19 min read
Continuous batching từ nguyên lý đầu tiên
TL;DR: Trong bài đăng này, bắt đầu từ các cơ chế attention và KV caching, chúng tôi dẫn xuất continuous batching bằng cách tối ưu hóa thông lượng.
Nếu bạn đã từng sử dụng Qwen, Claude hoặc bất kỳ chatbot AI nào khác, bạn có thể đã nhận thấy điều gì đó: phải mất một lúc để từ đầu tiên của phản hồi xuất hiện, và sau đó các từ xuất hiện lần lượt trên màn hình của bạn với tần suất đều đặn và nhanh chóng (hy vọng là vậy). Đó là bởi vì về cốt lõi, tất cả các LLM chỉ là những bộ dự đoán token tiếp theo rất tốt. Một LLM trước tiên xử lý toàn bộ prompt của bạn để tạo ra một token mới. Sau đó, nó tiếp tục thêm các token từng cái một, mỗi lần đọc lại mọi thứ đã xảy ra trước đó, cho đến khi nó quyết định quá trình tạo đã hoàn tất.
Quá trình tạo này tốn nhiều tài nguyên tính toán: nó yêu cầu truyền đầu vào qua hàng tỷ tham số cho mỗi token được tạo ra. Để làm cho các mô hình này trở nên thiết thực cho các ứng dụng trong thế giới thực, đặc biệt là khi phục vụ nhiều người dùng đồng thời, các nhà nghiên cứu và kỹ sư đã phát triển một loạt các kỹ thuật suy luận hiệu quả. Một trong những tối ưu hóa có tác động nhất là continuous batching, một kỹ thuật cố gắng tối đa hóa hiệu suất bằng cách xử lý nhiều cuộc trò chuyện song song và hoán đổi chúng ra khi hoàn thành.
Để hiểu cách continuous batching hoạt động và tại sao nó lại hiệu quả trong các tình huống phục vụ tải cao, chúng ta sẽ xây dựng dựa trên các nguyên tắc cơ bản về cách LLM xử lý token.
Attention
Cơ chế attention là phần trung tâm của cách LLM hoạt động. Một mô hình ngôn ngữ xử lý văn bản bằng cách chia nó thành các phần mà chúng ta gọi là token. Chúng ta có thể coi “token” là “từ”, nhưng đôi khi một từ có thể bao gồm nhiều token. Đối với mỗi chuỗi token, mạng tính toán một dự đoán về token tiếp theo nên là gì.
Nhiều thao tác trong mạng là token-wise (theo từng token): mỗi token được xử lý độc lập, và đầu ra cho một token nhất định chỉ phụ thuộc vào nội dung của token đó, không phụ thuộc vào bất kỳ token nào khác trong chuỗi. Các thao tác như vậy bao gồm chuẩn hóa lớp hoặc nhân ma trận. Tuy nhiên, để tạo kết nối giữa các từ trong một câu, chúng ta cần các thao tác mà các token có thể ảnh hưởng lẫn nhau.
Đây là nơi cơ chế attention phát huy tác dụng. Các lớp attention là nơi duy nhất mà các token khác nhau tương tác với nhau. Hiểu cách một mạng kết nối các token với nhau có nghĩa là hiểu cơ chế attention.
Hãy xem cách điều này hoạt động trong thực tế, trong trường hợp chỉ có một prompt đầu vào.
Hãy xem xét prompt ban đầu I am sure this project, được token hóa thành 7 token: [<bos>, I, am, sure, this, pro, ject]. Token <bos>, hay “Beginning of Sequence” (Bắt đầu chuỗi), là một token đặc biệt mà chúng ta thêm vào đầu prompt để báo cho mô hình ngôn ngữ rằng một cuộc trò chuyện mới bắt đầu tại đây.
Mỗi token được biểu diễn bên trong mạng với một vector có độ dài d (là hidden dimension). Do đó, bảy token đầu vào tạo thành một tensor $x$ với hình dạng $[1, 7, d]$. $1$ là số lượng chuỗi, hoặc kích thước batch, chỉ là một trong trường hợp của chúng ta. $7$ là độ dài chuỗi, và $d$ là hidden dimension, hoặc kích thước của mỗi biểu diễn token. Chúng ta sẽ sử dụng $n$ thay vì $7$ làm độ dài chuỗi trong các ví dụ tiếp theo.
Tensor đầu vào $x$ sau đó được chiếu bởi ba ma trận: phép chiếu query $W_q$, phép chiếu key $W_k$ và phép chiếu value $W_v$. Điều này tạo ra ba tensor $Q$, $K$ và $V$, tất cả đều có hình dạng $[1, n, A]$, trong đó $A$ là chiều của đầu attention. Chúng ta gọi chúng là các trạng thái query, key và value, tương ứng. Điều này được biểu diễn ở bên trái trong hình dưới đây.
Tiếp theo, các tensor $Q$ và $K$ được nhân với nhau để đo lường sự tương đồng giữa các token, tạo ra một tensor có hình dạng $[1, n, n]$. Đây là lý do tại sao chúng ta nói rằng attention có độ phức tạp bậc hai theo độ dài chuỗi. Việc tính $QK^T$ yêu cầu $O(n^2 d)$ thao tác, vì vậy chi phí là bình phương của độ dài chuỗi $n$. Nó được biểu diễn ở bên phải trong hình trên.
Sau đó, chúng ta áp dụng một attention mask boolean cho $QK^T$ để kiểm soát token nào có thể tương tác, như được biểu diễn trong hình dưới đây. Trong hình này, attention mask là causal mask, có nghĩa là mỗi token chỉ tương tác với các token đi trước nó. Điều này tuân theo trực giác rằng một nguyên nhân phải đi trước kết quả của nó, do đó có tên là causal mask. Attention mask rất quan trọng vì nó quyết định tất cả các tương tác token trong mạng. Đặt tất cả các giá trị attention mask thành False và không token nào có thể tương tác với token nào khác trong toàn bộ mạng. Chúng ta sẽ xem xét kỹ hơn các attention mask trong vài đoạn nữa.
Cuối cùng, sau khi áp dụng attention mask, chúng ta áp dụng softmax theo từng token (tương đương với softmax theo từng hàng) và nhân kết quả với phép chiếu value $V$ để nhận đầu ra của một đầu attention, có hình dạng $[1, n, A]$. Chúng ta cung cấp một bản tóm tắt trực quan về toàn bộ quy trình trong hình sau đây.
Chúng ta sẽ sử dụng nhiều hình ảnh attention trong bài đăng này, vì vậy để đơn giản hóa mọi thứ, chúng ta sẽ cô đọng hình trên lại một chút.
Tại sao điều này quan trọng: Trong continuous batching, $Q$, $K$, và $V$ có thể có số lượng token khác nhau vì, như chúng ta sẽ thấy, chúng ta sẽ xử lý các giai đoạn khác nhau (prefill và decode) cùng một lúc. Để làm cho nó tổng quát hơn, hãy nói rằng $Q$ có hình dạng $[1, n_Q, A]$, $K$ có hình dạng $[1, n_K, A]$, và $V$ có hình dạng $[1, n_V, A]$.
Các điểm attention $QK^T$ sau đó có hình dạng $[1, n_Q, n_K]$, và attention mask có cùng hình dạng vì nó được áp dụng theo điểm tới các điểm. Thay vì biểu diễn điểm attention, chúng ta sẽ biểu diễn attention mask vào vị trí của nó. Cuối cùng, vì $Q$, $K$ và $V$ là các phép chiếu trực tiếp của $x$, không cần biểu diễn $x$. Điều này mang lại hình ảnh đơn giản hóa, nơi chúng ta chỉ biểu diễn $Q$, $K$ và attention mask:
Biểu diễn này cũng nhấn mạnh cách chúng ta có thể đọc một attention mask.
Chúng ta đọc mask theo từng hàng, tương đương với việc đọc theo từng token: mỗi hàng tương ứng với phép tính attention của một token. Một ô màu xanh lá cây tại vị trí (hàng i, cột j) có nghĩa là True: token j có thể ảnh hưởng đến token i. Một ô màu trắng có nghĩa là False: không cho phép tương tác.
Ví dụ, hãy xem hàng thứ ba cho token “am”. Cột “I” có màu xanh lá cây, vì vậy “I” ảnh hưởng đến việc tính toán “am”. Cột “pro” có màu trắng, vì vậy “pro” không ảnh hưởng đến “am”. Đây là cách causal masking hoạt động: các token trong tương lai không thể ảnh hưởng đến các token trong quá khứ.
Lớp cuối cùng của mô hình xuất ra dự đoán token cho mỗi token đầu vào. Trong bối cảnh của chúng ta, việc tạo ra phần tiếp theo của một prompt duy nhất, chúng ta chỉ quan tâm đến dự đoán token tiếp theo từ token cuối cùng. Token cuối cùng là “ject” trong hình trên, và dự đoán tương ứng là “will”.
Quá trình chúng ta vừa mô tả, nơi chúng ta lấy toàn bộ chuỗi đầu vào, truyền nó qua nhiều lớp attention và tính toán điểm cho token tiếp theo, được gọi là prefill. Đây là vì, như chúng ta sẽ thấy trong giây lát, phần lớn tính toán chúng ta đã thực hiện có thể được lưu trữ và tái sử dụng – do đó, chúng ta đang prefilling cache. Nhờ việc sử dụng cache này, việc tạo chuỗi có thể tiến hành với ít tính toán hơn trong giai đoạn được gọi là decoding. Trong giai đoạn decoding, việc tạo ra một token mới sẽ nhanh hơn nhiều so với phép tính toàn bộ chuỗi ban đầu. Hãy xem tại sao.
Để tiếp tục tạo, chúng ta bắt đầu một lượt truyền xuôi mới, mà theo cách thông thường sẽ trông như sau:
Để tính toán điểm attention của token mới, chúng ta vẫn cần các phép chiếu key và value của các token trước đó. Vì vậy, chúng ta cần lặp lại phép nhân ma trận của các token cũ (màu xám trong hình trên) với $W_k$ và $W_v$ để lấy lại một kết quả đã được tính toán trước đó. Nói cách khác, chúng ta đang lãng phí tài nguyên tính toán. Hãy xem cách chúng ta có thể tránh điều đó.
KV-cache
Ngay lập tức, chúng ta nhận thấy rằng token cuối cùng không ảnh hưởng đến phép tính attention của các token khác:
Điều này tuân theo ý tưởng của causal mask: vì “will” đến sau tất cả các token trước đó, nó không làm thay đổi phép tính attention của chúng. Đối với việc tạo văn bản, causal attention là phổ biến nhất, vì vậy chúng ta sẽ tập trung vào trường hợp đó từ bây giờ. Hãy nhớ rằng các sơ đồ attention không gây ra (non-causal) cũng có thể được sử dụng, đặc biệt là khi xử lý hình ảnh. Xem xét rằng chúng ta chỉ cần dự đoán token tiếp theo cho token “will”, chúng ta có thể đơn giản hóa cơ chế attention bằng cách chỉ tính toán đầu ra cho token này.
Hơn nữa, chúng ta đã tính toán các trạng thái $K$ và $V$ cho các token “<bos>”, … , “ject” trong lượt truyền xuôi trước đó: nếu chúng đã được lưu trữ, chúng ta không cần phải tính toán lại chúng. Đây là KV cache: danh sách các trạng thái key và value được tạo ra trong quá trình tạo. Về cơ bản, nó cho phép chúng ta giảm chi phí tính toán để tạo token $n+1$ từ $O(n^2)$ xuống $O(n)$ bằng cách tránh tính toán lại các phép chiếu key và value, trong khi phải trả chi phí bộ nhớ $O(n)$.
Trong hình trên, chỉ các token màu trắng được tính toán: thay vì tính toán key và value cho 8 token, chúng ta tính cho 1 token. Bạn có thể thấy rằng thông qua KV caching, rất nhiều tài nguyên tính toán đã được tiết kiệm. Bạn có thể xem bài đăng này để biết thêm hình ảnh về KV caching, hoặc bài này để biết ví dụ triển khai thực tế.
Hãy cụ thể hơn một chút về kích thước cache, vì đây là cơ hội tốt để kiểm tra các hình dạng có trong mô hình của chúng ta. Đối với một mô hình có $\mathcal{L}$ lớp attention và $H$ đầu attention với chiều đầu $A$, kích thước cache tổng cộng cần để lưu trữ một token sẽ là $2*\mathcal L * AH$ với hệ số $2$ để tính cả $K$ và $V$. Ví dụ, Llama-2-7B với $\mathcal{L}=32$ lớp, $H=32$ đầu, và $A=128$ yêu cầu $2 \times 32 \times 128 = 8,192$ giá trị trên mỗi token trên mỗi lớp. Với độ chính xác float16, điều này tốn $2AH \times 2$ byte $= 16$ KB bộ nhớ.
KV caching rất hữu ích khi chúng ta muốn tạo token tiếp theo, một giai đoạn chúng ta gọi là decoding. Nhưng nó cũng có thể hữu ích trong giai đoạn prefill, khi chúng ta xử lý prompt ban đầu và có nhiều token đầu vào. Đặc biệt là khi có các prompt ban đầu lớn không vừa với bộ nhớ GPU cùng một lúc.
Chunked Prefill
Cho đến nay, chúng ta đã xem xét một ví dụ về prefill nơi chúng ta có $n=7$ token, nhưng trong thực tế các prompt ban đầu có thể dài hơn nhiều. Ví dụ, khi sử dụng Cursor, bạn có thể thêm kho lưu trữ của mình vào prompt, nơi nó hoạt động như một ngữ cảnh: điều này làm tăng đáng kể kích thước prompt. Trong những trường hợp như vậy, bộ nhớ cần thiết để lưu trữ các activation cho $n$ token có thể lớn hơn bộ nhớ có sẵn trên GPU. Do đó, chúng ta không thể thực hiện prefill trong một lần truyền xuôi duy nhất: chúng ta phải chia prefill thành các chunk. Điều này được gọi là chunked prefill, và nó sẽ là một trong những thành phần cần thiết để cho phép suy luận hiệu quả.
Hãy giả sử rằng bộ nhớ có sẵn bị giới hạn rất nhiều, và chúng ta chỉ có thể truyền $m=4$ token mỗi lượt truyền xuôi. Nếu chúng ta có một prompt ban đầu với $n=7$ token, chúng ta cần chia nó thành $\lceil n /m \rceil = 2$ chunk (làm tròn lên 7/4 = 1.75 thành 2). Chúng ta minh họa ví dụ dưới đây bằng cách sử dụng cùng các ký hiệu $n$ và $m$:
Chúng ta có thể làm điều đó nhờ KV cache. Chúng ta lưu trữ các trạng thái KV trong lần chia prefill đầu tiên, và trong lần chia prefill thứ hai, chúng ta ghép các trạng thái KV đã lưu trữ vào đầu các trạng thái KV mới. Chúng ta cũng điều chỉnh attention mask tương ứng. Về mặt hình ảnh, nó trông giống như chúng ta chia prefill không phân chunk thành hai phần.
Ý tưởng chính: các trạng thái KV được lưu trữ cho phép chúng ta xử lý prompt một cách tăng dần mà không làm mất thông tin.
Mặc dù chúng ta đã cho thấy một ví dụ ở đây, nơi chúng ta chia prefill thành 2 chunk, chunked prefill có thể được sử dụng để chia prefill theo bất kỳ cách nào chúng ta muốn, thích ứng linh hoạt với các ràng buộc bộ nhớ.
Chúng ta cuối cùng đã trang bị đủ tất cả các công cụ cần thiết để hiểu về Continuous Batching.
Continuous batching
Trong các ví dụ trước của chúng ta, chúng ta chỉ xem xét trường hợp kích thước batch là một, tức là chúng ta chỉ tạo token cho một prompt tại một thời điểm. Trong bối cảnh đánh giá hoặc phục vụ mô hình, chúng ta muốn tạo token cho một số lượng lớn các prompt. Để tăng thông lượng, là số lượng token được tạo ra mỗi giây, hành động tốt nhất là tạo token song song cho một batch gồm nhiều prompt.
Để batch các prompt lại với nhau, cách thông thường là thêm một trục vào cả hai tensor đầu vào: chuỗi token và attention mask. Tuy nhiên, điều này đi kèm với một ràng buộc về hình dạng của đầu vào: chúng ta cần tất cả các prompt có cùng độ dài, vì các tensor phải có hình chữ nhật. Để đạt được điều này, chúng ta thường thêm padding ở bên trái để dự đoán token mới luôn đến từ token ở ngoài cùng bên phải. Chúng ta cũng sửa đổi attention mask của mỗi prompt tương ứng, như được hiển thị dưới đây:
nơi các token padding <pad> có màu cam. Sau đó, chúng ta có thể thực hiện lượt truyền xuôi như trước đây, với thêm kích thước batch size. Điều này được gọi là batched generation: hiệu quả cho các prompt có cùng độ dài, nhưng lãng phí khi độ dài khác nhau.
Nó được minh họa dưới đây, thông qua 4 bước tạo: một bước prefilling (ở trên cùng) và 3 bước decoding (dưới mỗi dòng “Forward pass”).
trong đó <eos> có nghĩa là “End Of Sequence” (Kết thúc chuỗi), đây là một token đặc biệt để chỉ ra rằng mô hình đã đạt đến cuối quá trình tạo cho chuỗi tương ứng.
Nhược điểm của batched generation là nếu một prompt hoàn thành quá trình tạo trước các prompt khác bằng cách tạo ra token <eos>, tất cả các token được tạo ra sau đó đều vô dụng. Và điều này tiếp tục cho đến khi yêu cầu dài nhất của batch hoàn thành. Tất nhiên, chúng ta có thể loại bỏ các prompt đã đạt đến token <eos> khỏi batch và tiết kiệm một số tài nguyên và bộ nhớ, nhưng mục tiêu ở đây không phải là tiết kiệm tài nguyên: mà là thông lượng.
Thay vì chỉ loại bỏ prompt đã hoàn thành khỏi batch, chúng ta có thể thay thế nó bằng một prompt đang chờ tạo. Chúng ta sẽ gọi đây là dynamic scheduling, hoặc dynamic batching. Dynamic scheduling rất tuyệt để duy trì thông lượng đồng thời đảm bảo bất kỳ token nào được tạo ra bởi một lượt truyền xuôi đều có liên quan. Nhưng do cách chúng ta đã batch các prompt lại với nhau, nó có một nhược điểm lớn: chúng ta cần rất nhiều padding khi hoán đổi prompt. Đó là bởi vì prompt được chèn mới cần phải trải qua quá trình prefill trong khi các prompt khác đang giải mã từng token một. Vì vậy, gần như có nhiều padding như có token trong prompt được chèn mới.
Vấn đề càng trở nên tồi tệ hơn khi kích thước batch tăng lên và các prompt ban đầu dài. Chi phí padding tăng theo cấp số nhân với cả kích thước batch và độ dài prompt. Nếu chúng ta có một batch gồm $B$ prompt đang ở giai đoạn decoding và một prompt hoàn thành, việc đưa một prompt có $n$ token ban đầu vào batch một cách động đòi hỏi $(n-1)(B-1)$ token padding. Ví dụ, với $B=8$ và $n=100$, chúng ta sẽ cần $99 \times 7 = 693$ token padding!
Hơn nữa, các tối ưu hóa thực tế như CUDA graphs hoặc torch.compile yêu cầu các hình dạng tensor tĩnh. Điều này buộc chúng ta phải đệm tất cả các prompt đến một độ dài tối đa cố định, làm tăng đáng kể lãng phí padding.
Tại thời điểm này, vấn đề chính của chúng ta là padding, là hệ quả của trục mà chúng ta đã thêm để batch các câu lại với nhau. Do đó, cách lý tưởng là loại bỏ hoàn toàn trục này, một sự suy nghĩ lại triệt để về batching. Nếu chúng ta làm như vậy, cách duy nhất để batch các prompt lại với nhau là nối chúng lại:
Nhưng chúng ta không muốn các token từ prompt 0 tương tác với các token của prompt 1! May mắn thay, chúng ta có một cách để kiểm soát cách các token tương tác với nhau: attention mask. Cách chúng ta làm điều đó được hiển thị dưới đây:
Mặc dù chúng ta sử dụng các sắc thái xanh lá cây khác nhau để minh họa các phần khác nhau của attention mask, đây vẫn là một mask boolean với chỉ các màu xanh lá cây cho True và màu trắng cho False. Cách batch các prompt này được gọi là ragged batching (vì độ dài chuỗi “ragged” hay không đều), và nó mang lại lợi ích về tăng thông lượng mà không cần thêm token padding.
Trong hình trên, chúng ta sử dụng ragged batching để kết hợp hai prompt đầy đủ với nhau, nhưng chúng ta có thể batch bao nhiêu tùy theo bộ nhớ cho phép. Giới hạn duy nhất là $m$, số lượng token mà chúng ta có thể đưa vào một batch, với $m$ phụ thuộc vào bộ nhớ có sẵn trên GPU.
Ragged batching là một trong những thành phần chính của continuous batching. Để tối đa hóa thông lượng, chúng ta có thể kết hợp các chuỗi prefill và decoding theo một thuật toán như sau:
- Chúng ta cố gắng luôn đạt được ngân sách bộ nhớ của chúng ta là $m$ token trên mỗi batch
- Chúng ta trước tiên thêm tất cả các prompt ở giai đoạn decoding vào batch, mỗi prompt chiếm 1 token
- Chúng ta lấp đầy không gian còn lại với các prompt ở giai đoạn prefill, dựa vào tính linh hoạt của chunked prefill để chia các đầu vào khi cần thiết
Dynamic scheduling là mảnh ghép cuối cùng góp phần vào kỹ thuật continuous batching: chúng ta loại bỏ các prompt đã hoàn thành khỏi batch ngay khi chúng hoàn thành, và thay thế chúng bằng các prompt được phân chunk mới tương ứng với các yêu cầu đến.
Sự kết hợp giữa ragged batching và dynamic scheduling này được gọi là continuous batching, và đây là kỹ thuật cung cấp năng lượng cho các hệ thống phục vụ LLM hiện đại.
Kết luận
Continuous batching kết hợp ba kỹ thuật chính để tối đa hóa thông lượng trong việc phục vụ LLM:
- KV caching để tránh tính toán lại biểu diễn token trước đó
- Chunked prefill để xử lý các prompt có độ dài thay đổi trong giới hạn bộ nhớ
- Ragged batching với dynamic scheduling để loại bỏ lãng phí padding và giữ cho GPU được sử dụng tối đa
Bằng cách loại bỏ chiều batch và sử dụng attention masks để kiểm soát tương tác token, continuous batching cho phép kết hợp các giai đoạn prefill và decode trong cùng một batch, cải thiện đáng kể hiệu quả khi phục vụ nhiều yêu cầu. Đây là lý do tại sao các dịch vụ như ChatGPT có thể xử lý hàng nghìn người dùng đồng thời một cách hiệu quả.
Trong bài viết tiếp theo trong loạt bài này, chúng ta sẽ khám phá quản lý KV cache hiệu quả thông qua paged attention. Nếu bạn muốn xem một phân tích sâu về các chủ đề continuous batching khác, vui lòng cho chúng tôi biết trong phần bình luận!
Ghi nhận: cảm ơn Arthur Zucker đã sản xuất ý tưởng ban đầu cho các hình ảnh được sử dụng trong bài viết này. Và cảm ơn Arthur Zucker, Luc Georges, Lysandre Debut, Merve Noyan và Pedro Cuenca đã cung cấp các bài đánh giá hữu ích.