Tinh chỉnh FLUX.1-dev (LoRA) trên phần cứng tiêu dùng

Tinh chỉnh FLUX.1-dev bằng LoRA trên phần cứng tiêu dùng.

  • 11 min read
Tinh chỉnh FLUX.1-dev (LoRA) trên phần cứng tiêu dùng
Tinh chỉnh FLUX.1-dev bằng LoRA trên phần cứng tiêu dùng.

(LoRA) Tinh chỉnh FLUX.1-dev trên phần cứng tiêu dùng

Trong bài đăng trước của chúng tôi, Khám phá phần phụ trợ lượng tử hóa trong Diffusers, chúng tôi đã đi sâu vào cách các kỹ thuật lượng tử hóa khác nhau có thể thu nhỏ các mô hình khuếch tán như FLUX.1-dev, giúp chúng dễ tiếp cận hơn đáng kể để suy luận mà không ảnh hưởng đáng kể đến hiệu suất. Chúng ta đã thấy cách bitsandbytes, torchao và các công cụ khác làm giảm dấu chân bộ nhớ để tạo hình ảnh.

Thực hiện suy luận là rất tuyệt, nhưng để làm cho các mô hình này thực sự là của riêng mình, chúng ta cũng cần có khả năng tinh chỉnh chúng. Do đó, trong bài đăng này, chúng tôi giải quyết vấn đề tinh chỉnh hiệu quả các mô hình này với mức sử dụng bộ nhớ đỉnh dưới ~10 GB VRAM trên một GPU duy nhất. Bài đăng này sẽ hướng dẫn bạn tinh chỉnh FLUX.1-dev bằng QLoRA với thư viện diffusers. Chúng tôi sẽ giới thiệu kết quả từ NVIDIA RTX 4090. Chúng tôi cũng sẽ nêu bật cách đào tạo FP8 với torchao có thể tối ưu hóa hơn nữa tốc độ trên phần cứng tương thích.

Mục lục

Bộ dữ liệu

Chúng tôi muốn tinh chỉnh black-forest-labs/FLUX.1-dev để áp dụng phong cách nghệ thuật của Alphonse Mucha, sử dụng một bộ dữ liệu nhỏ.

Kiến trúc FLUX

Mô hình bao gồm ba thành phần chính:

  • Bộ mã hóa văn bản (CLIP và T5)
  • Transformer (Mô hình chính - Flux Transformer)
  • Bộ mã hóa tự động biến thiên (VAE)

Trong phương pháp QLoRA của chúng tôi, chúng tôi tập trung độc quyền vào việc tinh chỉnh thành phần transformer. Bộ mã hóa văn bản và VAE vẫn bị đóng băng trong suốt quá trình đào tạo.

QLoRA Tinh chỉnh FLUX.1-dev với diffusers

