这个问题是由于在使用梯度检查点(gradient checkpointing)时,use_cache=True
这个问题是由于在使用梯度检查点(gradient checkpointing)时,use_cache=True
与梯度检查点不兼容导致的。为了解决这个问题,你需要将use_cache
设置为False
。
你可以在你的代码中找到以下部分:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
per_device_eval_batch_size=16,
gradient_accumulation_steps=2,
learning_rate=5e-5,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=3,
logging_dir="./logs",
logging_steps=10,
use_cache=True, # 这里需要修改为 False
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
将use_cache=True
修改为use_cache=False
,然后重新运行你的代码。