|
1 | 1 | from datetime import datetime, timezone
|
2 |
| -from typing import List, Optional, Sequence |
| 2 | +from typing import Any, List, Optional, Sequence, Union |
3 | 3 |
|
4 | 4 | from langchain_core.messages import AnyMessage
|
| 5 | +from langchain_core.runnables import RunnableConfig |
5 | 6 |
|
6 | 7 | from app.agent import AgentType, get_agent_executor
|
7 | 8 | from app.lifespan import get_pg_pool
|
@@ -98,37 +99,36 @@ async def get_thread(user_id: str, thread_id: str) -> Optional[Thread]:
|
98 | 99 | )
|
99 | 100 |
|
100 | 101 |
|
101 |
| -async def get_thread_messages(user_id: str, thread_id: str): |
| 102 | +async def get_thread_state(user_id: str, thread_id: str): |
102 | 103 | """Get all messages for a thread."""
|
103 | 104 | app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
|
104 | 105 | state = await app.aget_state({"configurable": {"thread_id": thread_id}})
|
105 | 106 | return {
|
106 |
| - "messages": state.values, |
107 |
| - "resumeable": bool(state.next), |
| 107 | + "values": state.values, |
| 108 | + "next": state.next, |
108 | 109 | }
|
109 | 110 |
|
110 | 111 |
|
111 |
| -async def post_thread_messages( |
112 |
| - user_id: str, thread_id: str, messages: Sequence[AnyMessage] |
| 112 | +async def update_thread_state( |
| 113 | + config: RunnableConfig, messages: Union[Sequence[AnyMessage], dict[str, Any]] |
113 | 114 | ):
|
114 | 115 | """Add messages to a thread."""
|
115 | 116 | app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
|
116 |
| - await app.aupdate_state({"configurable": {"thread_id": thread_id}}, messages) |
| 117 | + return await app.aupdate_state(config, messages) |
117 | 118 |
|
118 | 119 |
|
119 | 120 | async def get_thread_history(user_id: str, thread_id: str):
|
120 | 121 | """Get the history of a thread."""
|
121 | 122 | app = get_agent_executor([], AgentType.GPT_35_TURBO, "", False)
|
| 123 | + config = {"configurable": {"thread_id": thread_id}} |
122 | 124 | return [
|
123 | 125 | {
|
124 | 126 | "values": c.values,
|
125 |
| - "resumeable": bool(c.next), |
| 127 | + "next": c.next, |
126 | 128 | "config": c.config,
|
127 | 129 | "parent": c.parent_config,
|
128 | 130 | }
|
129 |
| - async for c in app.aget_state_history( |
130 |
| - {"configurable": {"thread_id": thread_id}} |
131 |
| - ) |
| 131 | + async for c in app.aget_state_history(config) |
132 | 132 | ]
|
133 | 133 |
|
134 | 134 |
|
|
0 commit comments