MedQA- Tinh chỉnh AI lâm sàng trên AMD ROCm — Không yêu cầu CUDA

Tinh chỉnh mô hình AI lâm sàng trên nền tảng AMD ROCm mà không cần sử dụng CUDA

  • 10 min read
MedQA- Tinh chỉnh AI lâm sàng trên AMD ROCm — Không yêu cầu CUDA
Tinh chỉnh mô hình AI lâm sàng trên nền tảng AMD ROCm mà không cần sử dụng CUDA

MedQA: Tinh chỉnh AI Lâm sàng trên AMD ROCm — Không cần CUDA

Hướng dẫn chi tiết về quy trình tinh chỉnh LoRA cho mô hình Qwen3-1.7B trên tập dữ liệu MedMCQA sử dụng AMD MI300X, được xây dựng cho cuộc thi AMD Developer Hackathon trên lablab.ai.


Ý tưởng

Giải đáp câu hỏi y tế là một trong những nhiệm vụ có mức độ rủi ro thực sự cao. Một mô hình tự tin chọn sai đáp án trong một câu hỏi trắc nghiệm lâm sàng không chỉ đơn thuần là sai — mà còn gây nguy hiểm. Đồng thời, hầu hết các công trình AI y tế mã nguồn mở đều giả định rằng bạn có GPU NVIDIA. CUDA là mặc định, và mọi thứ khác chỉ là phụ.

Dự án này thách thức giả định đó.

MedQA là một mô hình giải đáp câu hỏi lâm sàng được tinh chỉnh bằng LoRA, xây dựng hoàn toàn trên phần cứng AMD sử dụng ROCm. Mô hình tiếp nhận một câu hỏi y tế trắc nghiệm và trả về cả chữ cái của đáp án đúng cùng với lời giải thích lâm sàng về lập luận. Toàn bộ quy trình huấn luyện — từ tải dữ liệu đến xuất adapter — đều chạy trên AMD Instinct MI300X mà không phụ thuộc vào bất kỳ thư viện CUDA nào.


Tại sao chọn AMD ROCm?

AMD Instinct MI300X là một phần cứng đáng kinh ngạc: 192 GB bộ nhớ HBM3 trong một thiết bị duy nhất. Đối với việc tinh chỉnh LLM, VRAM thường là rào cản chính — nó quyết định kích thước batch (batch size), độ dài chuỗi và việc bạn có cần lượng tử hóa (quantization) hay không. Với 192 GB khả dụng, chúng tôi đã huấn luyện Qwen3-1.7B với LoRA ở định dạng fp16 đầy đủ mà không cần đến các thủ thuật lượng tử hóa 4-bit hay 8-bit.

Quan trọng hơn, mục tiêu là chứng minh rằng hệ sinh thái HuggingFace — Transformers, PEFT, TRL, Accelerate — hoạt động mượt mà trên ROCm. Và thực tế là như vậy. Cùng một mã huấn luyện chạy trên CUDA cũng sẽ chạy trên ROCm nếu thiết lập ba biến môi trường sau:

os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"

Chỉ vậy thôi. Không thay đổi mã nguồn. Không cần kernel tùy chỉnh. Không cần các lớp tương thích CUDA.


Tập dữ liệu: MedMCQA

MedMCQA là một tập dữ liệu câu hỏi trắc nghiệm quy mô lớn được trích xuất từ các kỳ thi đầu vào y khoa của Ấn Độ (kiểu AIIMS, USMLE). Mỗi ví dụ bao gồm:

  • Một câu hỏi lâm sàng.
  • Bốn lựa chọn trả lời (A–D).
  • Chỉ số của đáp án đúng.
  • Một lời giải thích bằng văn bản tự do tùy chọn (trường exp).

Trong dự án này, chúng tôi sử dụng 2.000 mẫu huấn luyện — một phần nhỏ được chọn chủ đích để chứng minh rằng việc tinh chỉnh có ý nghĩa có thể đạt được một cách nhanh chóng. Quá trình huấn luyện mất khoảng 5 phút trên MI300X.


Mô hình: Qwen3-1.7B

Mô hình cơ sở là Qwen/Qwen3-1.7B — mô hình ngôn ngữ quy mô nhỏ mới nhất của Alibaba. Với 1,7 tỷ tham số, nó đủ nhỏ để tinh chỉnh với chi phí thấp nhưng đủ mạnh để đưa ra các lập luận lâm sàng mạch lạc. Mô hình hỗ trợ trust_remote_code=True và tải mượt mà thông qua HuggingFace Transformers.


