Trainer:快速落地训练
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8b")
tokenizer.pad_token = tokenizer.eos_token
args = TrainingArguments(
output_dir="checkpoints/l10",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-5,
warmup_ratio=0.03,
num_train_epochs=3,
fp16=True,
logging_steps=10,
save_strategy="epoch",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
data_collator=lambda batch: causal_lm_collator(batch, tokenizer.pad_token_id),
)
trainer.train()
- 优点:封装优化器、调度器、分布式、日志
- 通过
compute_metrics/callbacks扩展自定义逻辑