Huấn luyện các mô hình ternary bit thấp với Axolotl

Hướng dẫn huấn luyện các mô hình ternary bit thấp sử dụng Axolotl

  • 7 min read
Huấn luyện các mô hình ternary bit thấp với Axolotl
Hướng dẫn huấn luyện các mô hình ternary bit thấp sử dụng Axolotl

Huấn luyện các mô hình tam phân bit thấp với Axolotl

Tác giả: Đội ngũ Axolotl, Younes Belkada từ đội ngũ FalconLLM Mô hình: https://huggingface.co/collections/axolotl-ai-co/falcon-e-bitnet

Các nghiên cứu gần đây như Bonsai-1bit cho thấy sự quan tâm mạnh mẽ trong việc tạo ra các mô hình ngôn ngữ lớn (LLM) bit thấp cho cộng đồng, nhằm triển khai nhiều khả năng hơn trên các thiết bị biên và tài nguyên hạn chế. Trong sự hợp tác này, chúng tôi tập trung vào việc giúp việc huấn luyện các mô hình tam phân 1.58-bit (ternary LLMs) trở nên dễ tiếp cận hơn bằng cách tích hợp quy trình huấn luyện dòng mô hình Falcon BitNet của TII vào Axolotl. Chúng tôi cũng phát hành các mô hình thử nghiệm cho cộng đồng, được huấn luyện thông qua giai đoạn SFT thuần túy bằng Axolotl (bắt đầu từ các phiên bản bfloat16 và tiền lượng tử hóa của các LLM định dạng tam phân hiện có), cũng như các biến thể tinh chỉnh bằng DPO, nhằm chứng minh tính khả thi của việc tinh chỉnh BitNet.

BitNet (LLM định dạng tam phân) là gì?

BitNet được Microsoft giới thiệu trong bài báo “The Era of 1-bit LLMs: All Large Language Models are in 1.58 bits” vào năm 2024. Ý tưởng chính là huấn luyện các mô hình để chúng có khả năng chống chịu với việc lượng tử hóa tam phân (tức là các trọng số chỉ nhận giá trị -1, 0 hoặc 1). Điều này đạt được bằng cách đưa các sai số lượng tử hóa vào trọng số và các giá trị kích hoạt (activations) của mô hình trong quá trình huấn luyện, áp dụng cho tất cả các lớp tuyến tính (ngoại trừ lớp lm head vì lớp này rất nhạy cảm với lượng tử hóa).

Các mô hình sau khi huấn luyện sẽ có trọng số tam phân, chỉ với một hệ số tỷ lệ (scaling factor) duy nhất cho mỗi tensor và đạt được mức giảm bộ nhớ lên đến 7 lần (tùy thuộc vào kích thước từ vựng của mô hình) so với phiên bản bfloat16 tương ứng.

Quá trình huấn luyện diễn ra ở định dạng bfloat16 vì các giá trị kích hoạt và gradient luôn được tính toán ở độ chính xác này. Nói một cách đơn giản, lượng tử hóa tam phân được “mô phỏng” trong khi huấn luyện để làm cho mô hình tương thích với lượng tử hóa tam phân khi thực hiện suy luận (inference).

Dưới đây là mã PyTorch gốc của việc mô phỏng lượng tử hóa BitNet áp dụng cho các lớp tuyến tính. Trong thực tế, các kernel Triton tối ưu hóa sẽ được sử dụng để thực hiện các thao tác này.

def weight_quant_torch(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-05)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

def activation_quant_torch(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-05)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y

Lưu ý: Việc đưa các công thức này vào quá trình truyền xuôi (forward pass) sẽ dẫn đến các thao tác không thể tính đạo hàm (torch.clamp, …). Điều này được khắc phục bằng cách sử dụng straight-through estimators.

Thuật ngữ 1.58-bit xuất phát từ việc về mặt lý thuyết, có thể đóng gói các trọng số tam phân với trung bình 1.58 bit cho mỗi tham số nếu đáp ứng một số điều kiện nhất định. Trong thực tế, các trọng số tam phân được đóng gói ở độ chính xác 2-bit, sử dụng các tensor uint8 (tức là 4 tham số cho mỗi tensor).

Bài viết này sẽ tập trung vào việc tinh chỉnh dòng mô hình BitNet Falcon-E.

Hỗ trợ BitNet trong hệ sinh thái LLM

Các mô hình BitNet hiện có sự hỗ trợ tương đối mạnh mẽ trong hệ sinh thái cho việc suy luận trên CPU.

llama.cppik-llama.cpp hỗ trợ các định dạng lượng tử hóa TQ2_0TQ1_0 (trọng số tam phân được đóng gói tương ứng ở độ chính xác 2-bit và 1.58-bit). Thư viện mlx của Apple cũng hỗ trợ các mô hình BitNet thông qua các kernel Metal tối ưu hóa, và bạn cũng có thể thực hiện suy luận bằng thư viện transformers của Hugging Face thông qua torch.compile.

Mặc dù Microsoft gần đây đã phát hành các kernel GPU tối ưu hóa cho các mô hình BitNet, nhưng việc hỗ trợ suy luận tối ưu trên GPU vẫn chưa có mặt trên các khung vận hành phổ biến như vLLM hay sglang.

Cách huấn luyện mô hình BitNet với Axolotl

