mmBERT- ModernBERT trở nên đa ngôn ngữ
- 18 min read
mmBERT: ModernBERT vươn mình ra Đa ngôn ngữ
Bài viết này giới thiệu mmBERT, một mô hình mã hóa đa ngôn ngữ lớn hiện đại, được huấn luyện trên hơn 3 nghìn tỷ token văn bản ở hơn 1800 ngôn ngữ. Nó cho thấy sự cải thiện đáng kể về hiệu suất và tốc độ so với các mô hình đa ngôn ngữ trước đây, là mô hình đầu tiên cải thiện XLM-R, đồng thời phát triển các chiến lược mới để học hiệu quả các ngôn ngữ ít tài nguyên. mmBERT xây dựng dựa trên ModernBERT cho một kiến trúc cực nhanh và thêm các thành phần mới để cho phép học đa ngôn ngữ hiệu quả.
Nếu bạn quan tâm đến việc tự mình thử các mô hình, một số boilerplate ví dụ có sẵn ở cuối bài đăng trên blog này!
TL;DR
- mmBERT là một mô hình mã hóa đa ngôn ngữ hiện đại, được huấn luyện trên hơn 3 nghìn tỷ token ở hơn 1800 ngôn ngữ.
- mmBERT cho thấy sự cải thiện đáng kể về hiệu suất và tốc độ so với các mô hình đa ngôn ngữ trước đây.
- mmBERT phát triển các chiến lược mới để học hiệu quả các ngôn ngữ ít tài nguyên.
Dữ liệu huấn luyện
mmBERT được huấn luyện trên một tập dữ liệu đa ngôn ngữ được tuyển chọn cẩn thận với tổng cộng hơn 3T token qua ba giai đoạn huấn luyện riêng biệt. Nền tảng của dữ liệu huấn luyện của chúng tôi bao gồm ba lần thu thập dữ liệu web nguồn mở và chất lượng cao chính, cho phép cả độ bao phủ đa ngôn ngữ và chất lượng dữ liệu:
DCLM và DCLM được lọc cung cấp nội dung tiếng Anh chất lượng cao nhất hiện có, đóng vai trò là xương sống cho hiệu suất tiếng Anh mạnh mẽ (với dữ liệu được lọc đến từ Dolmino). Tập dữ liệu này thể hiện các kỹ thuật lọc web hiện đại và tạo thành một thành phần quan trọng. Do chất lượng cao của dữ liệu này, chúng tôi sử dụng tỷ lệ tiếng Anh cao hơn đáng kể so với các mô hình mã hóa đa ngôn ngữ thế hệ trước (lên đến 18%).
FineWeb2 cung cấp nội dung web đa ngôn ngữ rộng lớn multilingual web content bao gồm hơn 1.800 ngôn ngữ. Tập dữ liệu này cho phép độ bao phủ đa ngôn ngữ mở rộng của chúng tôi trong khi vẫn duy trì các tiêu chuẩn chất lượng hợp lý trên các họ ngôn ngữ và bảng chữ cái đa dạng.
FineWeb2-HQ bao gồm một filtered subset of FineWeb2 tập trung vào 20 ngôn ngữ có nhiều tài nguyên. Phiên bản được lọc này cung cấp nội dung đa ngôn ngữ chất lượng cao hơn, thu hẹp khoảng cách giữa dữ liệu được lọc chỉ bằng tiếng Anh và độ bao phủ đa ngôn ngữ rộng rãi.
Dữ liệu huấn luyện cũng kết hợp các kho ngữ liệu chuyên dụng từ Dolma, MegaWika v2, ProLong và hơn thế nữa: kho mã (StarCoder, ProLong), nội dung học thuật (ArXiv, PeS2o), tài liệu tham khảo (Wikipedia, sách giáo khoa) và thảo luận cộng đồng (StackExchange), cùng với hướng dẫn và tập dữ liệu toán học.
Đổi mới quan trọng trong phương pháp tiếp cận dữ liệu của chúng tôi là chiến lược bao gồm ngôn ngữ lũy tiến được hiển thị trong Hình 1. Ở mỗi giai đoạn, chúng tôi lấy mẫu dần dần từ phân phối phẳng hơn (tức là gần với đồng nhất hơn), đồng thời thêm các ngôn ngữ mới. Điều này có nghĩa là các ngôn ngữ có nhiều tài nguyên như tiếng Nga bắt đầu với tỷ lệ dữ liệu cao (tức là 9%) và sau đó trong giai đoạn huấn luyện cuối cùng kết thúc khoảng một nửa tỷ lệ đó. Chúng tôi bắt đầu với 60 ngôn ngữ có nhiều tài nguyên trong quá trình huấn luyện trước, mở rộng lên 110 ngôn ngữ trong quá trình huấn luyện giữa và cuối cùng bao gồm tất cả 1.833 ngôn ngữ từ FineWeb2 trong giai đoạn phân rã. Điều này cho phép chúng tôi tối đa hóa tác động của dữ liệu ngôn ngữ ít tài nguyên hạn chế mà không cần lặp lại quá mức và đồng thời duy trì chất lượng dữ liệu tổng thể cao.
Công thức huấn luyện và các thành phần mới
mmBERT xây dựng dựa trên kiến trúc ModernBERT nhưng giới thiệu một số đổi mới chính để học đa ngôn ngữ:
Kiến trúc
Chúng tôi sử dụng cùng một kiến trúc cốt lõi với ModernBERT-base với 22 lớp và 1152 kích thước trung gian, nhưng chuyển sang bộ mã hóa thông báo Gemma 2 để xử lý văn bản đa ngôn ngữ tốt hơn. Mô hình cơ sở có 110 triệu tham số không nhúng (tổng cộng 307 triệu do từ vựng lớn hơn), trong khi biến thể nhỏ có 42 triệu tham số không nhúng (tổng cộng 140 triệu).
Phương pháp huấn luyện ba giai đoạn
Quá trình huấn luyện của chúng tôi tuân theo một lịch trình ba giai đoạn được thiết kế cẩn thận:
- Huấn luyện trước (2,3T token): Khởi động và giai đoạn tốc độ học ổn định sử dụng 60 ngôn ngữ với tỷ lệ che phủ 30%
- Huấn luyện giữa (600B token): Mở rộng ngữ cảnh lên 8192 token, dữ liệu chất lượng cao hơn, mở rộng lên 110 ngôn ngữ với tỷ lệ che phủ 15%
- Giai đoạn phân rã (100B token): Phân rã tốc độ học căn bậc hai nghịch đảo, bao gồm tất cả 1.833 ngôn ngữ với tỷ lệ che phủ 5%
Kỹ thuật huấn luyện mới
Lịch trình Tỷ lệ Che phủ Nghịch đảo: Thay vì sử dụng tỷ lệ che phủ cố định, chúng tôi giảm dần tỷ lệ che phủ từ 30% → 15% → 5% trong các giai đoạn huấn luyện. Điều này cho phép mô hình học các biểu diễn cơ bản với độ che phủ cao hơn sớm hơn, sau đó tập trung vào sự hiểu biết sắc thái hơn với tỷ lệ che phủ thấp hơn.
Học Ngôn ngữ Tôi luyện: Chúng tôi điều chỉnh động nhiệt độ để lấy mẫu dữ liệu đa ngôn ngữ từ τ=0,7 → 0,5 → 0,3. Điều này tạo ra một sự tiến triển từ độ lệch ngôn ngữ tài nguyên cao sang lấy mẫu đồng đều hơn, cho phép mô hình xây dựng nền tảng đa ngôn ngữ mạnh mẽ trước khi học các ngôn ngữ ít tài nguyên.
Bổ sung Ngôn ngữ Lũy tiến: Thay vì huấn luyện trên tất cả các ngôn ngữ đồng thời, chúng tôi thêm một cách chiến lược các ngôn ngữ ở mỗi giai đoạn (60 → 110 → 1.833). Điều này tối đa hóa hiệu quả học tập bằng cách tránh các kỷ nguyên quá mức trên dữ liệu ít tài nguyên hạn chế trong khi vẫn đạt được hiệu suất mạnh mẽ.
Hợp nhất Mô hình: Chúng tôi huấn luyện ba biến thể khác nhau trong giai đoạn phân rã (tập trung vào tiếng Anh, 110 ngôn ngữ và tất cả ngôn ngữ) và sử dụng hợp nhất TIES để kết hợp các điểm mạnh của chúng vào mô hình cuối cùng.
Kết quả
Hiểu Ngôn ngữ Tự nhiên (NLU)
Hiệu suất Tiếng Anh: Trên điểm chuẩn GLUE tiếng Anh (Bảng 1), mmBERT base đạt được hiệu suất mạnh mẽ, vượt trội đáng kể so với các mô hình đa ngôn ngữ khác như XLM-R (đa ngôn ngữ RoBERTa) base và mGTE base, đồng thời vẫn cạnh tranh với các mô hình chỉ bằng tiếng Anh mặc dù ít hơn 25% dữ liệu huấn luyện mmBERT là tiếng Anh.
Hiệu suất Đa ngôn ngữ: mmBERT cho thấy những cải thiện đáng kể trên điểm chuẩn XTREME so với XLM-R như được chứng minh trong Bảng 2. Những lợi ích đáng chú ý bao gồm hiệu suất mạnh mẽ trên phân loại XNLI, những cải thiện đáng kể trong các nhiệm vụ trả lời câu hỏi như TyDiQA và kết quả cạnh tranh trên PAWS-X và XCOPA để hiểu đa ngôn ngữ.
Mô hình hoạt động tốt trên hầu hết các danh mục, ngoại trừ một số nhiệm vụ dự đoán có cấu trúc như NER và gắn thẻ POS, có thể là do sự khác biệt về bộ mã hóa thông báo ảnh hưởng đến việc phát hiện ranh giới từ. Trên các danh mục này, nó hoạt động tương tự như thế hệ trước, nhưng có thể được áp dụng cho nhiều ngôn ngữ hơn.
Hiệu suất truy xuất
Truy xuất Tiếng Anh: Mặc dù mmBERT được thiết kế cho các cài đặt đa ngôn ngữ lớn, nhưng trong điểm chuẩn MTEB v2 tiếng Anh (Bảng 3), mmBERT cho thấy những lợi ích đáng kể so với các mô hình đa ngôn ngữ trước đây và thậm chí ngang bằng với khả năng của các mô hình chỉ bằng tiếng Anh như ModernBERT!
Truy xuất Đa ngôn ngữ: mmBERT cho thấy những cải thiện nhất quán trên điểm chuẩn MTEB v2 đa ngôn ngữ so với các mô hình khác (Bảng 4).
Truy xuất Mã: Do bộ mã hóa thông báo hiện đại (dựa trên Gemma 2), mmBERT cũng cho thấy hiệu suất mã hóa mạnh mẽ (Bảng 5), khiến mmBERT phù hợp với bất kỳ loại dữ liệu văn bản nào. Mô hình duy nhất hoạt động tốt hơn nó là EuroBERT, có thể sử dụng tập dữ liệu Stack v2 không thể truy cập công khai.
Học ngôn ngữ trong giai đoạn phân rã
Một trong những tính năng mới quan trọng nhất của mmBERT là chứng minh rằng các ngôn ngữ ít tài nguyên có thể được học hiệu quả trong giai đoạn phân rã ngắn của quá trình huấn luyện. Chúng tôi đã xác nhận phương pháp này bằng cách kiểm tra trên các ngôn ngữ chỉ được giới thiệu trong giai đoạn phân rã 100B token cuối cùng.
Lợi ích Hiệu suất Đáng kể: Kiểm tra trên TiQuaD (Tigrinya) và FoQA (Faroese), chúng tôi quan sát thấy những cải thiện đáng kể khi các ngôn ngữ này được bao gồm trong giai đoạn phân rã, như được hiển thị trong Hình 2. Kết quả chứng minh tính hiệu quả của phương pháp học ngôn ngữ lũy tiến của chúng tôi.
Cạnh tranh với các Mô hình Lớn: Mặc dù chỉ nhìn thấy các ngôn ngữ này trong giai đoạn huấn luyện cuối cùng, mmBERT đạt được mức hiệu suất vượt quá các mô hình lớn hơn nhiều. Về câu trả lời câu hỏi Faroese, nơi LLM đã được đánh giá, mmBERT hoạt động tốt hơn Google Gemini 2.5 Pro và OpenAI o3.
Cơ chế Học Tập Nhanh chóng: Sự thành công của việc học ngôn ngữ trong giai đoạn phân rã xuất phát từ khả năng của mô hình tận dụng nền tảng đa ngôn ngữ mạnh mẽ của nó được xây dựng trong các giai đoạn trước đó. Khi tiếp xúc với các ngôn ngữ mới, mô hình có thể nhanh chóng điều chỉnh các biểu diễn đa ngôn ngữ hiện có thay vì học từ đầu.
Lợi ích Hợp nhất Mô hình: Các mô hình mmBERT cuối cùng giữ lại thành công hầu hết các cải tiến trong giai đoạn phân rã trong khi được hưởng lợi từ các biến thể tập trung vào tiếng Anh và tài nguyên cao thông qua hợp nhất TIES.
Cải thiện hiệu quả
mmBERT mang lại những cải thiện đáng kể về hiệu quả so với các mô hình mã hóa đa ngôn ngữ trước đó thông qua những cải tiến kiến trúc được kế thừa từ ModernBERT:
Hiệu suất Thông lượng: mmBERT xử lý văn bản nhanh hơn đáng kể so với các mô hình đa ngôn ngữ hiện có trên các độ dài chuỗi khác nhau, như được chứng minh trong Hình 3. Cả mô hình nhỏ và mô hình cơ sở đều cho thấy những cải thiện đáng kể về tốc độ so với bộ mã hóa đa ngôn ngữ trước đây.
Lợi ích của Kiến trúc Hiện đại: Các lợi ích về hiệu quả đến từ hai cải tiến kỹ thuật chính:
- Flash Attention 2: Tính toán chú ý được tối ưu hóa để sử dụng bộ nhớ và tốc độ tốt hơn
- Kỹ thuật bỏ đệm: Loại bỏ các mã thông báo đệm không cần thiết trong quá trình xử lý
Tỷ lệ Độ dài Chuỗi: Không giống như các mô hình cũ hơn giới hạn ở 512 token, mmBERT xử lý hiệu quả tới 8.192 token trong khi vẫn duy trì thông lượng cao. Điều này làm cho nó phù hợp với các tác vụ xử lý tài liệu dài hơn ngày càng phổ biến trong các ứng dụng đa ngôn ngữ.
Hiệu quả Năng lượng: Sự kết hợp giữa thông lượng tốt hơn và kiến trúc hiện đại dẫn đến chi phí tính toán thấp hơn cho suy luận, làm cho mmBERT thiết thực hơn cho các triển khai sản xuất nơi cần hỗ trợ đa ngôn ngữ ở quy mô lớn.
Những cải thiện về hiệu quả này làm cho mmBERT không chỉ chính xác hơn bộ mã hóa đa ngôn ngữ trước đây, mà còn thiết thực hơn đáng kể cho việc sử dụng thực tế.
Ví dụ sử dụng
Bạn có thể sử dụng các mô hình này chỉ với một vài dòng mã!
python from transformers import AutoTokenizer, AutoModelForMaskedLM import torch
tokenizer = AutoTokenizer.from_pretrained(“jhu-clsp/mmBERT-base”) model = AutoModelForMaskedLM.from_pretrained(“jhu-clsp/mmBERT-base”)
def predict_masked_token(text): inputs = tokenizer(text, return_tensors=“pt”) with torch.no_grad(): outputs = model(**inputs) mask_indices = torch.where(inputs[“input_ids”] == tokenizer.mask_token_id) predictions = outputs.logits[mask_indices] top_tokens, top_indices = torch.topk(predictions, 5, dim=-1) return [tokenizer.decode(token) for token in top_indices[0]]
Hoạt động trên các ngôn ngữ
texts = [ “The capital of France is .”, “La capital de España es .”, “Die Hauptstadt von Deutschland ist .”, ]
for text in texts: predictions = predict_masked_token(text) print(f"Text: {text}") print(f"Predictions: {predictions}\n")
Ví dụ tinh chỉnh
Bộ mã hóa
python import argparse
from datasets import load_dataset from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.evaluation import TripletEvaluator from sentence_transformers.losses import CachedMultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers
def main(): # phân tích cú pháp lr & tên mô hình parser = argparse.ArgumentParser() parser.add_argument("–lr", type=float, default=8e-5) parser.add_argument("–model_name", type=str, default=“jhu-clsp/mmBERT-small”) args = parser.parse_args() lr = args.lr model_name = args.model_name model_shortname = model_name.split("/")[-1]
# 1. Tải một mô hình để tinh chỉnh
model = SentenceTransformer(model_name)
# 2. Tải một tập dữ liệu để tinh chỉnh trên
dataset = load_dataset(
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
"triplet-hard",
split="train",
)
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"].select(range(1_250_000))
eval_dataset = dataset_dict["test"]
# 3. Xác định một hàm mất mát
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16) # Tăng mini_batch_size nếu bạn có đủ VRAM
run_name = f"{model_shortname}-DPR-{lr}"
# 4. (Tùy chọn) Chỉ định các đối số huấn luyện
args = SentenceTransformerTrainingArguments(
# Tham số bắt buộc:
output_dir=f"output/{model_shortname}/{run_name}",
# Các tham số huấn luyện tùy chọn:
num_train_epochs=1,
per_device_train_batch_size=512,
per_device_eval_batch_size=512,
warmup_ratio=0.05,
fp16=False, # Đặt thành False nếu GPU không thể xử lý FP16
bf16=True, # Đặt thành True nếu GPU hỗ trợ BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached) MultipleNegativesRankingLoss hưởng lợi từ việc không có bản sao
learning_rate=lr,
# Các tham số theo dõi/gỡ lỗi tùy chọn:
save_strategy="steps",
save_steps=500,
save_total_limit=2,
logging_steps=500,
run_name=run_name, # Được sử dụng trong `wandb`, `tensorboard`, `neptune`, v.v. nếu được cài đặt
)
# 5. (Tùy chọn) Tạo một trình đánh giá & đánh giá mô hình cơ sở
dev_evaluator = TripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
name="msmarco-co-condenser-dev",
)
dev_evaluator(model)
# 6. Tạo một trình huấn luyện & huấn luyện
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 7. (Tùy chọn) Đánh giá mô hình được huấn luyện trên trình đánh giá sau khi huấn luyện
dev_evaluator(model)
# 8. Lưu mô hình
model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
# 9. (Tùy chọn) Đẩy nó lên Hugging Face Hub
model.push_to_hub(run_name, private=False)
if name == “main”: main()
python from datasets import load_dataset from pylate import losses, models, utils from sentence_transformers import ( SentenceTransformerTrainer, SentenceTransformerTrainingArguments, )
def main(): # Tải các tập dữ liệu cần thiết cho chưng cất kiến thức (huấn luyện, truy vấn, tài liệu) train = load_dataset( path=“lightonai/ms-marco-en-bge”, name=“train”, )
queries = load_dataset(
path="lightonai/ms-marco-en-bge",
name="queries",
)
documents = load_dataset(
path="lightonai/ms-marco-en-bge",
name="documents",
)
# Đặt biến đổi để tải các văn bản tài liệu/truy vấn bằng cách sử dụng các id tương ứng một cách nhanh chóng
train.set_transform(
utils.KDProcessing(queries=queries, documents=documents).transform,
)
# Xác định mô hình cơ sở, các tham số huấn luyện và thư mục đầu ra
num_train_epochs = 1
lr = 8e-5
batch_size = 16
accum_steps = 1
model_name = "jhu-clsp/mmBERT-small"
model_shortname = model_name.split("/")[-1]
# Đặt tên chạy cho mục đích ghi nhật ký và thư mục đầu ra
run_name = f"{model_shortname}-colbert-KD-{lr}"
output_dir = f"output/{model_shortname}/{run_name}"
# Khởi tạo mô hình ColBERT từ mô hình cơ sở
model = models.ColBERT(model_name_or_path=model_name)
# Định cấu hình các đối số huấn luyện (ví dụ: số epoch, kích thước lô, tốc độ học)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
fp16=False, # Đặt thành False nếu bạn gặp lỗi rằng GPU của bạn không thể chạy trên FP16
bf16=True, # Đặt thành True nếu bạn có GPU hỗ trợ BF16
run_name=run_name,
logging_steps=10,
learning_rate=lr,
gradient_accumulation_steps=accum_steps,
warmup_ratio=0.05,
)
# Sử dụng hàm mất mát Chưng cất cho huấn luyện
train_loss = losses.Distillation(model=model)
# Khởi tạo trình huấn luyện
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train,
loss=train_loss,
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)
# Bắt đầu quá trình huấn luyện
trainer.train()
model.save_pretrained(f"{output_dir}/final")
if name == “main”: main()
python import logging
from datasets import load_dataset
from sentence_transformers import ( SparseEncoder, SparseEncoderModelCardData, SparseEncoderTrainer, SparseEncoderTrainingArguments, ) from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
1. Tải một mô hình để tinh chỉnh với 2. (Tùy chọn) dữ liệu thẻ mô hình
model = SparseEncoder( “jhu-clsp/mmBERT-small”, model_card_data=SparseEncoderModelCardData( language=“en”, license=“apache-2.0”, ) )
3. Tải một tập dữ liệu để tinh chỉnh trên
full_dataset = load_dataset(“sentence-transformers/natural-questions”, split=“train”).select(range(100_000)) dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12) train_dataset = dataset_dict[“train”] eval_dataset = dataset_dict[“test”]
4. Xác định một hàm mất mát
loss = SpladeLoss( model=model, loss=SparseMultipleNegativesRankingLoss(model=model), query_regularizer_weight=5e-5, document_regularizer_weight=3e-5, )
5. (Tùy chọn) Chỉ định các đối số huấn luyện
run_name = “splade-distilbert-base-uncased-nq”
args = SparseEncoderTrainingArguments(
# Tham số bắt buộc:
output_dir=f"models/{run_name}",
# Các tham số huấn luyện tùy chọn:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Đặt thành False nếu bạn gặp lỗi rằng GPU của bạn không thể chạy trên FP16
bf16=False, # Đặt thành True nếu bạn có GPU hỗ trợ BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss hưởng lợi từ việc không có các mẫu trùng lặp trong một lô
# Các tham số theo dõi/gỡ lỗi tùy chọn:
eval_strategy=“steps”,
eval_steps=1000,
save_strategy=“steps”,
save_steps=1000,
save_total_limit=2,
logging_steps=200,
run_name=run_name, # Sẽ được sử dụng trong W&B nếu wandb được cài đặt
)
6. (Tùy chọn) Tạo một trình đánh giá & đánh giá mô hình cơ sở
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=[“msmarco”, “nfcorpus”, “nq”], batch_size=16)
7. Tạo một trình huấn luyện & huấn luyện
trainer = SparseEncoderTrainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=dev_evaluator, ) trainer.train()
8. Đánh giá lại hiệu suất mô hình sau khi huấn luyện
dev_evaluator(model)
9. Lưu mô hình đã huấn luyện
model.save_pretrained(f"models/{run_name}/final")
10. (Tùy chọn) Đẩy nó lên Hugging Face Hub
model.push_to_hub(run_name)
python import logging import traceback
import torch from datasets import load_dataset
from sentence_transformers import SentenceTransformer from sentence_transformers.cross_encoder import ( CrossEncoder, CrossEncoderModelCardData, CrossEncoderTrainer, CrossEncoderTrainingArguments, ) from sentence_transformers.cross_encoder.evaluation import ( CrossEncoderNanoBEIREvaluator, CrossEncoderRerankingEvaluator, ) from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss from sentence_transformers.evaluation import SequentialEvaluator from sentence_transformers.util import mine_hard_negatives
Đặt mức ghi nhật ký thành INFO để có thêm thông tin
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main(): model_name = “jhu-clsp/mmBERT-small”
train_batch_size = 64
num_epochs = 1
num_hard_negatives = 5 # Có bao nhiêu số âm khó nên được khai thác cho mỗi cặp câu hỏi-trả lời
# 1a. Tải một mô hình để tinh chỉnh với 1b. (Tùy chọn) dữ liệu thẻ mô hình
model = CrossEncoder(
model_name,
model_card_data=CrossEncoderModelCardData(
language="en",
license="apache-2.0",
),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2a. Tải tập dữ liệu GooAQ: https://huggingface.co/datasets/sentence-transformers/gooaq
logging.info("Đọc tập dữ liệu huấn luyện gooaq")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 2b. Sửa đổi tập dữ liệu huấn luyện của chúng tôi để bao gồm các số âm khó bằng cách sử dụng một mô hình nhúng rất hiệu quả
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=num_hard_negatives, # Có bao nhiêu số âm trên mỗi cặp câu hỏi-trả lời
margin=0, # Mức độ tương đồng giữa truy vấn và các mẫu âm nên thấp hơn x so với mức độ tương đồng giữa truy vấn-dương
range_min=0, # Bỏ qua x mẫu tương tự nhất
range_max=100, # Chỉ xem xét x mẫu tương tự nhất
sampling_strategy="top", # Lấy mẫu các số âm hàng đầu từ phạm vi
batch_size=4096, # Sử dụng kích thước lô là 4096 cho mô hình nhúng
output_format="labeled-pair", # Định dạng đầu ra là (truy vấn, đoạn văn, nhãn), như được yêu cầu bởi BinaryCrossEntropyLoss
use_faiss=True,
)
logging.info(hard_train_dataset)
# 2c. (Tùy chọn) Lưu tập dữ liệu huấn luyện khó vào đĩa
# hard_train_dataset.save_to_disk("gooaq-hard-train")
# Tải lại với:
# hard_train_dataset = load_from_disk("gooaq-hard-train")
# 3. Xác định hàm mất mát huấn luyện của chúng tôi.
# pos_weight được khuyến nghị đặt làm tỷ lệ giữa dương tính và âm tính, hay còn gọi là `num_hard_negatives`
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
# 4a. Xác định trình đánh giá. Chúng tôi sử dụng CrossEncoderNanoBEIREvaluator, đây là một trình đánh giá nhẹ cho việc xếp hạng lại tiếng Anh
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
# 4b. Xác định một trình đánh giá xếp hạng lại bằng cách khai thác các số âm khó cho các cặp câu hỏi-trả lời
# Chúng tôi bao gồm câu trả lời dương tính trong danh sách các số âm, vì vậy trình đánh giá có thể sử dụng hiệu suất của
# mô hình nhúng làm đường cơ sở.
hard_eval