-
Notifications
You must be signed in to change notification settings - Fork 605
增加grpo多次工具调用训练 #3503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
增加grpo多次工具调用训练 #3503
Conversation
数据集可以放在modelscope上,然后使用model_id进行使用嘛 然后 最外层目录的文件 放置在examples/train/grpo中单读建立个文件夹,然后放置在里面,然后写给文档(最佳实践)来介绍一下不 |
lint过一下,会进行代码的整理 |
好的 |
把数据集上传到了modelscope,然后新增一个最佳实践多轮工具调用实践 |
examples/train/rft/rft.py
Outdated
@@ -22,7 +22,8 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int): | |||
for device in range(device_count): | |||
sample_cmd = (f'{conda_prefix} CUDA_VISIBLE_DEVICES={device} swift sample ' | |||
f'--model {model} --model_type {model_type} ' | |||
f'--dataset {" ".join(dataset)} ' | |||
f'--dataset {' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有语法问题,请检查一下
scripts/benchmark/exp_utils.py
Outdated
@@ -122,7 +122,7 @@ def run(self, exp: Experiment): | |||
exp.runtime = runtime | |||
envs = deepcopy(runtime.get('env', {})) | |||
envs.update(os.environ) | |||
logger.info(f'Running cmd: {runtime["running_cmd"]}, env: {runtime.get("env", {})}') | |||
logger.info(f'Running cmd: {runtime['running_cmd']}, env: {runtime.get('env', {})}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
检查+1
期待调用工具的GRPO支持 |
test_grpo_tool.py:训练测试脚本
math_tool.py:测试的工具,定义新运算,接口主要是判断是否继续和给格式奖励,还有online的result输入
相关数据集也放在目录中,比较混乱,主要还是修改gpro_trainer.py
GRPO args需要新增参数:
is_reward_tool_call:是否累加计算每个tool_call的格式奖励,但应该设置上限,否则可能会学到无限调用,但不输出正确答案。
tool_call_weight:tool_call_format奖励的权重