Cách an toàn nhất để huấn luyện một mô hình BitNet là huấn luyện lại từ đầu bằng kiến trúc BitNet, hoặc lấy một mô hình BitNet hiện có — phiên bản tiền lượng tử hóa (tức là state dict bfloat16 thuần túy của checkpoint) — và thực hiện huấn luyện tiền tiếp (continuous pre-training) hoặc tinh chỉnh (fine-tuning).

Việc chia sẻ các trọng số BitNet tiền lượng tử hóa đã được thực hiện thông qua các bản phát hành gần đây như Falcon-E hoặc bản Bitnet-2B mới nhất của Microsoft.

Sau khi huấn luyện, các checkpoint có thể được chuyển đổi an toàn sang định dạng tam phân bằng một phương pháp chuyển đổi đơn giản và đóng gói trọng số tam phân vào các tensor uint8. Điều này có thể thực hiện dễ dàng với hàm tiện ích quantize_to_1bit từ thư viện onebitllms của TII:

# uv pip install onebitllms
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH

Huấn luyện mô hình Falcon-E với Axolotl

Đầu tiên, hãy đảm bảo cài đặt gói onebitllms và sử dụng một trong các checkpoint dưới đây để tinh chỉnh các mô hình Falcon-E:

  • tiiuae/Falcon-E-1B-prequantized-bf16
  • tiiuae/Falcon-E-3B-prequantized-bf16
  • tiiuae/Falcon3-10B-1.58bit-prequantized-bf16 (thử nghiệm)

Checkpoint tiiuae/Falcon3-10B-1.58bit-prequantized-bf16 là bản thử nghiệm, vì checkpoint tiền lượng tử hóa gốc không được phát hành, do đó chúng tôi đã xấp xỉ các trọng số bằng cách đưa các hệ số tỷ lệ vào trọng số tam phân và lưu mô hình ở định dạng bfloat16.

Trong tệp cấu hình Axolotl, hãy bật tính năng tinh chỉnh BitNet với cờ sau:

...
use_onebitllms: true

Vậy là xong! Bạn có thể thử nghiệm với tập dữ liệu và siêu tham số của riêng mình để xây dựng LLM tam phân bằng CLI huấn luyện của Axolotl. Chúng tôi cũng cung cấp một cấu hình mẫu trong examples/bitnet/falcon-e-1b.yaml, cấu hình này sẽ sử dụng onebitllms và FSDP.

axolotl train examples/bitnet/falcon-e-1b.yaml

Khi quá trình huấn luyện kết thúc, hãy nhớ chuyển đổi các checkpoint sang định dạng tam phân bằng hàm quantize_to_1bit từ onebitllms:

onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH

Checkpoint đã lưu tương thích trực tiếp với thư viện transformers của HuggingFace và Apple MLX. Để tạo các tệp GGUF cho llama.cpp, hãy chuyển đổi checkpoint HF bằng tập lệnh convert_hf_to_gguf.py từ kho lưu trữ đó ở định dạng bfloat16 ( llama.cpp sẽ tự động xử lý việc đưa hệ số tỷ lệ vào và chuyển đổi trọng số tam phân sang bfloat16) và sử dụng công cụ llama-quantize để lượng tử hóa các checkpoint bfloat16 sang TQ2_0 hoặc TQ1_0.

Lưu ý: Đối với tinh chỉnh BitNet, hiện tại chỉ hỗ trợ tinh chỉnh toàn bộ (full finetuning) — việc bật LoRA cho các mô hình BitNet vẫn là một lĩnh vực nghiên cứu chưa được khám phá. Việc huấn luyện bất kỳ mô hình tùy ý nào với nó về mặt kỹ thuật là khả thi, nhưng bạn sẽ phải tự chịu rủi ro!

Cách sử dụng Falcon-E-3B-1.2-Exp / Falcon-E-10B-1.2-Exp

Chúng tôi cũng phát hành các checkpoint bfloat16 tiền lượng tử hóa của các mô hình đã được tinh chỉnh (SFT và DPO) để cộng đồng có thể tiếp tục khám phá việc tinh chỉnh thêm. Chỉ cần bật use_onebitllms trong cấu hình Axolotl và bắt đầu.

Định hướng tương lai

Trong tương lai, một lĩnh vực khám phá thú vị sẽ là áp dụng các phương pháp RL on-policy cho các mô hình BitNet. Đối với huấn luyện off-policy (ví dụ: DPO), không có nhiều điều phải lo lắng, tuy nhiên đối với RL on-policy, chúng ta cần một mẹo trong giai đoạn rollout: xấp xỉ mô hình BitNet bằng phiên bản bfloat16 của nó (bao gồm mô hình tam phân với các hệ số tỷ lệ được tích hợp).

Chúng tôi cũng hy vọng thấy được sự hỗ trợ nhiều hơn trong hệ sinh thái GPU, thông qua các khung vận hành như vLLM hoặc sglang cho các kernel BitNet tối ưu hóa trên GPU.

Recommended for You

Cách định hướng Tác nhân AI tiếng Hàn theo nhân khẩu học thực tế với các Persona tổng hợp

Cách định hướng Tác nhân AI tiếng Hàn theo nhân khẩu học thực tế với các Persona tổng hợp

QIMMA قِمّة ⛰- Bảng xếp hạng LLM tiếng Ả Rập ưu tiên chất lượng

QIMMA قِمّة ⛰- Bảng xếp hạng LLM tiếng Ả Rập ưu tiên chất lượng