Skip to content

feature(khev): add equation solver env and related configs #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Khev
Copy link

@Khev Khev commented Mar 17, 2025

Changes Made

  • Added a custom "equation_solver" environment as discussed in issue AlphaZero not working on custom env: solving algebraic equations #329.
  • Created a custom_env folder containing the environment and all necessary files.
  • Included an environment_description file in the custom_env folder, defining the associated Markov Decision Process (MDP).
  • To run the environment, use the following command:
python zoo/custom_envs/equation_solver/config_single_eqn.py

Issue Encountered

I believe there’s an issue in the existing repository, unrelated to my custom environment. It seems the get_next_action method’s return signature in the lzero/policy/alphazero.py has changed and now returns only action, mcts_probs instead of action, mcts_probs, root. I reproduced this issue by running an existing config file:

python zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py

This threw an error (e.g., unpacking mismatch due to fewer return values—please see the exact traceback in my tests if needed). As a temporary workaround, I commented out the offending line, and the code ran successfully.

Request

Could someone confirm if this is a legitimate issue with the repo? It’d be great to know if the get_next_action method was intentionally updated or if this is a bug. Any feedback on the custom environment implementation is also welcome!

Thanks! Looking forward to feedback.

@Khev
Copy link
Author

Khev commented Mar 17, 2025

Some context: I have uploaded only a single_eqn.py env. After we have figured this one out / got the models working on it as expected, I plan to add a more realistic mult_eqn.py env, in which the agent tries to solve a range of equations.

Other context: I created a custom AlphaZero model which uses a regular MLP instead of a CNN

@puyuan1996 puyuan1996 added the enhancement New feature or request label Mar 18, 2025
@puyuan1996
Copy link
Collaborator

Hello,

Thank you very much for your valuable contribution. After reviewing your environment definition and the corresponding configuration, I believe the current equation solver environment falls under a single-player task. Given this, the MuZero algorithm might be a more appropriate choice. Based on previous successful experiments in the Atari environment, MuZero has already demonstrated strong performance in single-player settings.

Therefore, I suggest adding a MuZero-specific configuration for the equation solver environment in the configuration file and conducting initial tests. If the environment definition is complete—which seems likely, given your observation that the PPO algorithm converges relatively quickly—there is good reason to believe that MuZero will achieve even faster convergence.

Regarding the implementation of AlphaZero, its original version was primarily designed for two-player board games. While there are now extensions that support single-player board games (by setting battle_mode='play_with_bot_mode' and explicitly configuring a bot for the environment), we have not yet tested AlphaZero’s convergence behavior on non-board game tasks. Theoretically, AlphaZero should be feasible, but certain implementation details may require adjustments and fine-tuning for single-player tasks.

Overall, prioritizing MuZero for the initial tests seems to better align with the characteristics of the current task and is more likely to yield optimal convergence results. You can first experiment with MuZero and then compare and refine AlphaZero based on the experimental outcomes. Regarding the issue of AlphaZero not converging, it is likely related to the transformation of player_index, which might be causing inaccuracies in the value function’s sign. Debugging this aspect could help identify the specific issue.

Once again, thank you for your valuable contribution.

@puyuan1996 puyuan1996 changed the title Add equation solver env feature(khev): add equation solver env and related configs Mar 18, 2025
action, mcts_probs = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)

# Kev: uncommented this
#action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
Copy link
Collaborator

@puyuan1996 puyuan1996 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, this issue was caused by inconsistencies in the interfaces of ctree and ptree due to recent modifications. We will fix it in the coming days. In the meantime, you can follow this workaround to resolve it. Thank you for your patience!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. No rush at all.

import_names=['zoo.custom_envs.equation_solver.my_alphazero_mlp_model'],
observation_shape=(41,), # Flat vector of length 41
action_space_size=50, # Matches your environment's action_dim
hidden_sizes=[64, 64], # MLP hidden layer sizes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the network size can be increased to ensure it has the necessary capacity and to test its effectiveness.

Copy link
Collaborator

@puyuan1996 puyuan1996 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, it is necessary to check whether obs/reward has been properly normalized and whether the action space (action_mask) covers all reasonable actions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will try. The rewards are already normalized, but the observations are not


