-
Notifications
You must be signed in to change notification settings - Fork 152
feature(khev): add equation solver env and related configs #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .env_single_eqn import singleEqn | ||
from .env_single_eqn_easy import singleEqnEasy | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# ============================================================== | ||
# Kev: Adapted from lunarlander_disc_muzero_config | ||
# ============================================================== | ||
|
||
|
||
from easydict import EasyDict | ||
from lzero.entry import train_muzero | ||
|
||
|
||
# ============================================================== | ||
# begin of the most frequently changed config specified by the user | ||
# ============================================================== | ||
collector_env_num = 8 | ||
n_episode = 8 | ||
evaluator_env_num = 3 | ||
num_simulations = 100 | ||
update_per_collect = 200 | ||
batch_size = 256 | ||
max_env_step = int(1e5) | ||
reanalyze_ratio = 0.0 | ||
|
||
# ============================================================== | ||
# end of the most frequently changed config specified by the user | ||
# ============================================================== | ||
|
||
single_eqn_muzero_config = dict( | ||
exp_name=f'data_muzero/x+b', | ||
env=dict( | ||
env_name='singleEqn_env', # Changed from LunarLander-v2 | ||
continuous=False, | ||
manually_discretization=False, | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
n_evaluator_episode=evaluator_env_num, | ||
manager=dict(shared_memory=False, ), | ||
), | ||
policy=dict( | ||
model=dict( | ||
observation_shape=41, # Changed from 8 | ||
action_space_size=50, # Changed from 4 | ||
model_type='mlp', | ||
lstm_hidden_size=128, | ||
latent_state_dim=128, | ||
self_supervised_learning_loss=True, | ||
discrete_action_encoding_type='not_one_hot', | ||
res_connection_in_dynamics=True, | ||
norm_type='BN', | ||
), | ||
model_path=None, | ||
cuda=True, | ||
env_type='not_board_games', | ||
action_type= "varied_action_space", | ||
game_segment_length=10, | ||
update_per_collect=update_per_collect, | ||
batch_size=batch_size, | ||
optim_type='Adam', | ||
piecewise_decay_lr_scheduler=False, | ||
learning_rate=0.001, | ||
ssl_loss_weight=2, | ||
grad_clip_value=0.5, | ||
num_simulations=num_simulations, | ||
reanalyze_ratio=reanalyze_ratio, | ||
n_episode=n_episode, | ||
eval_freq=int(1e3), | ||
replay_buffer_size=int(1e6), | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
), | ||
) | ||
single_eqn_muzero_config = EasyDict(single_eqn_muzero_config) | ||
main_config = single_eqn_muzero_config | ||
|
||
single_eqn_muzero_create_config = dict( | ||
env=dict( | ||
type='singleEqn_env', # Changed from lunarlander | ||
import_names=['zoo.custom_envs.equation_solver.env_single_eqn'], # Changed from lunarlander path | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict( | ||
type='muzero', | ||
import_names=['lzero.policy.muzero'], | ||
), | ||
) | ||
single_eqn_muzero_create_config = EasyDict(single_eqn_muzero_create_config) | ||
create_config = single_eqn_muzero_create_config | ||
|
||
if __name__ == "__main__": | ||
seed = 14850 | ||
train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# ============================================================== | ||
# Kev: Adapted from lunarlander_disc_muzero_config | ||
# ============================================================== | ||
|
||
from easydict import EasyDict | ||
from lzero.entry import train_muzero | ||
|
||
|
||
# ============================================================== | ||
# begin of the most frequently changed config specified by the user | ||
# ============================================================== | ||
collector_env_num = 8 | ||
n_episode = 8 | ||
evaluator_env_num = 3 | ||
num_simulations = 1 | ||
update_per_collect = 100 | ||
batch_size = 128 | ||
max_env_step = int(1e5) | ||
reanalyze_ratio = 0.2 | ||
|
||
# ============================================================== | ||
# end of the most frequently changed config specified by the user | ||
# ============================================================== | ||
|
||
single_eqn_muzero_config = dict( | ||
exp_name=f'data_muzero/x+b', | ||
env=dict( | ||
env_name='singleEqnEasy_env', # Changed from LunarLander-v2 | ||
continuous=False, | ||
manually_discretization=False, | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
n_evaluator_episode=evaluator_env_num, | ||
manager=dict(shared_memory=False, ), | ||
), | ||
policy=dict( | ||
model=dict( | ||
observation_shape=41, | ||
action_space_size=4, | ||
model_type='mlp', | ||
latent_state_dim=32, | ||
self_supervised_learning_loss=False, | ||
discrete_action_encoding_type='not_one_hot', | ||
res_connection_in_dynamics=False, | ||
norm_type='BN', | ||
), | ||
model_path=None, | ||
cuda=True, | ||
env_type='not_board_games', | ||
action_type= "fixed_action_space", | ||
game_segment_length=2, | ||
update_per_collect=update_per_collect, | ||
batch_size=batch_size, | ||
optim_type='Adam', | ||
piecewise_decay_lr_scheduler=False, | ||
learning_rate=0.001, | ||
ssl_loss_weight=1, | ||
grad_clip_value=1.0, | ||
num_simulations=num_simulations, | ||
reanalyze_ratio=reanalyze_ratio, | ||
n_episode=n_episode, | ||
eval_freq=int(1e3), | ||
replay_buffer_size=int(1e4), | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
), | ||
) | ||
single_eqn_muzero_config = EasyDict(single_eqn_muzero_config) | ||
main_config = single_eqn_muzero_config | ||
|
||
single_eqn_muzero_create_config = dict( | ||
env=dict( | ||
type='singleEqnEasy_env', | ||
import_names=['zoo.custom_envs.equation_solver.env_single_eqn_easy'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict( | ||
type='muzero', | ||
import_names=['lzero.policy.muzero'], | ||
), | ||
) | ||
single_eqn_muzero_create_config = EasyDict(single_eqn_muzero_create_config) | ||
create_config = single_eqn_muzero_create_config | ||
|
||
if __name__ == "__main__": | ||
seed = 14850 | ||
train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from my_train_alphazero import my_train_alphazero | ||
from easydict import EasyDict | ||
|
||
# ============================================================== | ||
# Frequently changed config specified by the user (lightweight settings) | ||
# ============================================================== | ||
collector_env_num = 4 # Number of parallel environments for data collection | ||
n_episode = 4 # Number of episodes per training iteration | ||
evaluator_env_num = 1 # Number of evaluator environments | ||
num_simulations = 50 # MCTS simulations per move (try increasing if needed) | ||
update_per_collect = 100 # Number of gradient updates per data collection cycle | ||
batch_size = 32 # Mini-batch size for training | ||
max_env_step = int(1e3) # Maximum total environment steps for a quick run | ||
model_path = None | ||
mcts_ctree = False | ||
|
||
# ============================================================== | ||
# Configurations for singleEqn_env (lightweight version) | ||
# ============================================================== | ||
singleEqn_alphazero_config = dict( | ||
exp_name='data_alphazero/singleEqn/x+b/', | ||
env=dict( | ||
battle_mode='play_with_bot_mode', | ||
battle_mode_in_simulation_env='self_play_mode', # For simulation during MCTS | ||
channel_last=False, | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
n_evaluator_episode=evaluator_env_num, | ||
manager=dict(shared_memory=False), | ||
agent_vs_human=False, | ||
prob_random_agent=0, | ||
prob_expert_agent=0, | ||
prob_random_action_in_bot=0, | ||
scale=True, | ||
render_mode=None, | ||
replay_path=None, | ||
alphazero_mcts_ctree=mcts_ctree, | ||
), | ||
policy=dict( | ||
mcts_ctree=mcts_ctree, | ||
simulation_env_id='singleEqn_env', # Must match the registered name of your environment | ||
model=dict( | ||
type='AlphaZeroMLPModel', | ||
import_names=['zoo.custom_envs.equation_solver.my_alphazero_mlp_model'], | ||
observation_shape=(41,), # Flat vector of length 41 | ||
action_space_size=50, # Matches your environment's action_dim | ||
hidden_sizes=[64, 64], # MLP hidden layer sizes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps the network size can be increased to ensure it has the necessary capacity and to test its effectiveness. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Additionally, it is necessary to check whether obs/reward has been properly normalized and whether the action space (action_mask) covers all reasonable actions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. I will try. The rewards are already normalized, but the observations are not |
||
), | ||
cuda=True, | ||
env_type='not_board_games', | ||
action_type='varied_action_space', | ||
update_per_collect=update_per_collect, | ||
batch_size=batch_size, | ||
optim_type='Adam', | ||
lr_piecewise_constant_decay=False, | ||
# learning_rate=0.003, | ||
learning_rate=3e-4, | ||
grad_clip_value=0.5, | ||
value_weight=1.0, | ||
entropy_weight=0.0, | ||
n_episode=n_episode, | ||
eval_freq=int(2e3), | ||
mcts=dict(num_simulations=num_simulations), | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
other=dict( | ||
replay_buffer=dict( | ||
type='advanced', # Use advanced (or prioritized) replay buffer | ||
replay_buffer_size=10000, # Set a smaller buffer for lightweight runs | ||
sample_min_limit_ratio=0.25, # Allow sampling even if only 50% of batch size is available. | ||
alpha=0.6, | ||
beta=0.4, | ||
anneal_step=100000, | ||
enable_track_used_data=False, | ||
deepcopy=False, | ||
save_episode=False, | ||
) | ||
), | ||
), | ||
) | ||
singleEqn_alphazero_config = EasyDict(singleEqn_alphazero_config) | ||
main_config = singleEqn_alphazero_config | ||
|
||
singleEqn_alphazero_create_config = dict( | ||
env=dict( | ||
type='singleEqn_env', | ||
import_names=['zoo.custom_envs.equation_solver.env_single_eqn'], # Adjust this path if needed | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict( | ||
type='MyAlphaZeroPolicy', # Your custom policy subclass | ||
import_names=['zoo.custom_envs.equation_solver.my_alphazero_policy'], | ||
), | ||
collector=dict( | ||
type='episode_alphazero', | ||
import_names=['lzero.worker.alphazero_collector'], | ||
), | ||
evaluator=dict( | ||
type='alphazero', | ||
import_names=['lzero.worker.alphazero_evaluator'], | ||
) | ||
) | ||
singleEqn_alphazero_create_config = EasyDict(singleEqn_alphazero_create_config) | ||
create_config = singleEqn_alphazero_create_config | ||
|
||
if __name__ == '__main__': | ||
from lzero.entry import train_alphazero | ||
# Merge the environment configuration into the policy config. | ||
main_config.policy.env = main_config.env | ||
my_train_alphazero([main_config, create_config], seed=0, model_path=model_path, max_env_step=max_env_step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, this issue was caused by inconsistencies in the interfaces of ctree and ptree due to recent modifications. We will fix it in the coming days. In the meantime, you can follow this workaround to resolve it. Thank you for your patience!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. No rush at all.