Định dạng Prompt

Sự nhất quán trong định dạng prompt là cực kỳ quan trọng đối với tinh chỉnh hướng dẫn (instruction fine-tuning). Mọi ví dụ huấn luyện và mọi lời gọi suy luận đều sử dụng cùng một mẫu:

### Question:
{question}

### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}

### Answer:
{answer_letter}) {answer_text}

### Explanation:
{explanation}

Trong khi huấn luyện, mô hình nhìn thấy toàn bộ chuỗi bao gồm cả đáp án và lời giải thích. Khi suy luận, chúng tôi cung cấp mọi thứ cho đến phần ### Answer:\n và để mô hình tự hoàn thiện phần còn lại.


Huấn luyện với LoRA

Thay vì tinh chỉnh toàn bộ 1,5 tỷ tham số, chúng tôi sử dụng LoRA (Low-Rank Adaptation) thông qua thư viện PEFT. LoRA chèn các ma trận phân rã hạng thấp có thể huấn luyện vào các lớp attention, trong khi giữ nguyên các trọng số cơ sở.

Cấu hình LoRA

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443

Chỉ khoảng 2,2 triệu trong số 1,5 tỷ tham số của mô hình được huấn luyện. Điều này giúp giảm mức sử dụng bộ nhớ và tăng tốc độ huấn luyện.

Đối số huấn luyện (Training Arguments)

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./outputs",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,     # effective batch size = 16
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    report_to="none",
)

Một vài điểm đáng lưu ý:

  • fp16=True, bf16=False: Chúng tôi sử dụng fp16 tiêu chuẩn. Trong các thử nghiệm ban đầu với bfloat16, chúng tôi gặp lỗi NaN loss; việc chuyển sang fp16 đã giải quyết hoàn toàn vấn đề này.
  • gradient_checkpointing=True: Đổi tài nguyên tính toán để lấy bộ nhớ. Điều này không thực sự cần thiết trên MI300X vì có 192 GB VRAM, nhưng là thói quen tốt để đảm bảo khả năng tái lập trên các GPU nhỏ hơn.
  • gradient_accumulation_steps=4: Kích thước batch hiệu dụng là 16 với batch vật lý là 4.
  • Lịch trình LR Cosine với warmup: Giúp hội tụ mượt mà hơn so với lịch trình phẳng cho các đợt huấn luyện ngắn.

Toàn bộ vòng lặp huấn luyện

from transformers import DataCollatorForSeq2Seq, Trainer

collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    pad_to_multiple_of=8,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
)

trainer.train()

# Lưu adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")

Sau khi huấn luyện, thư mục ./outputs sẽ chứa các trọng số của LoRA adapter — chỉ là vài MB tệp thay vì một bản checkpoint mô hình đầy đủ nặng nhiều GB.


Suy luận (Inference)

Khi suy luận, chúng tôi tải mô hình cơ sở, gắn adapter LoRA và tùy chọn hợp nhất (merge) các trọng số:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()

Quá trình tạo văn bản sử dụng giải mã tham lam (do_sample=False) với hình phạt lặp lại (repetition penalty) để ngăn mô hình bị lặp từ:

def generate(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            temperature=1.0,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

    new_tokens = output[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

Ví dụ kết quả đầu ra

Câu hỏi: Điều nào sau đây là phương pháp điều trị hàng đầu cho tình trạng cấp cứu tăng huyết áp (hypertensive emergency)?

A) Amlodipine đường uống
B) Labetalol IV hoặc Nitroprusside IV
C) Nifedipine ngậm dưới lưỡi
D) Hydralazine tiêm bắp

Kết quả từ mô hình:

B) Labetalol IV hoặc Nitroprusside IV

Giải thích:
Labetalol đường tĩnh mạch (thuốc chẹn beta) hoặc nitroprusside giúp giảm nhanh huyết áp trong các tình huống cấp cứu. Các thuốc đường uống tác dụng quá chậm đối với các trường hợp cấp cứu tăng huyết áp cần kiểm soát huyết áp ngay lập tức để ngăn tổn thương cơ quan đích.

Mô hình không chỉ đưa ra một chữ cái — nó giải thích tại sao, điều này làm cho nó trở nên hữu ích trong lâm sàng.


Tải từ HuggingFace Hub

Adapter đã tinh chỉnh được cung cấp công khai. Bạn có thể tải trực tiếp mà không cần clone repo:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-1.7B", trust_remote_code=True
)

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()

