-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add model_context
to SelectorGroupChat
for enhanced speaker selection
#6330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
352cc80
d839ec3
b37ec18
4f0b015
3bc42fe
35e1da9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,18 @@ | |
from inspect import iscoroutinefunction | ||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast | ||
|
||
from autogen_core import AgentRuntime, Component, ComponentModel | ||
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage | ||
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel | ||
from autogen_core.model_context import ( | ||
ChatCompletionContext, | ||
UnboundedChatCompletionContext, | ||
) | ||
from autogen_core.models import ( | ||
AssistantMessage, | ||
ChatCompletionClient, | ||
ModelFamily, | ||
SystemMessage, | ||
UserMessage, | ||
) | ||
from pydantic import BaseModel | ||
from typing_extensions import Self | ||
|
||
|
@@ -15,6 +25,7 @@ | |
from ...messages import ( | ||
BaseAgentEvent, | ||
BaseChatMessage, | ||
HandoffMessage, | ||
MessageFactory, | ||
) | ||
from ...state import SelectorManagerState | ||
|
@@ -56,6 +67,7 @@ def __init__( | |
max_selector_attempts: int, | ||
candidate_func: Optional[CandidateFuncType], | ||
emit_team_events: bool, | ||
model_context: ChatCompletionContext | None, | ||
) -> None: | ||
super().__init__( | ||
name, | ||
|
@@ -79,13 +91,19 @@ def __init__( | |
self._max_selector_attempts = max_selector_attempts | ||
self._candidate_func = candidate_func | ||
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) | ||
if model_context is not None: | ||
self._model_context = model_context | ||
else: | ||
self._model_context = UnboundedChatCompletionContext() | ||
self._cancellation_token = CancellationToken() | ||
|
||
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: | ||
pass | ||
|
||
async def reset(self) -> None: | ||
self._current_turn = 0 | ||
self._message_thread.clear() | ||
await self._model_context.clear() | ||
if self._termination_condition is not None: | ||
await self._termination_condition.reset() | ||
self._previous_speaker = None | ||
|
@@ -101,16 +119,37 @@ async def save_state(self) -> Mapping[str, Any]: | |
async def load_state(self, state: Mapping[str, Any]) -> None: | ||
selector_state = SelectorManagerState.model_validate(state) | ||
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread] | ||
await self._add_messages_to_context( | ||
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)] | ||
) | ||
self._current_turn = selector_state.current_turn | ||
self._previous_speaker = selector_state.previous_speaker | ||
|
||
@staticmethod | ||
async def _add_messages_to_context( | ||
model_context: ChatCompletionContext, | ||
messages: Sequence[BaseChatMessage], | ||
) -> None: | ||
""" | ||
Add incoming messages to the model context. | ||
""" | ||
for msg in messages: | ||
if isinstance(msg, HandoffMessage): | ||
for llm_msg in msg.context: | ||
await model_context.add_message(llm_msg) | ||
await model_context.add_message(msg.to_model_message()) | ||
|
||
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: | ||
self._message_thread.extend(messages) | ||
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)] | ||
await self._add_messages_to_context(self._model_context, base_chat_messages) | ||
|
||
async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: | ||
"""Selects the next speaker in a group chat using a ChatCompletion client, | ||
with the selector function as override if it returns a speaker name. | ||
|
||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name. | ||
""" | ||
|
||
# Use the selector function if provided. | ||
if self._selector_func is not None: | ||
if self._is_selector_func_async: | ||
|
@@ -152,18 +191,6 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - | |
|
||
assert len(participants) > 0 | ||
|
||
# Construct the history of the conversation. | ||
history_messages: List[str] = [] | ||
for msg in thread: | ||
if not isinstance(msg, BaseChatMessage): | ||
# Only process chat messages. | ||
continue | ||
message = f"{msg.source}: {msg.to_model_text()}" | ||
history_messages.append( | ||
message.rstrip() + "\n\n" | ||
) # Create some consistency for how messages are separated in the transcript | ||
history = "\n".join(history_messages) | ||
|
||
# Construct agent roles. | ||
# Each agent sould appear on a single line. | ||
roles = "" | ||
|
@@ -173,17 +200,39 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - | |
|
||
# Select the next speaker. | ||
if len(participants) > 1: | ||
agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts) | ||
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts) | ||
else: | ||
agent_name = participants[0] | ||
self._previous_speaker = agent_name | ||
trace_logger.debug(f"Selected speaker: {agent_name}") | ||
return agent_name | ||
|
||
async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str: | ||
def construct_message_history( | ||
self, message_history: Sequence[Union[BaseChatMessage, BaseAgentEvent, UserMessage, AssistantMessage]] | ||
) -> str: | ||
# Construct the history of the conversation. | ||
history_messages: List[str] = [] | ||
for msg in message_history: | ||
if isinstance(msg, BaseChatMessage): | ||
message = f"{msg.source}: {msg.to_model_text()}" | ||
history_messages.append(message.rstrip() + "\n\n") | ||
elif isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage): | ||
message = f"{msg.source}: {msg.content}" | ||
history_messages.append( | ||
message.rstrip() + "\n\n" | ||
) # Create some consistency for how messages are separated in the transcript | ||
|
||
history: str = "\n".join(history_messages) | ||
return history | ||
|
||
async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str: | ||
model_context_messages = await self._model_context.get_messages() | ||
model_context_history = self.construct_message_history(model_context_messages) # type: ignore | ||
|
||
select_speaker_prompt = self._selector_prompt.format( | ||
roles=roles, participants=str(participants), history=history | ||
roles=roles, participants=str(participants), history=model_context_history | ||
) | ||
|
||
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] | ||
if ModelFamily.is_openai(self._model_client.model_info["family"]): | ||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] | ||
|
@@ -453,6 +502,7 @@ def __init__( | |
candidate_func: Optional[CandidateFuncType] = None, | ||
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, | ||
emit_team_events: bool = False, | ||
model_context: ChatCompletionContext | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to update the API doc (argument list) and include a code example of using a custom model context. |
||
): | ||
super().__init__( | ||
participants, | ||
|
@@ -473,6 +523,7 @@ def __init__( | |
self._selector_func = selector_func | ||
self._max_selector_attempts = max_selector_attempts | ||
self._candidate_func = candidate_func | ||
self._model_context = model_context | ||
|
||
def _create_group_chat_manager_factory( | ||
self, | ||
|
@@ -505,6 +556,7 @@ def _create_group_chat_manager_factory( | |
self._max_selector_attempts, | ||
self._candidate_func, | ||
self._emit_team_events, | ||
self._model_context, | ||
) | ||
|
||
def _to_config(self) -> SelectorGroupChatConfig: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,27 @@ | |
import json | ||
import logging | ||
import tempfile | ||
from typing import Any, AsyncGenerator, List, Mapping, Sequence | ||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Sequence | ||
|
||
import pytest | ||
import pytest_asyncio | ||
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime | ||
from autogen_core.models import ( | ||
AssistantMessage, | ||
CreateResult, | ||
FunctionExecutionResult, | ||
FunctionExecutionResultMessage, | ||
LLMMessage, | ||
RequestUsage, | ||
UserMessage, | ||
) | ||
from autogen_core.tools import FunctionTool | ||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor | ||
from autogen_ext.models.openai import OpenAIChatCompletionClient | ||
from autogen_ext.models.replay import ReplayChatCompletionClient | ||
from pydantic import BaseModel | ||
from utils import FileLogHandler | ||
|
||
from autogen_agentchat import EVENT_LOGGER_NAME | ||
from autogen_agentchat.agents import ( | ||
AssistantAgent, | ||
|
@@ -32,22 +49,6 @@ | |
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager | ||
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager | ||
from autogen_agentchat.ui import Console | ||
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime | ||
from autogen_core.models import ( | ||
AssistantMessage, | ||
CreateResult, | ||
FunctionExecutionResult, | ||
FunctionExecutionResultMessage, | ||
LLMMessage, | ||
RequestUsage, | ||
UserMessage, | ||
) | ||
from autogen_core.tools import FunctionTool | ||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor | ||
from autogen_ext.models.openai import OpenAIChatCompletionClient | ||
from autogen_ext.models.replay import ReplayChatCompletionClient | ||
from pydantic import BaseModel | ||
from utils import FileLogHandler | ||
|
||
logger = logging.getLogger(EVENT_LOGGER_NAME) | ||
logger.setLevel(logging.DEBUG) | ||
|
@@ -691,6 +692,50 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None: | |
assert result2 == result | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | None) -> None: | ||
selector_group_chat_model_client = ReplayChatCompletionClient(["agent2", "agent1", "agent1", "agent2"]) | ||
agent_one_model_client = ReplayChatCompletionClient( | ||
["[Agent One] First generation", "[Agent One] Second generation", "TERMINATE"] | ||
) | ||
agent_two_model_client = ReplayChatCompletionClient( | ||
["[Agent Two] First generation", "[Agent Two] Second generation", "TERMINATE"] | ||
) | ||
|
||
agent1 = AssistantAgent("agent1", model_client=agent_one_model_client, description="Assistant agent 1") | ||
agent2 = AssistantAgent("agent2", model_client=agent_two_model_client, description="Assistant agent 2") | ||
|
||
termination = TextMentionTermination("TERMINATE") | ||
team = SelectorGroupChat( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the model context is being customized here? |
||
participants=[agent1, agent2], | ||
model_client=selector_group_chat_model_client, | ||
termination_condition=termination, | ||
runtime=runtime, | ||
emit_team_events=True, | ||
) | ||
await team.run( | ||
task="[GroupChat] Task", | ||
) | ||
|
||
messages_to_check = { | ||
"1": "user: [GroupChat] Task", | ||
"2": "agent2: [Agent Two] First generation", | ||
"3": "agent1: [Agent One] First generation", | ||
"4": "agent1: [Agent One] Second generation", | ||
"5": "agent2: [Agent Two] Second generation", | ||
} | ||
|
||
create_calls: List[Dict[str, Any]] = selector_group_chat_model_client.create_calls | ||
for idx, call in enumerate(create_calls): | ||
messages = call["messages"] | ||
prompt = messages[0].content | ||
prompt_lines = prompt.split("\n") | ||
chat_history = [value for _, value in list(messages_to_check.items())[: idx + 1]] | ||
assert all( | ||
line.strip() in prompt_lines for line in chat_history | ||
), f"Expected all lines {chat_history} to be in prompt, but got {prompt_lines}" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None) -> None: | ||
model_client = ReplayChatCompletionClient( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, there shouldn't be any
BaseChatMessage
orBaseAgentEvent
in the input list right?