Chúng tôi đã sử dụng tập lệnh đào tạo diffusers (được sửa đổi một chút từ https://github.com/huggingface/diffusers/blob/main/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py) được thiết kế để tinh chỉnh LoRA theo phong cách DreamBooth của các mô hình FLUX. Ngoài ra, một phiên bản rút gọn để tái tạo kết quả trong bài đăng trên blog này (và được sử dụng trong Google Colab) có sẵn tại đây. Hãy cùng xem xét các phần quan trọng cho QLoRA và hiệu quả bộ nhớ:

Các kỹ thuật tối ưu hóa chính

LoRA (Điều chỉnh thứ hạng thấp) Deep Dive: LoRA giúp việc đào tạo mô hình hiệu quả hơn bằng cách theo dõi các cập nhật trọng số với các ma trận thứ hạng thấp. Thay vì cập nhật ma trận trọng số đầy đủ W W W, LoRA tìm hiểu hai ma trận nhỏ hơn A A AB B B. Bản cập nhật cho trọng số cho mô hình là ΔW=BA \Delta W = B A ΔW=BA, trong đó ARr×k A \in \mathbb{R}^{r \times k} ARr×kBRd×r B \in \mathbb{R}^{d \times r} BRd×r. Số r r r (gọi là hạng) nhỏ hơn nhiều so với các kích thước ban đầu, có nghĩa là ít tham số để cập nhật hơn. Cuối cùng, α \alpha α là một hệ số tỷ lệ cho các kích hoạt LoRA. Điều này ảnh hưởng đến mức độ ảnh hưởng của LoRA đến các bản cập nhật và thường được đặt thành cùng giá trị với r r r hoặc bội số của nó. Nó giúp cân bằng ảnh hưởng của mô hình được đào tạo trước và bộ điều hợp LoRA. Để biết giới thiệu chung về khái niệm này, hãy xem bài đăng trên blog trước của chúng tôi: Sử dụng LoRA để tinh chỉnh Stable Diffusion hiệu quả.

QLoRA: Nhà máy điện hiệu quả: QLoRA tăng cường LoRA bằng cách tải mô hình cơ sở được đào tạo trước ở định dạng lượng tử hóa trước (thường là 4-bit thông qua bitsandbytes), cắt giảm đáng kể dấu chân bộ nhớ của mô hình cơ sở. Sau đó, nó đào tạo bộ điều hợp LoRA (thường là ở FP16/BF16) trên đầu cơ sở lượng tử hóa này. Điều này làm giảm đáng kể VRAM cần thiết để giữ mô hình cơ sở.

Ví dụ: trong tập lệnh đào tạo DreamBooth cho HiDream lượng tử hóa 4-bit với bitsandbytes làm giảm mức sử dụng bộ nhớ đỉnh của tinh chỉnh LoRA từ ~60GB xuống ~37GB mà không làm giảm chất lượng. Nguyên tắc tương tự là những gì chúng ta áp dụng ở đây để tinh chỉnh FLUX.1 trên phần cứng cấp tiêu dùng.

Trình tối ưu hóa 8-bit (AdamW):

Trình tối ưu hóa AdamW tiêu chuẩn duy trì ước tính thời điểm đầu tiên và thứ hai cho mỗi tham số ở 32-bit (FP32), tiêu tốn rất nhiều bộ nhớ. AdamW 8-bit sử dụng lượng tử hóa theo khối để lưu trữ trạng thái trình tối ưu hóa với độ chính xác 8-bit, đồng thời duy trì tính ổn định của quá trình đào tạo. Kỹ thuật này có thể giảm mức sử dụng bộ nhớ trình tối ưu hóa ~75% so với FP32 AdamW tiêu chuẩn. Bật nó trong tập lệnh rất đơn giản:

# Kiểm tra cờ --use_8bit_adam
if args.use_8bit_adam:
    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

optimizer = optimizer_class(
    params_to_optimize,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

Kiểm tra gradient:

Trong quá trình chuyển tiếp, các kích hoạt trung gian thường được lưu trữ để tính toán gradient chuyển ngược. Kiểm tra điểm gradient đánh đổi tính toán cho bộ nhớ bằng cách chỉ lưu trữ một số kích hoạt điểm kiểm tra nhất định và tính toán lại các kích hoạt khác trong quá trình lan truyền ngược.

if args.gradient_checkpointing:
    transformer.enable_gradient_checkpointing()

Bộ nhớ đệm tiềm ẩn:

Kỹ thuật tối ưu hóa này xử lý trước tất cả hình ảnh đào tạo thông qua bộ mã hóa VAE trước khi bắt đầu đào tạo. Nó lưu trữ các biểu diễn tiềm ẩn kết quả trong bộ nhớ. Trong quá trình đào tạo, thay vì mã hóa hình ảnh một cách nhanh chóng, các tiềm ẩn được lưu trong bộ nhớ đệm được sử dụng trực tiếp. Cách tiếp cận này mang lại hai lợi ích chính:

  1. loại bỏ các phép tính mã hóa VAE dư thừa trong quá trình đào tạo, tăng tốc mỗi bước đào tạo
  2. cho phép loại bỏ hoàn toàn VAE khỏi bộ nhớ GPU sau khi lưu vào bộ nhớ đệm. Sự đánh đổi là tăng mức sử dụng RAM để lưu trữ tất cả các tiềm ẩn được lưu trong bộ nhớ đệm, nhưng điều này thường có thể quản lý được đối với các bộ dữ liệu nhỏ.
# Bộ nhớ đệm tiềm ẩn trước khi đào tạo nếu cờ được đặt
    if args.cache_latents:
        latents_cache = []
        for batch in tqdm(train_dataloader, desc="Caching latents"):
            with torch.no_grad():
                batch["pixel_values"] = batch["pixel_values"].to(
                    accelerator.device, non_blocking=True, dtype=weight_dtype
                )
                latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
        # VAE không còn cần thiết nữa, giải phóng bộ nhớ của nó
        del vae
        free_memory()

Thiết lập lượng tử hóa 4-bit (BitsAndBytesConfig):

Phần này trình bày cấu hình QLoRA cho mô hình cơ sở:

# Xác định kiểu dữ liệu tính toán dựa trên độ chính xác hỗn hợp
bnb_4bit_compute_dtype = torch.float32
if args.mixed_precision == "fp16":
    bnb_4bit_compute_dtype = torch.float16
elif args.mixed_precision == "bf16":
    bnb_4bit_compute_dtype = torch.bfloat16

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
)

transformer = FluxTransformer2DModel.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="transformer",
    quantization_config=nf4_config,
    torch_dtype=bnb_4bit_compute_dtype,
)
# Chuẩn bị mô hình cho đào tạo k-bit
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
# Kiểm tra gradient được bật sau đó thông qua transformer.enable_gradient_checkpointing() nếu arg được đặt

