Skip to content

feature(khev): add equation solver env and related configs #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,15 @@ def _compute_target_policy_non_reanalyzed(
policy_tmp = [0 for _ in range(policy_shape)]
for index, legal_action in enumerate(legal_actions[policy_index]):
# only the action in ``legal_action`` the policy logits is nonzero
policy_tmp[legal_action] = distributions[index]
#breakpoint()
#print(f"len(distributions)={len(distributions)}, len(legal_actions[policy_index])={len(legal_actions[policy_index])}")
#if len(distributions) != len(legal_actions[policy_index]):
# breakpoint()
# Temporary fix: might be masking underlying error
if index < len(distributions):
policy_tmp[legal_action] = distributions[index]
else:
policy_tmp[legal_action] = 0
target_policies.append(policy_tmp)
else:
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
Expand Down
6 changes: 5 additions & 1 deletion lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
init_state=init_state[env_id],
katago_policy_init=False,
katago_game_state=katago_game_state[env_id]))
action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
#breakpoint()
action, mcts_probs = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)

# Kev: uncommented this
#action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
Copy link
Collaborator

@puyuan1996 puyuan1996 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, this issue was caused by inconsistencies in the interfaces of ctree and ptree due to recent modifications. We will fix it in the coming days. In the meantime, you can follow this workaround to resolve it. Thank you for your patience!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. No rush at all.


