Skip to content

Add model_context to SelectorGroupChat for enhanced speaker selection #6330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
)

# Append all messages to thread
self._message_thread.extend(message.messages)
await self.update_message_thread(message.messages)

# Check termination condition after processing all messages
if await self._apply_termination_condition(message.messages):
Expand All @@ -139,17 +139,19 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
cancellation_token=ctx.cancellation_token,
)

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
try:
# Append the message to the message thread and construct the delta.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
delta.append(message.agent_response.chat_message)
await self.update_message_thread(delta)

# Check if the conversation should be terminated.
if await self._apply_termination_condition(delta, increment_turn_count=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
await self.update_message_thread([message.agent_response.chat_message])
delta.append(message.agent_response.chat_message)

if self._termination_condition is not None:
Expand Down Expand Up @@ -263,7 +263,7 @@ async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> No
)

# Save my copy
self._message_thread.append(ledger_message)
await self.update_message_thread([ledger_message])

# Log it to the output topic.
await self.publish_message(
Expand Down Expand Up @@ -376,7 +376,7 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None

# Broadcast the next step
message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name)
self._message_thread.append(message) # My copy
await self.update_message_thread([message]) # My copy

await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}")
# Log it to the output topic.
Expand Down Expand Up @@ -458,7 +458,7 @@ async def _prepare_final_answer(self, reason: str, cancellation_token: Cancellat
assert isinstance(response.content, str)
message = TextMessage(content=response.content, source=self._name)

self._message_thread.append(message) # My copy
await self.update_message_thread([message]) # My copy

# Log it to the output topic.
await self.publish_message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
ModelFamily,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self

Expand All @@ -15,6 +25,7 @@
from ...messages import (
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MessageFactory,
)
from ...state import SelectorManagerState
Expand Down Expand Up @@ -56,6 +67,7 @@ def __init__(
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
emit_team_events: bool,
model_context: ChatCompletionContext | None,
) -> None:
super().__init__(
name,
Expand All @@ -79,13 +91,19 @@ def __init__(
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
if model_context is not None:
self._model_context = model_context
else:
self._model_context = UnboundedChatCompletionContext()
self._cancellation_token = CancellationToken()

async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass

async def reset(self) -> None:
self._current_turn = 0
self._message_thread.clear()
await self._model_context.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._previous_speaker = None
Expand All @@ -101,16 +119,37 @@ async def save_state(self) -> Mapping[str, Any]:
async def load_state(self, state: Mapping[str, Any]) -> None:
selector_state = SelectorManagerState.model_validate(state)
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
await self._add_messages_to_context(
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
)
self._current_turn = selector_state.current_turn
self._previous_speaker = selector_state.previous_speaker

@staticmethod
async def _add_messages_to_context(
model_context: ChatCompletionContext,
messages: Sequence[BaseChatMessage],
) -> None:
"""
Add incoming messages to the model context.
"""
for msg in messages:
if isinstance(msg, HandoffMessage):
for llm_msg in msg.context:
await model_context.add_message(llm_msg)
await model_context.add_message(msg.to_model_message())

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
await self._add_messages_to_context(self._model_context, base_chat_messages)

async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.

A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""

# Use the selector function if provided.
if self._selector_func is not None:
if self._is_selector_func_async:
Expand Down Expand Up @@ -152,18 +191,6 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -

assert len(participants) > 0

# Construct the history of the conversation.
history_messages: List[str] = []
for msg in thread:
if not isinstance(msg, BaseChatMessage):
# Only process chat messages.
continue
message = f"{msg.source}: {msg.to_model_text()}"
history_messages.append(
message.rstrip() + "\n\n"
) # Create some consistency for how messages are separated in the transcript
history = "\n".join(history_messages)

# Construct agent roles.
# Each agent sould appear on a single line.
roles = ""
Expand All @@ -173,17 +200,39 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -

# Select the next speaker.
if len(participants) > 1:
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts)
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
else:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")
return agent_name

async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str:
def construct_message_history(
self, message_history: Sequence[Union[BaseChatMessage, BaseAgentEvent, UserMessage, AssistantMessage]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, there shouldn't be any BaseChatMessage or BaseAgentEvent in the input list right?

) -> str:
# Construct the history of the conversation.
history_messages: List[str] = []
for msg in message_history:
if isinstance(msg, BaseChatMessage):
message = f"{msg.source}: {msg.to_model_text()}"
history_messages.append(message.rstrip() + "\n\n")
elif isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage):
message = f"{msg.source}: {msg.content}"
history_messages.append(
message.rstrip() + "\n\n"
) # Create some consistency for how messages are separated in the transcript

history: str = "\n".join(history_messages)
return history

async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str:
model_context_messages = await self._model_context.get_messages()
model_context_history = self.construct_message_history(model_context_messages) # type: ignore

select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=history
roles=roles, participants=str(participants), history=model_context_history
)

select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
if ModelFamily.is_openai(self._model_client.model_info["family"]):
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
Expand Down Expand Up @@ -453,6 +502,7 @@ def __init__(
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
model_context: ChatCompletionContext | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to update the API doc (argument list) and include a code example of using a custom model context.

):
super().__init__(
participants,
Expand All @@ -473,6 +523,7 @@ def __init__(
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._model_context = model_context

def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -505,6 +556,7 @@ def _create_group_chat_manager_factory(
self._max_selector_attempts,
self._candidate_func,
self._emit_team_events,
self._model_context,
)

def _to_config(self) -> SelectorGroupChatConfig:
Expand Down
79 changes: 62 additions & 17 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@
import json
import logging
import tempfile
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, Dict, List, Mapping, Sequence

import pytest
import pytest_asyncio
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
from autogen_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
UserMessage,
)
from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import FileLogHandler

from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import (
AssistantAgent,
Expand All @@ -32,22 +49,6 @@
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
from autogen_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
UserMessage,
)
from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import FileLogHandler

logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -691,6 +692,50 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
assert result2 == result


@pytest.mark.asyncio
async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | None) -> None:
selector_group_chat_model_client = ReplayChatCompletionClient(["agent2", "agent1", "agent1", "agent2"])
agent_one_model_client = ReplayChatCompletionClient(
["[Agent One] First generation", "[Agent One] Second generation", "TERMINATE"]
)
agent_two_model_client = ReplayChatCompletionClient(
["[Agent Two] First generation", "[Agent Two] Second generation", "TERMINATE"]
)

agent1 = AssistantAgent("agent1", model_client=agent_one_model_client, description="Assistant agent 1")
agent2 = AssistantAgent("agent2", model_client=agent_two_model_client, description="Assistant agent 2")

termination = TextMentionTermination("TERMINATE")
team = SelectorGroupChat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the model context is being customized here?

participants=[agent1, agent2],
model_client=selector_group_chat_model_client,
termination_condition=termination,
runtime=runtime,
emit_team_events=True,
)
await team.run(
task="[GroupChat] Task",
)

messages_to_check = {
"1": "user: [GroupChat] Task",
"2": "agent2: [Agent Two] First generation",
"3": "agent1: [Agent One] First generation",
"4": "agent1: [Agent One] Second generation",
"5": "agent2: [Agent Two] Second generation",
}

create_calls: List[Dict[str, Any]] = selector_group_chat_model_client.create_calls
for idx, call in enumerate(create_calls):
messages = call["messages"]
prompt = messages[0].content
prompt_lines = prompt.split("\n")
chat_history = [value for _, value in list(messages_to_check.items())[: idx + 1]]
assert all(
line.strip() in prompt_lines for line in chat_history
), f"Expected all lines {chat_history} to be in prompt, but got {prompt_lines}"


@pytest.mark.asyncio
async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
Expand Down