Xác định cấu hình LoRA (LoraConfig): Bộ điều hợp được thêm vào transformer lượng tử hóa:

transformer_lora_config = LoraConfig(
    r=args.rank,
    lora_alpha=args.rank, 
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"], # Các khối chú ý FLUX
)
transformer.add_adapter(transformer_lora_config)
print(f"tham số có thể đào tạo: {transformer.num_parameters(only_trainable=True)} || tất cả tham số: {transformer.num_parameters()}")
# các tham số có thể đào tạo: 4.669.440 || tất cả các tham số: 11.906.077.760

Chỉ những tham số LoRA này mới có thể đào tạo được.

Tính toán trước các nhúng văn bản (CLIP/T5)

Trước khi khởi chạy tinh chỉnh QLoRA, chúng ta có thể tiết kiệm một lượng lớn VRAM và thời gian thực bằng cách lưu vào bộ nhớ đệm các đầu ra của bộ mã hóa văn bản một lần.

Tại thời điểm đào tạo, trình tải dữ liệu chỉ cần đọc các nhúng được lưu trong bộ nhớ đệm thay vì mã hóa lại chú thích, vì vậy bộ mã hóa CLIP/T5 không bao giờ phải nằm trong bộ nhớ GPU.

# https://github.com/huggingface/diffusers/blob/main/examples/research_projects/flux_lora_quantization/compute_embeddings.py
import argparse

import pandas as pd
import torch
from datasets import load_dataset
from huggingface_hub.utils import insecure_hashlib
from tqdm.auto import tqdm
from transformers import T5EncoderModel

from diffusers import FluxPipeline


MAX_SEQ_LENGTH = 77
OUTPUT_PATH = "embeddings.parquet"


def generate_image_hash(image):
    return insecure_hashlib.sha256(image.tobytes()).hexdigest()


def load_flux_dev_pipeline():
    id = "black-forest-labs/FLUX.1-dev"
    text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
    pipeline = FluxPipeline.from_pretrained(
        id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
    )
    return pipeline


@torch.no_grad()
def compute_embeddings(pipeline, prompts, max_sequence_length):
    all_prompt_embeds = []
    all_pooled_prompt_embeds = []
    all_text_ids = []
    for prompt in tqdm(prompts, desc="Encoding prompts."):
        (
            prompt_embeds,
            pooled_prompt_embeds,
            text_ids,
        ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
        all_prompt_embeds.append(prompt_embeds)
        all_pooled_prompt_embeds.append(pooled_prompt_embeds)
        all_text_ids.append(text_ids)

    max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
    print(f"Bộ nhớ tối đa được phân bổ: {max_memory:.3f} GB")
    return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids


def run(args):
    dataset = load_dataset("Norod78/Yarn-art-style", split="train")
    image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
    all_prompts = list(image_prompts.values())
    print(f"{len(all_prompts)=}")

    pipeline = load_flux_dev_pipeline()
    all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
        pipeline, all_prompts, args.max_sequence_length
    )

    data = []
    for i, (image_hash, _) in enumerate(image_prompts.items()):
        data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
    print(f"{len(data)=}")

    # Tạo DataFrame
    embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
    df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
    print(f"{len(df)=}")

    # Chuyển đổi danh sách nhúng thành mảng (để lưu trữ đúng cách trong parquet)
    for col in embedding_cols:
        df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())

    # Lưu dataframe vào một tệp parquet
    df.to_parquet(args.output_path)
    print(f"Dữ liệu đã được tuần tự hóa thành công thành {args.output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--max_sequence_length",
        type=int,
        default=MAX_SEQ_LENGTH,
        help="Độ dài chuỗi tối đa để sử dụng để tính toán các nhúng. Càng nhiều chi phí tính toán càng cao.",
    )
    parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Đường dẫn để tuần tự hóa tệp parquet.")
    args = parser.parse_args()

    run(args)