@ENV_REGISTRY.register('singleEqn_env')
class singleEqn(BaseEnv):
"""Environment for solving simple algebraic equations using RL in a LightZero‐compatible format."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the evaluation tests on MuZero and AlphaZero have converged, the file structure and annotation format should be optimized according to LightZero's standards before merging. However, these tasks can be addressed at a later stage.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

@Khev
Copy link
Author

Khev commented Mar 18, 2025

Thanks for your reply. I am in trying to get muzero working now. I am encountering a bug, which I think is due to the dynamics action space / action mask, described below, which I was hoping you could potentially provide some insight on.

Some context: I think there were previous issues with action masks for non-board games (#158) which might be related to current problem.

Bug

Running the command below produces the bug

python zoo/custom_envs/equation_solver/config_muzero_single_eqn.py 

The problematic code line is in lzero/mcts/buffer/game_buffer_muzero.py line 741. I quote the code block below. As you can see from the if statement I have included, sometimes len(distributions) != len(legal_actions[policy_index]). In other words, sometimes distributions thinks the action_dim = 24, say legal_actions thinks action_dim = 50.

Context: my env has a variable action space. I set an action_dim = 50 as an upper bound. If I'm in a state where the action set is < action_dim = 50, I pad the number of actions with an identity operation.

I'm trying to figure out where the size of distributions is set. I think it comes from child_visits, but I can't figure out where that is set.

Any help / insight / feedback would be most welcome.

Thanks again!


    def _compute_target_policy_non_reanalyzed(
            self, policy_non_re_context: List[Any], policy_shape: Optional[int]
    ) -> np.ndarray:
        """
        Overview:
            prepare policy targets from the non-reanalyzed context of policies
        Arguments:
            - policy_non_re_context (:obj:`List`): List containing:
                - pos_in_game_segment_list
                - child_visits
                - game_segment_lens
                - action_mask_segment
                - to_play_segment
            - policy_shape: self._cfg.model.action_space_size
        Returns:
            - batch_target_policies_non_re
        """
        batch_target_policies_non_re = []
        if policy_non_re_context is None:
            return batch_target_policies_non_re

        pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context
        game_segment_batch_size = len(pos_in_game_segment_list)
        transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)

        to_play, action_mask = self._preprocess_to_play_and_action_mask(
            game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
        )

        if self._cfg.model.continuous_action_space is True:
            # when the action space of the environment is continuous, action_mask[:] is None.
            action_mask = [
                list(np.ones(self._cfg.model.num_of_sampled_actions, dtype=np.int8)) for _ in range(transition_batch_size)
            ]
            # NOTE: in continuous action space env: we set all legal_actions as -1
            legal_actions = [
                [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
            ]
        else:
            legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]

        with torch.no_grad():
            policy_index = 0
            # 0 -> Invalid target policy for padding outside of game segments,
            # 1 -> Previous target policy for game segments.
            policy_mask = []
            for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits,
                                                                  pos_in_game_segment_list):
                target_policies = []

                for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
                    if current_index < game_segment_len:
                        policy_mask.append(1)
                        # NOTE: child_visit is already a distribution
                        distributions = child_visit[current_index]
                        if self._cfg.action_type == 'fixed_action_space':
                            # for atari/classic_control/box2d environments that only have one player.
                            target_policies.append(distributions)
                        else:
                            # for board games that have two players.
                            policy_tmp = [0 for _ in range(policy_shape)]
                            for index, legal_action in enumerate(legal_actions[policy_index]):
                                # only the action in ``legal_action`` the policy logits is nonzero
                                #breakpoint()
                                print(f"len(distributions)={len(distributions)}, len(legal_actions[policy_index])={len(legal_actions[policy_index])}")                
                                if len(distributions) != len(legal_actions[policy_index]):
                                    breakpoint()
                                policy_tmp[legal_action] = distributions[index]
                            target_policies.append(policy_tmp)
                    else:
                        # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
                        policy_mask.append(0)
                        target_policies.append([0 for _ in range(policy_shape)])

                    policy_index += 1

                batch_target_policies_non_re.append(target_policies)
        batch_target_policies_non_re = np.asarray(batch_target_policies_non_re)
        return batch_target_policies_non_re
        
        ```
        
     

@Khev
Copy link
Author

Khev commented Mar 19, 2025

Hi there,

I have got muzero working on a simpler environment where the action set is small and fixed: action = (add,b), (sub,b), (mul,b), (div,b). Moreover, the episode length = 2. Since the equation is x+b = 0, the solution is simply choosing (sub,b).
Screenshot 2025-03-19 at 11 25 18 AM

I thought muzero would solve this environment instantly. But its still failing: see learning curve and config below. I have tuned some of the parameters which hasn't helped much.

Any ideas about what I might doing wrong?

# ==============================================================
# Kev: Adapted from lunarlander_disc_muzero_config
# ==============================================================

from easydict import EasyDict
from lzero.entry import train_muzero


# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 1
update_per_collect = 100
batch_size = 128
max_env_step = int(1e5)
reanalyze_ratio = 0.2

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

single_eqn_muzero_config = dict(
    exp_name=f'data_muzero/x+b',
    env=dict(
        env_name='singleEqnEasy_env',  # Changed from LunarLander-v2
        continuous=False,
        manually_discretization=False,
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
        n_evaluator_episode=evaluator_env_num,
        manager=dict(shared_memory=False, ),
    ),
    policy=dict(
        model=dict(
            observation_shape=41,  
            action_space_size=4, 
            model_type='mlp',
            latent_state_dim=32,
            self_supervised_learning_loss=False,
            discrete_action_encoding_type='not_one_hot',
            res_connection_in_dynamics=False,
            norm_type='BN',
        ),
        model_path=None,
        cuda=True,
        env_type='not_board_games',
        action_type= "fixed_action_space",
        game_segment_length=2,
        update_per_collect=update_per_collect,
        batch_size=batch_size,
        optim_type='Adam',
        piecewise_decay_lr_scheduler=False,
        learning_rate=0.001,
        ssl_loss_weight=1,
        grad_clip_value=1.0,
        num_simulations=num_simulations,
        reanalyze_ratio=reanalyze_ratio,
        n_episode=n_episode,
        eval_freq=int(1e3),
        replay_buffer_size=int(1e4),
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
    ),
)
single_eqn_muzero_config = EasyDict(single_eqn_muzero_config)
main_config = single_eqn_muzero_config

single_eqn_muzero_create_config = dict(
    env=dict(
        type='singleEqnEasy_env', 
        import_names=['zoo.custom_envs.equation_solver.env_single_eqn_easy'],
    ),
    env_manager=dict(type='subprocess'),
    policy=dict(
        type='muzero',
        import_names=['lzero.policy.muzero'],
    ),
)
single_eqn_muzero_create_config = EasyDict(single_eqn_muzero_create_config)
create_config = single_eqn_muzero_create_config

if __name__ == "__main__":
    seed = 14850
    train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
    

@puyuan1996
Copy link
Collaborator

Hello, thank you for your contribution. After adjusting some parameters, x+b + fixed-action-space-4 has achieved rapid convergence, as shown by the blue line in the figure below (the orange line represents the situation before adjusting the norm).

image

Additionally, x+b + fixed-action-space=50 also converges quickly:

image

We observed that under your original settings, the latent norm quickly stabilized at 1. Therefore, we adjusted it to LN and modified some hyperparameters. Please refer to this commit for the specific changes: [GitHub Commit](05373e0).

For environments with multiple equations, you can adjust the config accordingly and test it out. 😊

@Khev
Copy link
Author

Khev commented Mar 20, 2025

Excellent -- thank you :) I will run the experiments for multiple equations and let you know the results.

@Khev
Copy link
Author

Khev commented Mar 21, 2025

One question, why did u pick the value below?

update_per_collect = int(collector_env_num*max_steps*replay_ratio)

@puyuan1996
Copy link
Collaborator

The replay ratio is a crucial hyperparameter responsible for balancing sample efficiency and network plasticity. You can refer to this paper. In our setting, collector_env_num * max_steps roughly equals the total number of steps collected in this iteration, and a replay ratio of 0.25 is generally a good initial parameter. Of course, you can fine-tune it based on specific circumstances.

@PaParaZz1 PaParaZz1 added the environment New or improved environment label Mar 27, 2025
@Khev
Copy link
Author

Khev commented Mar 29, 2025

Hi there,

Just checking in. Things are going well. Learning curve below is for the multieqn environement (list of equations below) which you can see is being solved perfectly. I plan to scale to even more complicated equations next. After that, I will make the code clearner, and we can maybe to an official git pull.

        self.train_eqns = [
                     'a*x','b*x','c*x','d*x',
                     'x+a','x+b','x+c','x+d',
                     'a*x+b','b*x+a','c*x+d', 'd*x+c',
                     'a/x+b','b/x+a','c/x+d','d/x+c',
                     'c*(a*x+b)+d', 'd*(a*x+b)+c', 'c*(b*x+a)+d', 'd*(b*x+a)+c'
                     ]

@Khev
Copy link
Author

Khev commented Mar 29, 2025

Screenshot 2025-03-29 at 1 13 44 PM

@Khev
Copy link
Author

Khev commented Mar 29, 2025

One question: is there an exploration / action entropy parameter I can tweak? Sometimes the model collapses to just taking one action, which I want to discourage.

@puyuan1996
Copy link
Collaborator

puyuan1996 commented Mar 30, 2025

Hello! I'm very glad to hear that your project is progressing smoothly.

Regarding how to enhance the exploration performance of the MCTS agent, I recommend referring to Section 5.1 "Exploration Strategies in MCTS" of the paper (https://arxiv.org/abs/2310.08348), which provides relevant analyses of various exploration strategies. Combined with the implementation of LightZero, you may consider the following approaches to improve the exploration capability of the policy:

  1. Enable temperature decay: You can refer to the implementation here: https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py#L189;
  2. Enhance the exploration performance of the policy during training. Reference parameter settings here: https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py#L180;
  3. Enable ε-greedy-based exploration strategy (eps_greedy_collect). See this line: https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py#L227.

We recommend trying the first two approaches as a priority.

We look forward to your submission of a well-formatted and complete Pull Request after debugging and optimization in more environments. Thank you for your interest in and support of the LightZero project—we sincerely welcome your contributions!

If you’d like to discuss further, please don’t hesitate to reach out.

@Khev
Copy link
Author

Khev commented Apr 1, 2025

Thanks for the info. Can you take a quick look at the config below? Does it look ok?

# ==============================================================
# Kev: Adapted from lunarlander_disc_muzero_config
# ==============================================================
from easydict import EasyDict
from lzero.entry import train_muzero

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 25
max_steps = 5
replay_ratio = 0.25
update_per_collect = int(collector_env_num*max_steps*replay_ratio)
batch_size = 128
max_env_step = int(5*1e6)
reanalyze_ratio = 0
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

multi_eqn_muzero_config = dict(
    exp_name=f'data_muzero/multieqn/easy/',
    env=dict(
        env_name='multiEqnEasy_env',  # Changed from LunarLander-v2
        max_steps=max_steps,
        continuous=False,
        manually_discretization=False,
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
        n_evaluator_episode=evaluator_env_num,
        manager=dict(shared_memory=False, ),
    ),
    policy=dict(
        model=dict(
            observation_shape=41,  
            action_space_size=21, 
            model_type='mlp',
            hidden_size_list=[1024,1024,1024],
            latent_state_dim=1024,
            self_supervised_learning_loss=True,
            discrete_action_encoding_type='not_one_hot',
            res_connection_in_dynamics=True,
            norm_type='LN',
        ),
        root_dirichlet_alpha=1.0,       # e.g., 0.3
        root_exploration_fraction=0.5, # e.g., 0.25
        td_steps=5,
        num_unroll_steps=5,
        model_path=None,
        cuda=True,
        env_type='not_board_games',
        action_type= "fixed_action_space",
        game_segment_length=max_steps,
        update_per_collect=update_per_collect,
        batch_size=batch_size,
        optim_type='Adam',
        piecewise_decay_lr_scheduler=False,
        learning_rate=0.0001,
        ssl_loss_weight=2,
        grad_clip_value=1.0,
        num_simulations=num_simulations,
        reanalyze_ratio=reanalyze_ratio,
        n_episode=n_episode,
        eval_freq=int(1e3),
        replay_buffer_size=int(1e5),
        policy_entropy_weight=0.1,          # ADDED for entropy-based exploration
        fixed_temperature_value=1.00,  # you can increase this for more exploration
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
    ),
)
multi_eqn_muzero_config = EasyDict(multi_eqn_muzero_config)
main_config = multi_eqn_muzero_config

multi_eqn_muzero_create_config = dict(
    env=dict(
        type='multiEqnEasy_env', 
        import_names=['zoo.custom_envs.equation_solver.env_multi_eqn_easy'],
    ),
    env_manager=dict(type='subprocess'),
    policy=dict(
        type='muzero',
        import_names=['lzero.policy.muzero'],
    ),
)
multi_eqn_muzero_create_config = EasyDict(multi_eqn_muzero_create_config)
create_config = multi_eqn_muzero_create_config

if __name__ == "__main__":
    seed = 14850
    train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)

@puyuan1996
Copy link
Collaborator

puyuan1996 commented Apr 1, 2025

You may use the following configuration for multi_eqn_env; please pay attention to the NOTE comments.

from easydict import EasyDict
from lzero.entry import train_muzero

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
# Number of environments to collect data from
collector_env_num = 8
# Number of episodes to run
n_episode = 8
# Number of environments for evaluation
evaluator_env_num = 3
# Number of simulations to perform
num_simulations = 50  # NOTE: sometimes, 25 is too small for complex environments
# Maximum number of steps per episode
max_steps = 5  # NOTE: max_steps should be at least larger than the optimal episode length
# Ratio of replay buffer updates
replay_ratio = 0.25
# Number of updates per collection of data
update_per_collect = int(collector_env_num * max_steps * replay_ratio)
# Size of each batch for training
batch_size = 256  # NOTE: can be larger
# Maximum number of environment steps
max_env_step = int(5e6)
# Ratio for reanalyzing the data
reanalyze_ratio = 0
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

multi_eqn_muzero_config = dict(
    exp_name=f'data_muzero/multieqn/easy/',
    env=dict(
        env_name='multiEqnEasy_env',  # Changed from LunarLander-v2
        max_steps=max_steps,
        continuous=False,
        manually_discretization=False,
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
        n_evaluator_episode=evaluator_env_num,
        manager=dict(shared_memory=False, ),
    ),
    policy=dict(
        model=dict(
            observation_shape=41,  
            action_space_size=21, 
            model_type='mlp',
            hidden_size_list=[1024, 1024, 1024],
            latent_state_dim=512,  # NOTE: Typically, 512 is sufficient for medium complexity environments, but you can test different configurations
            self_supervised_learning_loss=True,
            discrete_action_encoding_type='not_one_hot',
            res_connection_in_dynamics=True,
            norm_type='LN',
        ),
        # root_dirichlet_alpha=0.3, # NOTE: Typically, we use the default value
        # root_exploration_fraction=0.25,
        td_steps=5,
        num_unroll_steps=5,
        model_path=None,
        cuda=True,
        env_type='not_board_games',
        action_type="fixed_action_space",
        game_segment_length=max_steps,
        update_per_collect=update_per_collect,
        batch_size=batch_size,
        optim_type='Adam',
        piecewise_decay_lr_scheduler=False,
        learning_rate=0.0001,
        ssl_loss_weight=2,
        grad_clip_value=1.0,
        num_simulations=num_simulations,
        reanalyze_ratio=reanalyze_ratio,
        n_episode=n_episode,
        eval_freq=int(1e3),
        replay_buffer_size=int(1e6),
        manual_temperature_decay=True,  # NOTE: Use manually decayed temperature: 1 -> 0.5 -> 0.25 
        threshold_training_steps_for_final_temperature=int(1e5),  # NOTE: The number of final training iterations to control temperature. Please refer to [here](https://github.com/opendilab/LightZero/blob/main/lzero/policy/scaling_transform.py#L131).
        policy_entropy_weight=0.05,          # NOTE: Entropy-based exploration; typically, 0.1 is too large, but you can test it
        collector_env_num=collector_env_num,
        evaluator_env_num=evaluator_env_num,
    ),
)
multi_eqn_muzero_config = EasyDict(multi_eqn_muzero_config)
main_config = multi_eqn_muzero_config

multi_eqn_muzero_create_config = dict(
    env=dict(
        type='multiEqnEasy_env', 
        import_names=['zoo.custom_envs.equation_solver.env_multi_eqn_easy'],
    ),
    env_manager=dict(type='subprocess'),
    policy=dict(
        type='muzero',
        import_names=['lzero.policy.muzero'],
    ),
)
multi_eqn_muzero_create_config = EasyDict(multi_eqn_muzero_create_config)
create_config = multi_eqn_muzero_create_config

if __name__ == "__main__":
    seed = 14850
    train_muzero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request environment New or improved environment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants