Đào tạo và tinh chỉnh các mô hình Reranker với Sentence Transformers v4
Bài viết này hướng dẫn cách đào tạo và tinh chỉnh các mô hình Reranker với Sentence Transformers v4.
- 14 min read
Đào tạo và Tinh chỉnh Mô hình Reranker với Sentence Transformers v4
Sentence Transformers là một thư viện Python để sử dụng và đào tạo các mô hình embedding và reranker cho một loạt các ứng dụng, chẳng hạn như tạo sinh tăng cường truy xuất, tìm kiếm ngữ nghĩa, tương đồng văn bản ngữ nghĩa, khai thác diễn giải, v.v. Bản cập nhật v4.0 của nó giới thiệu một phương pháp đào tạo mới cho reranker, còn được gọi là mô hình cross-encoder, tương tự như bản cập nhật v3.0 đã giới thiệu cho mô hình embedding. Trong bài đăng trên blog này, tôi sẽ chỉ cho bạn cách sử dụng nó để tinh chỉnh một mô hình reranker đánh bại tất cả các tùy chọn hiện có trên chính xác dữ liệu của bạn. Phương pháp này cũng có thể đào tạo các mô hình reranker mới cực kỳ mạnh mẽ từ đầu.
Việc tinh chỉnh các mô hình reranker bao gồm một số thành phần: tập dữ liệu, hàm mất mát, đối số đào tạo, bộ đánh giá và chính lớp huấn luyện. Tôi sẽ xem xét từng thành phần này, kèm theo các ví dụ thực tế về cách chúng có thể được sử dụng để tinh chỉnh các mô hình reranker mạnh mẽ.
Cuối cùng, trong phần Đánh giá, tôi sẽ chỉ cho bạn rằng mô hình reranker nhỏ đã được tinh chỉnh tomaarsen/reranker-ModernBERT-base-gooaq-bce mà tôi đã đào tạo cùng với bài đăng trên blog này dễ dàng vượt trội hơn 13 mô hình reranker công khai được sử dụng phổ biến nhất trên tập dữ liệu đánh giá của tôi. Nó thậm chí còn đánh bại các mô hình lớn hơn 4 lần.
Lặp lại công thức với một mô hình cơ sở lớn hơn sẽ tạo ra tomaarsen/reranker-ModernBERT-large-gooaq-bce, một mô hình reranker thổi bay tất cả các mô hình reranker đa năng hiện có trên dữ liệu của tôi.

Nếu bạn quan tâm đến việc tinh chỉnh các mô hình embedding thay vào đó, thì hãy cân nhắc đọc qua bài đăng trên blog trước đây của tôi Đào tạo và Tinh chỉnh Mô hình Embedding với Sentence Transformers v3.
Mục lục
- Mô hình Reranker là gì?
- Tại sao cần tinh chỉnh?
- Các thành phần đào tạo
- Tập dữ liệu
- Hàm mất mát
- Đối số đào tạo
- Bộ đánh giá
- Huấn luyện viên
- Mẹo đào tạo
- Đánh giá
- Tài nguyên bổ sung
Mô hình Reranker là gì?
Các mô hình Reranker, thường được triển khai bằng kiến trúc Cross Encoder, được thiết kế để đánh giá mức độ liên quan giữa các cặp văn bản (ví dụ: truy vấn và tài liệu hoặc hai câu). Không giống như Sentence Transformers (hay còn gọi là bi-encoder, mô hình embedding), nhúng độc lập từng văn bản vào vectơ và tính toán độ tương đồng thông qua một số liệu khoảng cách, Cross Encoder xử lý các văn bản được ghép nối cùng nhau thông qua một mạng nơ-ron dùng chung, dẫn đến một điểm số đầu ra. Bằng cách cho phép hai văn bản chú ý lẫn nhau, các mô hình Cross Encoder có thể vượt trội hơn các mô hình embedding.
Tuy nhiên, sức mạnh này đi kèm với sự đánh đổi: Các mô hình Cross Encoder chậm hơn vì chúng xử lý mọi cặp văn bản có thể (ví dụ: 10 truy vấn với 500 tài liệu ứng viên yêu cầu 5.000 phép tính thay vì 510 cho mô hình embedding). Điều này làm cho chúng kém hiệu quả hơn cho việc truy xuất ban đầu quy mô lớn nhưng lý tưởng cho việc reranking: tinh chỉnh các kết quả top-k được xác định đầu tiên bởi các mô hình Sentence Transformer nhanh hơn. Các hệ thống tìm kiếm mạnh nhất thường sử dụng phương pháp “truy xuất và rerank” hai giai đoạn này.

Trong suốt bài đăng trên blog này, tôi sẽ sử dụng “mô hình reranker” và “mô hình Cross Encoder” thay thế cho nhau.
Tại sao cần tinh chỉnh?
Các mô hình Reranker thường được giao một vấn đề đầy thách thức:
Trong số k tài liệu có liên quan cao này, tài liệu nào trả lời truy vấn tốt nhất?
Các mô hình reranker đa năng được đào tạo để thực hiện đầy đủ câu hỏi chính xác này trong một loạt các lĩnh vực và chủ đề, ngăn chúng đạt được tiềm năng tối đa trong lĩnh vực cụ thể của bạn. Thông qua tinh chỉnh, mô hình có thể học cách tập trung hoàn toàn vào lĩnh vực và/hoặc ngôn ngữ quan trọng đối với bạn.
Trong phần Đánh giá ở cuối bài đăng trên blog này, tôi sẽ chỉ ra rằng việc đào tạo một mô hình trên miền của bạn có thể vượt trội so với bất kỳ mô hình reranker đa năng nào, ngay cả khi các đường cơ sở đó lớn hơn nhiều. Đừng đánh giá thấp sức mạnh của việc tinh chỉnh trên miền của bạn!
Các thành phần đào tạo
Đào tạo các mô hình reranker liên quan đến các thành phần sau:
- Tập dữ liệu: Dữ liệu được sử dụng để đào tạo và/hoặc đánh giá.
- Hàm mất mát: Một hàm đo lường hiệu suất của mô hình và hướng dẫn quá trình tối ưu hóa.
- Đối số đào tạo (tùy chọn): Các tham số có tác động đến hiệu suất đào tạo, theo dõi và gỡ lỗi.
- Bộ đánh giá (tùy chọn): Một lớp để đánh giá mô hình trước, trong hoặc sau khi đào tạo.
- Huấn luyện viên: Tập hợp tất cả các thành phần đào tạo lại với nhau.
Hãy xem xét kỹ hơn từng thành phần.
Tập dữ liệu
CrossEncoderTrainer sử dụng các phiên bản datasets.Dataset hoặc datasets.DatasetDict để đào tạo và đánh giá. Bạn có thể tải dữ liệu từ Hugging Face Datasets Hub hoặc sử dụng dữ liệu cục bộ của bạn ở bất kỳ định dạng nào bạn thích (ví dụ: CSV, JSON, Parquet, Arrow hoặc SQL).
Lưu ý: Rất nhiều tập dữ liệu công khai hoạt động ngay lập tức với Sentence Transformers đã được gắn thẻ sentence-transformers trên Hugging Face Hub, vì vậy bạn có thể dễ dàng tìm thấy chúng trên https://huggingface.co/datasets?other=sentence-transformers. Hãy cân nhắc duyệt qua những thứ này để tìm các tập dữ liệu sẵn sàng có thể hữu ích cho các tác vụ, miền hoặc ngôn ngữ của bạn.
Dữ liệu trên Hugging Face Hub
Bạn có thể sử dụng hàm load_dataset để tải dữ liệu từ tập dữ liệu trong Hugging Face Hub
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
features: ['query', 'answer'],
num_rows: 100231
})
"""
Một số tập dữ liệu, như nthakur/swim-ir-monolingual, có nhiều tập hợp con với các định dạng dữ liệu khác nhau. Bạn cần chỉ định tên tập hợp con cùng với tên tập dữ liệu, ví dụ: dataset = load_dataset("nthakur/swim-ir-monolingual", "de", split="train").
Dữ liệu cục bộ (CSV, JSON, Parquet, Arrow, SQL)
Bạn cũng có thể sử dụng load_dataset để tải dữ liệu cục bộ ở một số định dạng tệp nhất định:
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
# hoặc
dataset = load_dataset("json", data_files="my_file.json")
Dữ liệu cục bộ yêu cầu tiền xử lý
Bạn có thể sử dụng datasets.Dataset.from_dict nếu dữ liệu cục bộ của bạn yêu cầu tiền xử lý. Điều này cho phép bạn khởi tạo tập dữ liệu của mình bằng một từ điển các danh sách:
from datasets import Dataset
queries = []
documents = []
# Mở một tệp, thực hiện tiền xử lý, lọc, làm sạch, v.v.
# và thêm vào danh sách
dataset = Dataset.from_dict({
"query": queries,
"document": documents,
})
Mỗi khóa trong từ điển trở thành một cột trong tập dữ liệu kết quả.
Định dạng tập dữ liệu
Điều quan trọng là định dạng tập dữ liệu của bạn phù hợp với hàm mất mát của bạn (hoặc bạn chọn một hàm mất mát phù hợp với định dạng tập dữ liệu và mô hình của bạn). Xác minh xem định dạng tập dữ liệu và mô hình có hoạt động với hàm mất mát hay không bao gồm ba bước:
- Tất cả các cột không có tên là “label”, “labels”, “score” hoặc “scores” được coi là Đầu vào theo bảng Tổng quan về Mất mát. Số lượng cột còn lại phải khớp với số lượng đầu vào hợp lệ cho tổn thất đã chọn của bạn.
- Nếu hàm mất mát của bạn yêu cầu Nhãn theo bảng Tổng quan về Mất mát, thì tập dữ liệu của bạn phải có một cột có tên là “label”, “labels”, “score” hoặc “scores”. Cột này tự động được lấy làm nhãn.
- Số lượng nhãn đầu ra mô hình khớp với những gì được yêu cầu cho tổn thất theo bảng Tổng quan về Mất mát.
Ví dụ: cho một tập dữ liệu có các cột ["text1", "text2", "label"] trong đó cột “label” có điểm tương đồng float dao động từ 0 đến 1 và một mô hình xuất ra 1 nhãn, chúng ta có thể sử dụng nó với BinaryCrossEntropyLoss vì:
- tập dữ liệu có một cột “label” như được yêu cầu cho hàm mất mát này.
- tập dữ liệu có 2 cột không phải nhãn, chính xác số lượng được yêu cầu bởi hàm mất mát này.
- mô hình có 1 nhãn đầu ra, chính xác như yêu cầu của hàm mất mát này.
Hãy đảm bảo sắp xếp lại các cột tập dữ liệu của bạn bằng Dataset.select_columns nếu các cột của bạn không được sắp xếp chính xác. Ví dụ: nếu tập dữ liệu của bạn có ["good_answer", "bad_answer", "question"] làm cột, thì tập dữ liệu này về mặt kỹ thuật có thể được sử dụng với một tổn thất yêu cầu bộ ba (neo, dương tính, âm tính), nhưng cột good_answer sẽ được coi là neo, bad_answer là dương tính và question là âm tính.
Ngoài ra, nếu tập dữ liệu của bạn có các cột không liên quan (ví dụ: sample_id, siêu dữ liệu, nguồn, loại), bạn nên xóa chúng bằng Dataset.remove_columns vì chúng sẽ được sử dụng làm đầu vào nếu không. Bạn cũng có thể sử dụng Dataset.select_columns để chỉ giữ lại các cột mong muốn.
Khai thác âm bản khó
Sự thành công của việc đào tạo các mô hình reranker thường phụ thuộc vào chất lượng của các âm bản, tức là các đoạn mà điểm số truy vấn-âm bản phải thấp. Các âm bản có thể được chia thành hai loại:
- Âm bản mềm: các đoạn hoàn toàn không liên quan. Còn được gọi là âm bản dễ.
- Âm bản khó: các đoạn có vẻ như có thể liên quan đến truy vấn, nhưng không phải vậy.
Một ví dụ ngắn gọn là:
- Truy vấn: Apple được thành lập ở đâu?
- Âm bản mềm: Cầu Cache River là một giàn cầu pony Parker bắc qua Sông Cache giữa Walnut Ridge và Paragould, Arkansas.
- Âm bản khó: Táo Fuji là một giống táo được phát triển vào cuối những năm 1930 và được đưa ra thị trường vào năm 1962.
Các mô hình CrossEncoder mạnh nhất thường được đào tạo để nhận ra các âm bản khó, và vì vậy, có giá trị để có thể “khai thác” các âm bản khó để đào tạo cùng. Sentence Transformers hỗ trợ một hàm mine_hard_negatives mạnh mẽ có thể hỗ trợ, cho một tập dữ liệu các cặp truy vấn-trả lời:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives
# Tải tập dữ liệu GooAQ: https://huggingface.co/datasets/sentence-transformers/gooaq
train_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
print(train_dataset)
# Khai thác các âm bản khó bằng cách sử dụng một mô hình nhúng rất hiệu quả
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=5, # Có bao nhiêu âm bản trên mỗi cặp câu hỏi-trả lời
range_min=10, # Bỏ qua x mẫu tương tự nhất
range_max=100, # Chỉ xem xét x mẫu tương tự nhất
max_score=0.8, # Chỉ xem xét các mẫu có điểm tương đồng tối đa là x
margin=0.1, # Điểm tương đồng giữa truy vấn và mẫu âm bản phải thấp hơn x so với điểm tương đồng truy vấn-dương tính
sampling_strategy="top", # Lấy mẫu ngẫu nhiên âm bản từ phạm vi
batch_size=4096, # Sử dụng kích thước lô là 4096 cho mô hình nhúng
output_format="labeled-pair", # Định dạng đầu ra là (truy vấn, đoạn văn, nhãn), như được yêu cầu bởi BinaryCrossEntropyLoss
use_faiss=True, # Nên sử dụng FAISS để giữ cho mức sử dụng bộ nhớ thấp (pip install faiss-gpu hoặc pip install faiss-cpu)
)
print(hard_train_dataset)
print(hard_train_dataset[1])
Xem kết quả của tập lệnh này.
Dataset({
features: ['question', 'answer'],
num_rows: 100000
})
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 13.74it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 36.49it/s]
Querying FAISS index: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:19<00:00, 2.80s/it]
Metric Positive Negative Difference
Count 100,000 436,925
Mean 0.5882 0.4040 0.2157
Median 0.5989 0.4024 0.1836
Std 0.1425 0.0905 0.1013
Min -0.0514 0.1405 0.1014
25% 0.4993 0.3377 0.1352
50% 0.5989 0.4024 0.1836
75% 0.6888 0.4681 0.2699
Max 0.9748 0.7486 0.7545
Skipped 2420871 potential negatives (23.97%) due to the margin of 0.1.
Skipped 43 potential negatives (0.00%) due to the maximum score of 0.8.
Could not find enough negatives for 63075 samples (12.62%). Consider adjusting the range_max, range_min, margin and max_score parameters if you'd like to find more valid negatives.
Dataset({
features: ['question', 'answer', 'label'],
num_rows: 536925
})
{
'question': 'làm thế nào để chuyển dấu trang từ máy tính xách tay này sang máy tính xách tay khác?',
'answer': 'Sử dụng ổ đĩa ngoài Hầu như bất kỳ ổ đĩa ngoài nào, bao gồm cả ổ USB thumb hoặc thẻ SD đều có thể được sử dụng để chuyển tệp của bạn từ máy tính xách tay này sang máy tính xách tay khác. Kết nối ổ đĩa với máy tính xách tay cũ của bạn; kéo tệp của bạn vào ổ đĩa, sau đó ngắt kết nối nó và chuyển nội dung ổ đĩa vào máy tính xách tay mới của bạn.',
'label': 0
}
Hàm mất mát
Các hàm mất mát giúp đánh giá hiệu suất của mô hình trên một tập dữ liệu và hướng dẫn quá trình đào tạo. Hàm mất mát phù hợp cho tác vụ của bạn phụ thuộc vào dữ liệu bạn có và những gì bạn đang cố gắng đạt được. Bạn có thể tìm thấy danh sách đầy đủ các hàm mất mát có sẵn trong Tổng quan về Mất mát.
Hầu hết các hàm mất mát đều dễ thiết lập - bạn chỉ cần cung cấp mô hình CrossEncoder bạn đang đào tạo:
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss
# Tải một mô hình để đào tạo/tinh chỉnh
model = CrossEncoder("xlm-roberta-base", num_labels=1) # num_labels=1 dành cho reranker
# Khởi tạo CachedMultipleNegativesRankingLoss, yêu cầu các cặp
# văn bản hoặc bộ ba có liên quan
loss = CachedMultipleNegativesRankingLoss(model)
# Tải một tập dữ liệu đào tạo ví dụ hoạt động với hàm mất mát của chúng tôi:
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")
...
Đối số đào tạo
Bạn có thể tùy chỉnh quá trình đào tạo bằng cách sử dụng lớp CrossEncoderTrainingArguments. Lớp này cho phép bạn điều chỉnh các tham số có thể tác động đến tốc độ đào tạo và giúp bạn hiểu điều gì đang xảy ra trong quá trình đào tạo.
Để biết thêm thông tin về các đối số đào tạo hữu ích nhất, hãy xem Tổng quan về Đào tạo Cross Encoder > Đối số Đào tạo. Rất đáng để đọc để tận dụng tối đa quá trình đào tạo của bạn.
Đây là một ví dụ về cách thiết lập CrossEncoderTrainingArguments:
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
args = CrossEncoderTrainingArguments(
# Tham số bắt buộc:
output_dir="models/reranker-MiniLM-msmarco-v1",
# Tham số đào tạo tùy chọn:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Đặt thành False nếu bạn gặp lỗi rằng GPU của bạn không thể chạy trên FP16
bf16=False, # Đặt thành True nếu bạn có GPU hỗ trợ BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # các tổn thất sử dụng "âm bản trong lô" được hưởng lợi từ việc không có bản sao
# Tham số theo dõi/gỡ lỗi tùy chọn:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="reranker-MiniLM-msmarco-v1", # Sẽ được sử dụng trong W&B nếu `wandb` được cài đặt
)
Bộ đánh giá
Để theo dõi hiệu suất của mô hình của bạn trong quá trình đào tạo, bạn có thể chuyển một eval_dataset đến CrossEncoderTrainer. Tuy nhiên, bạn có thể muốn các số liệu chi tiết hơn ngoài tổn thất đánh giá. Đó là nơi các bộ đánh giá có thể giúp bạn đánh giá hiệu suất của mô hình của bạn bằng cách sử dụng các số liệu cụ thể ở các giai đoạn đào tạo khác nhau. Bạn có thể sử dụng tập dữ liệu đánh giá, bộ đánh giá, cả hai hoặc không, tùy thuộc vào nhu cầu của bạn. Chiến lược và tần suất đánh giá được kiểm soát bởi eval_strategy và eval_steps Đối số Đào tạo.
Sentence Transformers bao gồm các bộ đánh giá tích hợp sau:
| Bộ đánh giá | Dữ liệu bắt buộc |
|---|---|
| CrossEncoderClassificationEvaluator | Các cặp có nhãn lớp (nhị phân hoặc đa lớp) |
| CrossEncoderCorrelationEvaluator | Các cặp có điểm tương đồng |
| CrossEncoderNanoBEIREvaluator | Không yêu cầu dữ liệu |
| CrossEncoderRerankingEvaluator | Danh sách từ điể |
Link bài báo gốc
- Tags:
- Ai
- March 26, 2025
- Huggingface.co