Cách sử dụng

python compute_embeddings.py \
  --max_sequence_length 77 \
  --output_path embeddings_alphonse_mucha.parquet

Bằng cách kết hợp điều này với các tiềm ẩn VAE được lưu trong bộ nhớ đệm (--cache_latents), bạn giảm mô hình hoạt động xuống chỉ còn transformer được lượng tử hóa + bộ điều hợp LoRA, giữ cho toàn bộ tinh chỉnh thoải mái dưới 10 GB bộ nhớ GPU.

Thiết lập & Kết quả

Để trình diễn này, chúng tôi đã tận dụng NVIDIA RTX 4090 (24GB VRAM) để khám phá hiệu suất của nó. Lệnh đào tạo đầy đủ sử dụng accelerate được hiển thị bên dưới.

# Bạn cần tính toán trước các nhúng văn bản trước. Xem kho lưu trữ diffusers.
# https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization
accelerate launch --config_file=accelerate.yaml \
  train_dreambooth_lora_flux_miniature.py \
  --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
  --data_df_path="embeddings_alphonse_mucha.parquet" \
  --output_dir="alphonse_mucha_lora_flux_nf4" \
  --mixed_precision="bf16" \
  --use_8bit_adam \
  --weighting_scheme="none" \
  --width=512 \
  --height=768 \
  --train_batch_size=1 \
  --repeats=1 \
  --learning_rate=1e-4 \
  --guidance_scale=1 \
  --report_to="wandb" \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \ # có thể bỏ kiểm tra điểm khi HW có hơn 16 GB.
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --cache_latents \
  --rank=4 \
  --max_train_steps=700 \
  --seed="0"

Cấu hình cho RTX 4090:

Trên RTX 4090 của chúng tôi, chúng tôi đã sử dụng train_batch_size là 1, gradient_accumulation_steps là 4, mixed_precision="bf16", gradient_checkpointing=True, use_8bit_adam=True, LoRA rank là 4 và độ phân giải 512x768. Các tiềm ẩn đã được lưu trong bộ nhớ đệm với cache_latents=True.

Dấu chân bộ nhớ (RTX 4090):

  • QLoRA: Mức sử dụng VRAM đỉnh cho tinh chỉnh QLoRA là khoảng 9GB.
  • BF16 LoRA: Chạy LoRA tiêu chuẩn (với FLUX.1-dev cơ sở ở FP16) trên cùng một thiết lập đã tiêu thụ 26 GB VRAM.
  • BF16 tinh chỉnh đầy đủ: Ước tính sẽ là ~120 GB VRAM mà không có tối ưu hóa bộ nhớ.

Thời gian đào tạo (RTX 4090):

Tinh chỉnh trong 700 bước trên bộ dữ liệu Alphonse Mucha mất khoảng 41 phút trên RTX 4090 với train_batch_size là 1 và độ phân giải là 512x768.

Chất lượng đầu ra:

Thước đo cuối cùng là nghệ thuật được tạo ra. Dưới đây là các mẫu từ mô hình được tinh chỉnh QLoRA của chúng tôi trên bộ dữ liệu derekl35/alphonse-mucha-style:

Bảng này so sánh các kết quả chính xác bf16 chính. Mục tiêu của việc tinh chỉnh là dạy cho mô hình phong cách riêng biệt của Alphonse Mucha.

Mô hình được tinh chỉnh đã nắm bắt được phong cách nghệ thuật nouveau mang tính biểu tượng của Mucha một cách độc đáo, thể hiện rõ ở các họa tiết trang trí và bảng màu đặc biệt. Quy trình QLoRA duy trì độ trung thực tuyệt vời đồng thời tìm hiểu phong cách mới.

Kết quả gần như giống hệt nhau, cho thấy QLoRA hoạt động hiệu quả với cả độ chính xác hỗn hợp fp16bf16.

So sánh mô hình: Cơ sở so với Q

Recommended for You

CodeAgents + Cấu trúc- Một Cách Tốt Hơn để Thực Hiện Hành Động

CodeAgents + Cấu trúc- Một Cách Tốt Hơn để Thực Hiện Hành Động

Đánh giá các tác nhân GUI của bạn một cách dễ dàng!

Groq trên các nhà cung cấp suy luận Hugging Face 🔥

Groq trên các nhà cung cấp suy luận Hugging Face 🔥

Giới thiệu Groq trên Hugging Face Inference Providers