Thách thức và Cách khắc phục

Không dự án AMD ROCm nào hoàn tất mà thiếu phần kể về những “trận chiến”. Đây là những gì chúng tôi đã gặp phải:

Thách thức Nguyên nhân gốc rễ Cách khắc phục
NaN loss Độ mất ổn định của độ chính xác hỗn hợp (mixed precision) Chuyển từ bfloat16 $\rightarrow$ fp16
Không phát hiện GPU Thiếu các biến môi trường ROCm Thiết lập ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES, HSA_OVERRIDE_GFX_VERSION
bitsandbytes không hỗ trợ Không có bản build bitsandbytes cho ROCm Loại bỏ hoàn toàn lượng tử hóa — MI300X có đủ VRAM
Kết quả suy luận rác Cấu hình padding của tokenizer sai Thiết lập pad_token = eos_token và sửa padding_side
Lỗi đánh giá của Trainer Sai lệch phiên bản Transformers Cố định phiên bản transformers>=4.40.0

Vấn đề về bitsandbytes đáng được lưu ý: trên phần cứng NVIDIA, lượng tử hóa 4-bit thường là bắt buộc để nhét mô hình vào bộ nhớ. Trên MI300X với 192 GB HBM3, điều này đơn giản là không cần thiết. Đây là một lợi thế phần cứng thực sự — huấn luyện sạch hơn, không có sai số do lượng tử hóa.


Kết quả

Chỉ số Giá trị
Tham số có thể huấn luyện ~2.2M (0.15% tổng số)
Thời gian huấn luyện trên MI300X ~5 phút
Kích thước tập dữ liệu sử dụng 2.000 mẫu
Độ chính xác cơ sở MedMCQA ~45%
Framework PyTorch + ROCm 6.1

Tự trải nghiệm

Không có GPU? Không vấn đề gì. Bản demo Gradio trực tiếp chạy trên HuggingFace Spaces (suy luận bằng CPU):

👉 Demo trực tiếp trên HuggingFace Spaces

Có phần cứng AMD? Clone repo và chạy trực tiếp:

git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py   # ~5 phút
python infer.py   # chạy các câu hỏi mẫu
python app.py     # khởi chạy giao diện Gradio

Bước tiếp theo

Dự án này chứng minh quy trình hoạt động ổn định. Các bước tiếp theo là mở rộng và hoàn thiện:

  • Tập dữ liệu lớn hơn — Huấn luyện trên toàn bộ kho dữ liệu MedMCQA (~180k câu hỏi) và thêm PubMedQA.
  • Chấm điểm độ tin cậy — Thêm các ước tính độ tin cậy đã được hiệu chuẩn bên cạnh các câu trả lời.
  • Tích hợp RAG — Căn cứ câu trả lời vào việc truy xuất tài liệu y khoa thời gian thực.
  • Bộ khung đánh giá — Xây dựng benchmark độ chính xác trên tập dữ liệu độc lập (held-out) thay vì chỉ dùng tập huấn luyện.

Kết luận

MedQA cho thấy việc xây dựng một AI y tế có năng lực và có khả năng giải thích trên phần cứng mã nguồn mở của AMD không chỉ khả thi mà còn đơn giản. Khả năng tương thích ROCm của hệ sinh thái HuggingFace thực sự tốt. Dung lượng bộ nhớ của MI300X loại bỏ hoàn toàn một nhóm các vấn đề kỹ thuật phức tạp. Và LoRA giúp việc tinh chỉnh một mô hình 1.7B trở thành công việc chỉ mất 5 phút.

Nếu bạn đang xây dựng trên AMD ROCm và gặp khó khăn, những cách khắc phục trên sẽ giúp bạn tiết kiệm nhiều giờ làm việc. Và nếu bạn xây dựng AI y tế, việc nhấn mạnh vào “lời giải thích” thay vì chỉ là “độ chính xác thuần túy” là điều cực kỳ quan trọng.


Xây dựng cho cuộc thi AMD Developer Hackathon trên lablab.ai · Hỗ trợ bởi AMD ROCm + Hệ sinh thái HuggingFace

— Harikrishna Sivanand Iyer và Srijan Sivaram A

Recommended for You

Cải thiện độ bền vững của Depth Anything V2 đối với nén video

Cải thiện độ bền vững của Depth Anything V2 đối với nén video

Cải thiện độ bền vững của Depth Anything V2 đối với nén video

Học Toán lần cuối cùng

Học Toán lần cuối cùng

Học Toán lần cuối cùng