output[env_id] = {
'action': action,
Expand Down
3 changes: 3 additions & 0 deletions zoo/custom_envs/equation_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .env_single_eqn import singleEqn
from .env_single_eqn_easy import singleEqnEasy

89 changes: 89 additions & 0 deletions zoo/custom_envs/equation_solver/config_muzero_single_eqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# ==============================================================
# Kev: Adapted from lunarlander_disc_muzero_config
# ==============================================================


from easydict import EasyDict
from lzero.entry import train_muzero


# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 100
update_per_collect = 200
batch_size = 256
max_env_step = int(1e5)
reanalyze_ratio = 0.0

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

single_eqn_muzero_config = dict(
exp_name=f'data_muzero/x+b',
env=dict(
env_name='singleEqn_env', # Changed from LunarLander-v2
continuous=False,
manually_discretization=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
),
policy=dict(
model=dict(
observation_shape=41, # Changed from 8
action_space_size=50, # Changed from 4
model_type='mlp',
lstm_hidden_size=128,
latent_state_dim=128,
self_supervised_learning_loss=True,
discrete_action_encoding_type='not_one_hot',
res_connection_in_dynamics=True,
norm_type='BN',
),
model_path=None,
cuda=True,
env_type='not_board_games',
action_type= "varied_action_space",
game_segment_length=10,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
piecewise_decay_lr_scheduler=False,
learning_rate=0.001,
ssl_loss_weight=2,
grad_clip_value=0.5,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(1e3),
replay_buffer_size=int(1e6),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)
single_eqn_muzero_config = EasyDict(single_eqn_muzero_config)
main_config = single_eqn_muzero_config

single_eqn_muzero_create_config = dict(
env=dict(
type='singleEqn_env', # Changed from lunarlander
import_names=['zoo.custom_envs.equation_solver.env_single_eqn'], # Changed from lunarlander path
),
env_manager=dict(type='subprocess'),
policy=dict(
type='muzero',
import_names=['lzero.policy.muzero'],
),
)
single_eqn_muzero_create_config = EasyDict(single_eqn_muzero_create_config)
create_config = single_eqn_muzero_create_config

if __name__ == "__main__":
seed = 14850
train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
87 changes: 87 additions & 0 deletions zoo/custom_envs/equation_solver/config_muzero_single_eqn_easy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ==============================================================
# Kev: Adapted from lunarlander_disc_muzero_config
# ==============================================================

from easydict import EasyDict
from lzero.entry import train_muzero


# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 1
update_per_collect = 100
batch_size = 128
max_env_step = int(1e5)
reanalyze_ratio = 0.2

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

single_eqn_muzero_config = dict(
exp_name=f'data_muzero/x+b',
env=dict(
env_name='singleEqnEasy_env', # Changed from LunarLander-v2
continuous=False,
manually_discretization=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
),
policy=dict(
model=dict(
observation_shape=41,
action_space_size=4,
model_type='mlp',
latent_state_dim=32,
self_supervised_learning_loss=False,
discrete_action_encoding_type='not_one_hot',
res_connection_in_dynamics=False,
norm_type='BN',
),
model_path=None,
cuda=True,
env_type='not_board_games',
action_type= "fixed_action_space",
game_segment_length=2,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
piecewise_decay_lr_scheduler=False,
learning_rate=0.001,
ssl_loss_weight=1,
grad_clip_value=1.0,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(1e3),
replay_buffer_size=int(1e4),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)
single_eqn_muzero_config = EasyDict(single_eqn_muzero_config)
main_config = single_eqn_muzero_config

single_eqn_muzero_create_config = dict(
env=dict(
type='singleEqnEasy_env',
import_names=['zoo.custom_envs.equation_solver.env_single_eqn_easy'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='muzero',
import_names=['lzero.policy.muzero'],
),
)
single_eqn_muzero_create_config = EasyDict(single_eqn_muzero_create_config)
create_config = single_eqn_muzero_create_config

if __name__ == "__main__":
seed = 14850
train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
110 changes: 110 additions & 0 deletions zoo/custom_envs/equation_solver/config_single_eqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from my_train_alphazero import my_train_alphazero
from easydict import EasyDict

# ==============================================================
# Frequently changed config specified by the user (lightweight settings)
# ==============================================================
collector_env_num = 4 # Number of parallel environments for data collection
n_episode = 4 # Number of episodes per training iteration
evaluator_env_num = 1 # Number of evaluator environments
num_simulations = 50 # MCTS simulations per move (try increasing if needed)
update_per_collect = 100 # Number of gradient updates per data collection cycle
batch_size = 32 # Mini-batch size for training
max_env_step = int(1e3) # Maximum total environment steps for a quick run
model_path = None
mcts_ctree = False

# ==============================================================
# Configurations for singleEqn_env (lightweight version)
# ==============================================================
singleEqn_alphazero_config = dict(
exp_name='data_alphazero/singleEqn/x+b/',
env=dict(
battle_mode='play_with_bot_mode',
battle_mode_in_simulation_env='self_play_mode', # For simulation during MCTS
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False),
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
prob_random_action_in_bot=0,
scale=True,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
),
policy=dict(
mcts_ctree=mcts_ctree,
simulation_env_id='singleEqn_env', # Must match the registered name of your environment
model=dict(
type='AlphaZeroMLPModel',
import_names=['zoo.custom_envs.equation_solver.my_alphazero_mlp_model'],
observation_shape=(41,), # Flat vector of length 41
action_space_size=50, # Matches your environment's action_dim
hidden_sizes=[64, 64], # MLP hidden layer sizes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the network size can be increased to ensure it has the necessary capacity and to test its effectiveness.

Copy link
Collaborator

@puyuan1996 puyuan1996 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, it is necessary to check whether obs/reward has been properly normalized and whether the action space (action_mask) covers all reasonable actions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will try. The rewards are already normalized, but the observations are not

),
cuda=True,
env_type='not_board_games',
action_type='varied_action_space',
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
lr_piecewise_constant_decay=False,
# learning_rate=0.003,
learning_rate=3e-4,
grad_clip_value=0.5,
value_weight=1.0,
entropy_weight=0.0,
n_episode=n_episode,
eval_freq=int(2e3),
mcts=dict(num_simulations=num_simulations),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
other=dict(
replay_buffer=dict(
type='advanced', # Use advanced (or prioritized) replay buffer
replay_buffer_size=10000, # Set a smaller buffer for lightweight runs
sample_min_limit_ratio=0.25, # Allow sampling even if only 50% of batch size is available.
alpha=0.6,
beta=0.4,
anneal_step=100000,
enable_track_used_data=False,
deepcopy=False,
save_episode=False,
)
),
),
)
singleEqn_alphazero_config = EasyDict(singleEqn_alphazero_config)
main_config = singleEqn_alphazero_config

singleEqn_alphazero_create_config = dict(
env=dict(
type='singleEqn_env',
import_names=['zoo.custom_envs.equation_solver.env_single_eqn'], # Adjust this path if needed
),
env_manager=dict(type='subprocess'),
policy=dict(
type='MyAlphaZeroPolicy', # Your custom policy subclass
import_names=['zoo.custom_envs.equation_solver.my_alphazero_policy'],
),
collector=dict(
type='episode_alphazero',
import_names=['lzero.worker.alphazero_collector'],
),
evaluator=dict(
type='alphazero',
import_names=['lzero.worker.alphazero_evaluator'],
)
)
singleEqn_alphazero_create_config = EasyDict(singleEqn_alphazero_create_config)
create_config = singleEqn_alphazero_create_config

if __name__ == '__main__':
from lzero.entry import train_alphazero
# Merge the environment configuration into the policy config.
main_config.policy.env = main_config.env
my_train_alphazero([main_config, create_config], seed=0, model_path=model_path, max_env_step=max_env_step)
Loading