Skip to content

Commit 8d6b45b

Browse files
author
Your Name
committed
重构代码
1 parent d93ae95 commit 8d6b45b

File tree

3 files changed

+7
-81
lines changed

3 files changed

+7
-81
lines changed

examples/train/rft/rft.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
2222
for device in range(device_count):
2323
sample_cmd = (f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample '
2424
f'--model {model} --model_type {model_type} '
25-
f'--dataset {' '.join(dataset)} '
25+
f'--dataset {'
26+
'.join(dataset)} '
2627
f'--data_range {device} {device_count} '
2728
f'--max_length 2048 '
2829
f'--system "You are a math model, you should **think step by step** carefully, '
@@ -61,7 +62,8 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
6162
sample_cmd = (
6263
f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample '
6364
f'--model {model} --model_type {model_type} ' # change to --resume_from_checkpoint to use the latest optimizer state # noqa
64-
f'--dataset {' '.join(dataset)} '
65+
f'--dataset {'
66+
'.join(dataset)} '
6567
f'--data_range {device} {device_count} '
6668
f'--max_length 2048 '
6769
f'--system "You are a math model, you should **think step by step** carefully, '
@@ -108,7 +110,8 @@ def do_train(model: str, model_type: str, datasets: List[str], iter, cmd='sft'):
108110
ga = 128 // get_device_count() // 2
109111
train_cmd = (f'{conda_prefix} {gpu_prefix} swift {cmd} '
110112
f'--model {model} --model_type {model_type} '
111-
f'--dataset {' '.join(datasets)} '
113+
f'--dataset {'
114+
'.join(datasets)} '
112115
f'--max_length 2048 '
113116
f'--num_train_epochs 1 '
114117
f'--load_args false '

gen_data.py

-77
This file was deleted.

swift/llm/dataset/dataset/mllm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def preprocess_row(row: Dict[str, Any]) -> Dict[str, Any]:
566566
what = ''
567567
if ':' in action:
568568
action, what = action[:action.find(':')], action[action.find(':') + 1:]
569-
row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{',' + what.strip()}'
569+
row['response'] = f'Action: {action.strip()}\nAction Input: {where.strip()}{', ' + what.strip()}'
570570
return row
571571

572572
conversations = []

0 commit comments

Comments
 (0)