Skip to content

【开源实习】bart模型微调 #2026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions llm/finetune/bart/bart_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from mindspore import nn, ops, Tensor
from mindspore.dataset import GeneratorDataset
from mindnlp.transformers import BartForConditionalGeneration, BartTokenizer
from mindnlp.engine import Trainer, TrainingArguments
from datasets import load_dataset

import evaluate
import mindspore as ms


rouge_metric = evaluate.load("rouge")
# Load dataset and tokenizer
tokenizer = BartTokenizer.from_pretrained("./bart-base")

dataset = load_dataset("xsum", split="train")
val_dataset = load_dataset("xsum", split="validation")


def preprocess_function(examples):
inputs = tokenizer(examples["document"], max_length=512,
truncation=True, padding="max_length")
targets = tokenizer(
examples["summary"], max_length=128, truncation=True, padding="max_length")
inputs["labels"] = targets["input_ids"]
return inputs


tokenized_data = dataset.map(preprocess_function, batched=True, remove_columns=[
"document", "summary", "id"], num_proc=24)
tokenized_val_data = val_dataset.map(preprocess_function, batched=True, remove_columns=[
"document", "summary", "id"], num_proc=24)


# Load model
model = BartForConditionalGeneration.from_pretrained("./bart-base")


def create_mindspore_dataset(data, batch_size=8):
data_list = list(data)

def generator():
for item in data_list:
yield (
Tensor(item["input_ids"], dtype=ms.int32),
Tensor(item["attention_mask"], dtype=ms.int32),
Tensor(item["labels"], dtype=ms.int32)
)

return GeneratorDataset(generator, column_names=["input_ids", "attention_mask", "labels"]).batch(batch_size)


def compute_metrics(pred):

labels_ids = pred.label_ids
pred_ids = pred.predictions[0]

pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

rouge_output = rouge_metric.compute(
predictions=pred_str,
references=label_str,
rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"],
)

return {
"R1": round(rouge_output["rouge1"], 4),
"R2": round(rouge_output["rouge2"], 4),
"RL": round(rouge_output["rougeL"], 4),
"RLsum": round(rouge_output["rougeLsum"], 4),
}


def preprocess_logits_for_metrics(logits, labels):
"""
防止内存溢出
"""
pred_ids = ms.mint.argmax(logits[0], dim=-1)
return pred_ids, labels


train_dataset = create_mindspore_dataset(tokenized_data, batch_size=4)
eval_dataset = create_mindspore_dataset(tokenized_val_data, batch_size=2)

training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=2,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=2,
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

trainer.train()
32 changes: 32 additions & 0 deletions llm/finetune/bart/bart_finetune_readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
## bart模型微调报告

### 任务
- **任务编号**:#IAUOXU
- **任务链接**:[【开源实习】bart模型微调](https://gitee.com/mindspore/community/issues/IAUOXU)
- **实现内容**:实现了bart模型在XSum数据集上的微调。
- **模型**:`facebook/bart-base`
- **数据集**:`EdinburghNLP/xsum`

---

### 结果对比

#### **Mindnlp+D910B**

| Epoch | Eval Loss | R1 (ROUGE-1) | R2 (ROUGE-2) | RL (ROUGE-L) | RLsum (ROUGE-Lsum) |
|------:|----------:|-------------:|-------------:|-------------:|-------------------:|
| 1 | 0.4504 | 0.5265 | 0.2512 | 0.5003 | 0.5004 |
| 2 | 0.4481 | 0.5272 | 0.2538 | 0.5026 | 0.5025 |
| 3 | 0.4440 | 0.5316 | 0.2580 | 0.5061 | 0.5062 |

---

#### **Pytorch+3090**

| Epoch | Eval Loss | R1 (ROUGE-1) | R2 (ROUGE-2) | RL (ROUGE-L) | RLsum (ROUGE-Lsum) |
|------:|----------:|-------------:|-------------:|-------------:|-------------------:|
| 1 | 0.4364 | 0.5226 | 0.2432 | 0.4965 | 0.4961 |
| 2 | 0.4297 | 0.5309 | 0.2547 | 0.5066 | 0.5065 |
| 3 | 0.4290 | 0.5318 | 0.2563 | 0.5065 | 0.5062 |

---