From 0be461d8779492ab16a56ff268a539371f4ef28a Mon Sep 17 00:00:00 2001 From: Steven Cybinski <41087154+StevenCyb@users.noreply.github.com> Date: Mon, 15 Apr 2024 08:44:39 +0200 Subject: [PATCH 1/4] chore: chat message is not inner class anymore --- LLama.Examples/Examples/ChatChineseGB2312.cs | 2 +- .../Examples/ChatSessionStripRoleName.cs | 2 +- .../Examples/ChatSessionWithHistory.cs | 2 +- .../Examples/ChatSessionWithRestart.cs | 10 +-- .../Examples/ChatSessionWithRoleName.cs | 2 +- LLama.Examples/Examples/LoadAndSaveSession.cs | 2 +- LLama.WebAPI/Controllers/ChatController.cs | 2 +- LLama.WebAPI/Services/StatefulChatService.cs | 4 +- LLama/ChatSession.cs | 55 +++++++------ LLama/Common/ChatHistory.cs | 80 ++++++++++++------- 10 files changed, 93 insertions(+), 68 deletions(-) diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index f3a964b4c..6d367e044 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -105,7 +105,7 @@ in session.RegenerateAssistantMessageAsync( await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index b46c92e49..d9aa1a51f 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -46,7 +46,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 31b6a7718..58758b25c 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -92,7 +92,7 @@ in session.RegenerateAssistantMessageAsync( await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 923f78f67..6dc3d2e71 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -20,7 +20,7 @@ public static async Task Run() var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); - ChatSession prototypeSession = + ChatSession prototypeSession = await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( new string[] { "User:", "Assistant:" }, @@ -50,7 +50,7 @@ public static async Task Run() while (userInput != "exit") { // Load the session state from the reset state - if(userInput == "reset") + if (userInput == "reset") { session.LoadSession(resetState); Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); @@ -75,10 +75,10 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Provide assistant input: "); - + Console.ForegroundColor = ConsoleColor.Green; string assistantInputOverride = Console.ReadLine() ?? ""; - + await session.AddAndProcessUserMessage(userInputOverride); await session.AddAndProcessAssistantMessage(assistantInputOverride); @@ -90,7 +90,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index de3314130..9965e220c 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -41,7 +41,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, userInput), + new Message(AuthorRole.User, userInput), inferenceParams)) { Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama.Examples/Examples/LoadAndSaveSession.cs b/LLama.Examples/Examples/LoadAndSaveSession.cs index fded50e03..50530011c 100644 --- a/LLama.Examples/Examples/LoadAndSaveSession.cs +++ b/LLama.Examples/Examples/LoadAndSaveSession.cs @@ -33,7 +33,7 @@ public static async Task Run() await foreach ( var text in session.ChatAsync( - new ChatHistory.Message(AuthorRole.User, prompt), + new Message(AuthorRole.User, prompt), new InferenceParams() { Temperature = 0.6f, diff --git a/LLama.WebAPI/Controllers/ChatController.cs b/LLama.WebAPI/Controllers/ChatController.cs index 9643ccf80..012a6137e 100644 --- a/LLama.WebAPI/Controllers/ChatController.cs +++ b/LLama.WebAPI/Controllers/ChatController.cs @@ -43,7 +43,7 @@ public async Task SendHistory([FromBody] HistoryInput input, [FromServic { var history = new ChatHistory(); - var messages = input.Messages.Select(m => new ChatHistory.Message(Enum.Parse(m.Role), m.Content)); + var messages = input.Messages.Select(m => new Message(Enum.Parse(m.Role), m.Content)); history.Messages.AddRange(messages); diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ae2401c90..c4c706d76 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -46,7 +46,7 @@ public async Task Send(SendMessageInput input) } _logger.LogInformation("Input: {text}", input.Text); var outputs = _session.ChatAsync( - new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), + new Common.Message(Common.AuthorRole.User, input.Text), new Common.InferenceParams() { RepeatPenalty = 1.0f, @@ -74,7 +74,7 @@ public async IAsyncEnumerable SendStream(SendMessageInput input) _logger.LogInformation(input.Text); var outputs = _session.ChatAsync( - new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!) + new Common.Message(Common.AuthorRole.User, input.Text!) , new Common.InferenceParams() { RepeatPenalty = 1.0f, diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 0a5accc5e..b1f2d8270 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -165,7 +165,7 @@ public SessionState GetSessionState() { var executorState = ((StatefulExecutorBase)Executor).GetStateData(); return new SessionState( - executorState.PastTokensCount > 0 + executorState.PastTokensCount > 0 ? Executor.Context.GetState() : null, executorState, History, @@ -221,7 +221,7 @@ public void LoadSession(string path, bool loadTransforms = true) if (state.ExecutorState is null) { var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME); - ((StatefulExecutorBase) Executor).LoadState(filename: executorPath); + ((StatefulExecutorBase)Executor).LoadState(filename: executorPath); } LoadSession(state, loadTransforms); } @@ -231,7 +231,7 @@ public void LoadSession(string path, bool loadTransforms = true) /// /// /// - public ChatSession AddMessage(ChatHistory.Message message) + public ChatSession AddMessage(Message message) { // If current message is a system message, only allow the history to be empty if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) @@ -243,7 +243,7 @@ public ChatSession AddMessage(ChatHistory.Message message) // or the previous message to be a system message or assistant message. if (message.AuthorRole == AuthorRole.User) { - ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + Message? lastMessage = History.Messages.LastOrDefault(); if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); @@ -254,7 +254,7 @@ public ChatSession AddMessage(ChatHistory.Message message) // the previous message must be a user message. if (message.AuthorRole == AuthorRole.Assistant) { - ChatHistory.Message? lastMessage = History.Messages.LastOrDefault(); + Message? lastMessage = History.Messages.LastOrDefault(); if (lastMessage is null || lastMessage.AuthorRole != AuthorRole.User) { @@ -272,7 +272,7 @@ public ChatSession AddMessage(ChatHistory.Message message) /// /// public ChatSession AddSystemMessage(string content) - => AddMessage(new ChatHistory.Message(AuthorRole.System, content)); + => AddMessage(new Message(AuthorRole.System, content)); /// /// Add an assistant message to the chat history. @@ -280,7 +280,7 @@ public ChatSession AddSystemMessage(string content) /// /// public ChatSession AddAssistantMessage(string content) - => AddMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + => AddMessage(new Message(AuthorRole.Assistant, content)); /// /// Add a user message to the chat history. @@ -288,7 +288,7 @@ public ChatSession AddAssistantMessage(string content) /// /// public ChatSession AddUserMessage(string content) - => AddMessage(new ChatHistory.Message(AuthorRole.User, content)); + => AddMessage(new Message(AuthorRole.User, content)); /// /// Remove the last message from the chat history. @@ -305,7 +305,7 @@ public ChatSession RemoveLastMessage() /// /// /// - public async Task AddAndProcessMessage(ChatHistory.Message message) + public async Task AddAndProcessMessage(Message message) { if (Executor is not StatefulExecutorBase statefulExecutor) { @@ -329,19 +329,19 @@ public async Task AddAndProcessMessage(ChatHistory.Message message) /// Compute KV cache for the system message and add it to the chat history. /// public Task AddAndProcessSystemMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + => AddAndProcessMessage(new Message(AuthorRole.System, content)); /// /// Compute KV cache for the user message and add it to the chat history. /// public Task AddAndProcessUserMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + => AddAndProcessMessage(new Message(AuthorRole.User, content)); /// /// Compute KV cache for the assistant message and add it to the chat history. /// public Task AddAndProcessAssistantMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + => AddAndProcessMessage(new Message(AuthorRole.Assistant, content)); /// /// Replace a user message with a new message and remove all messages after the new message. @@ -351,8 +351,8 @@ public Task AddAndProcessAssistantMessage(string content) /// /// public ChatSession ReplaceUserMessage( - ChatHistory.Message oldMessage, - ChatHistory.Message newMessage) + Message oldMessage, + Message newMessage) { if (oldMessage.AuthorRole != AuthorRole.User) { @@ -388,7 +388,7 @@ public ChatSession ReplaceUserMessage( /// /// public async IAsyncEnumerable ChatAsync( - ChatHistory.Message message, + Message message, bool applyInputTransformPipeline, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) @@ -460,7 +460,7 @@ in ChatAsyncInternal( /// /// public IAsyncEnumerable ChatAsync( - ChatHistory.Message message, + Message message, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { @@ -486,11 +486,11 @@ public IAsyncEnumerable ChatAsync( IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { - ChatHistory.Message lastMessage = history.Messages.LastOrDefault() + Message lastMessage = history.Messages.LastOrDefault() ?? throw new ArgumentException("History must contain at least one message", nameof(history)); foreach ( - ChatHistory.Message message + Message message in history.Messages.Take(history.Messages.Count - 1)) { // Apply input transform pipeline @@ -546,7 +546,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Make sure the last message is an assistant message (reponse from the LLM). - ChatHistory.Message? lastAssistantMessage = History.Messages.LastOrDefault(); + Message? lastAssistantMessage = History.Messages.LastOrDefault(); if (lastAssistantMessage is null || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) @@ -558,7 +558,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync( RemoveLastMessage(); // Get the last user message. - ChatHistory.Message? lastUserMessage = History.Messages.LastOrDefault(); + Message? lastUserMessage = History.Messages.LastOrDefault(); if (lastUserMessage is null || lastUserMessage.AuthorRole != AuthorRole.User) @@ -629,11 +629,11 @@ public record SessionState /// The history transform used in this session. /// public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); - + /// /// The the chat history messages for this session. /// - public ChatHistory.Message[] History { get; set; } = Array.Empty(); + public Message[] History { get; set; } = Array.Empty(); /// /// Create a new session state. @@ -645,7 +645,7 @@ public record SessionState /// /// public SessionState( - State? contextState, ExecutorBaseState executorState, + State? contextState, ExecutorBaseState executorState, ChatHistory history, List inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform) { @@ -717,7 +717,7 @@ public static SessionState Load(string path) } string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); - var contextState = File.Exists(modelStateFilePath) ? + var contextState = File.Exists(modelStateFilePath) ? State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) : null; @@ -733,7 +733,7 @@ public static SessionState Load(string path) ITextTransform[] inputTransforms; try { - inputTransforms = File.Exists(inputTransformFilepath) ? + inputTransforms = File.Exists(inputTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(inputTransformFilepath)) ?? throw new ArgumentException("Input transform file is invalid", nameof(path))) : Array.Empty(); @@ -744,11 +744,10 @@ public static SessionState Load(string path) } string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME); - ITextStreamTransform outputTransform; try { - outputTransform = File.Exists(outputTransformFilepath) ? + outputTransform = File.Exists(outputTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(outputTransformFilepath)) ?? throw new ArgumentException("Output transform file is invalid", nameof(path))) : new LLamaTransforms.EmptyTextOutputStreamTransform(); @@ -762,7 +761,7 @@ public static SessionState Load(string path) IHistoryTransform historyTransform; try { - historyTransform = File.Exists(historyTransformFilepath) ? + historyTransform = File.Exists(historyTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) ?? throw new ArgumentException("History transform file is invalid", nameof(path))) : new LLamaTransforms.DefaultHistoryTransform(); diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index c22cc7c06..01f7cbacd 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -35,39 +35,65 @@ public enum AuthorRole /// /// The chat history class /// - public class ChatHistory + public class Message { - private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; + /// + /// Role of the message author, e.g. user/assistant/system + /// + [JsonConverter(typeof(JsonStringEnumConverter))] + [JsonPropertyName("author_role")] + public AuthorRole AuthorRole { get; set; } /// - /// Chat message representation + /// Message content /// - public class Message + [JsonPropertyName("content")] + public string Content { get; set; } + + /// + /// Create a new instance + /// + /// Role of message author + /// Message content + public Message(AuthorRole authorRole, string content) { - /// - /// Role of the message author, e.g. user/assistant/system - /// - [JsonConverter(typeof(JsonStringEnumConverter))] - [JsonPropertyName("author_role")] - public AuthorRole AuthorRole { get; set; } - - /// - /// Message content - /// - [JsonPropertyName("content")] - public string Content { get; set; } - - /// - /// Create a new instance - /// - /// Role of message author - /// Message content - public Message(AuthorRole authorRole, string content) - { - this.AuthorRole = authorRole; - this.Content = content; - } + this.AuthorRole = authorRole; + this.Content = content; } + } + + /// + /// Interface for chat history + /// + public interface IChatHistory + { + /// + /// List of messages in the chat + /// + List Messages { get; set; } + + /// + /// Add a message to the chat history + /// + /// Role of the message author + /// Message content + void AddMessage(AuthorRole authorRole, string content); + + /// + /// Serialize the chat history to JSON + /// + /// + string ToJson(); + } + + + // copy from semantic-kernel + /// + /// The chat history class + /// + public class ChatHistory : IChatHistory + { + private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; /// /// List of messages in the chat From d0d09d026c86050cade56d24f7011ae577980ddd Mon Sep 17 00:00:00 2001 From: Steven Cybinski <41087154+StevenCyb@users.noreply.github.com> Date: Mon, 15 Apr 2024 09:29:14 +0200 Subject: [PATCH 2/4] chore: move chat history serialization out to a generic class `ChatHistory` --- LLama.Examples/Examples/ChatChineseGB2312.cs | 2 +- .../Examples/ChatSessionStripRoleName.cs | 2 +- .../Examples/ChatSessionWithHistory.cs | 2 +- .../Examples/ChatSessionWithRestart.cs | 2 +- .../Examples/ChatSessionWithRoleName.cs | 2 +- LLama/ChatSession.cs | 4 +-- LLama/Common/ChatHistory.cs | 25 ++++++++++--------- 7 files changed, 20 insertions(+), 19 deletions(-) diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index 6d367e044..50cc1e8a5 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -49,7 +49,7 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index d9aa1a51f..6739c041d 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -21,7 +21,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 58758b25c..3abada19c 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -31,7 +31,7 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 6dc3d2e71..0b0510055 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession prototypeSession = await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index 9965e220c..59fdede89 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory(); + ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index b1f2d8270..6d3db2ad5 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -686,7 +686,7 @@ public void Save(string path) File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); - File.WriteAllText(historyFilepath, new ChatHistory(History).ToJson()); + File.WriteAllText(historyFilepath, ChatHistorySerializer.ToJson(new ChatHistory(History))); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); @@ -726,7 +726,7 @@ public static SessionState Load(string path) string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); string historyJson = File.ReadAllText(historyFilepath); - var history = ChatHistory.FromJson(historyJson) + var history = ChatHistorySerializer.FromJson(historyJson) ?? throw new ArgumentException("History file is invalid", nameof(path)); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index 01f7cbacd..af39d676e 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -33,7 +33,7 @@ public enum AuthorRole // copy from semantic-kernel /// - /// The chat history class + /// The message class /// public class Message { @@ -78,15 +78,8 @@ public interface IChatHistory /// Role of the message author /// Message content void AddMessage(AuthorRole authorRole, string content); - - /// - /// Serialize the chat history to JSON - /// - /// - string ToJson(); } - // copy from semantic-kernel /// /// The chat history class @@ -125,14 +118,22 @@ public void AddMessage(AuthorRole authorRole, string content) { this.Messages.Add(new Message(authorRole, content)); } + } + + /// + /// Serializer for chat history + /// + public class ChatHistorySerializer where T : IChatHistory + { + private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; /// /// Serialize the chat history to JSON /// /// - public string ToJson() + public static string ToJson(T chatHistory) { - return JsonSerializer.Serialize(this, _jsonOptions); + return JsonSerializer.Serialize(chatHistory, _jsonOptions); } /// @@ -140,9 +141,9 @@ public string ToJson() /// /// /// - public static ChatHistory? FromJson(string json) + public static T? FromJson(string json) { - return JsonSerializer.Deserialize(json); + return JsonSerializer.Deserialize(json); } } } From 24def66ae916b4b203ea97597df8f8a238d97abc Mon Sep 17 00:00:00 2001 From: Steven Cybinski <41087154+StevenCyb@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:16:57 +0200 Subject: [PATCH 3/4] chore: use chat history interface --- LLama.Examples/Examples/ChatChineseGB2312.cs | 4 +- .../Examples/ChatSessionStripRoleName.cs | 2 +- .../Examples/ChatSessionWithHistory.cs | 2 +- .../Examples/ChatSessionWithRestart.cs | 2 +- .../Examples/ChatSessionWithRoleName.cs | 2 +- .../ChatCompletion/HistoryTransform.cs | 4 +- LLama.WebAPI/Services/StatelessChatService.cs | 5 +-- LLama/Abstractions/IHistoryTransform.cs | 6 +-- LLama/ChatSession.cs | 38 +++++++++---------- LLama/Common/ChatHistory.cs | 8 ++-- LLama/LLamaTransforms.cs | 13 ++++--- 11 files changed, 43 insertions(+), 43 deletions(-) diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index 50cc1e8a5..27a6f0519 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -49,13 +49,13 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); - ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } session - .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); + .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); InferenceParams inferenceParams = new InferenceParams() { diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index 6739c041d..e8f7e13a1 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -21,7 +21,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index 3abada19c..f0003f983 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -31,7 +31,7 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index 0b0510055..e588f6fda 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession prototypeSession = await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index 59fdede89..8ec972a3c 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - ChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs index f1a0ebcb6..04605e39e 100644 --- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs +++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs @@ -7,10 +7,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; /// /// Default HistoryTransform Patch /// -public class HistoryTransform : DefaultHistoryTransform +public class HistoryTransform : DefaultHistoryTransform { /// - public override string HistoryToText(global::LLama.Common.ChatHistory history) + public string HistoryToText(global::LLama.Common.ChatHistory history) { return base.HistoryToText(history) + $"{AuthorRole.Assistant}: "; } diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index 3520c29b0..98b662b6a 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -46,12 +46,11 @@ public async Task SendAsync(ChatHistory history) } } - public class HistoryTransform : DefaultHistoryTransform + public class HistoryTransform : DefaultHistoryTransform { - public override string HistoryToText(ChatHistory history) + public override string HistoryToText(IChatHistory history) { return base.HistoryToText(history) + "\n Assistant:"; } - } } diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index 9644b3e1d..d85d8d5b4 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -14,15 +14,15 @@ public interface IHistoryTransform /// /// The ChatHistory instance /// - string HistoryToText(ChatHistory history); - + string HistoryToText(IChatHistory history); + /// /// Converts plain text to a ChatHistory instance. /// /// The role for the author. /// The chat history as plain text. /// The updated history. - ChatHistory TextToHistory(AuthorRole role, string text); + IChatHistory TextToHistory(AuthorRole role, string text); /// /// Copy the transform. diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 6d3db2ad5..351ed36e9 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -52,12 +52,12 @@ public class ChatSession /// /// The chat history for this session. /// - public ChatHistory History { get; private set; } = new(); + public IChatHistory History { get; private set; } = new ChatHistory(); /// /// The history transform used in this session. /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The input transform pipeline used in this session. @@ -73,17 +73,17 @@ public class ChatSession /// Create a new chat session and preprocess history. /// /// The executor for this session - /// History for this session + /// History for this session /// public static async Task InitializeSessionFromHistoryAsync( - ILLamaExecutor executor, ChatHistory history) + ILLamaExecutor executor, IChatHistory chatHistory) { if (executor is not StatefulExecutorBase statefulExecutor) { throw new ArgumentException("Executor must have a StatefulExecutorBase", nameof(executor)); } - var session = new ChatSession(executor, history); - await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); + var session = new ChatSession(executor, chatHistory); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(chatHistory)); return session; } @@ -107,7 +107,7 @@ public ChatSession(ILLamaExecutor executor) /// /// /// - public ChatSession(ILLamaExecutor executor, ChatHistory history) + public ChatSession(ILLamaExecutor executor, IChatHistory history) : this(executor) { History = history; @@ -198,7 +198,7 @@ public void LoadSession(SessionState state, bool loadTransforms = true) { Executor.Context.LoadState(state.ContextState); } - History = new ChatHistory(state.History); + History = state.SessionChatHistory; if (loadTransforms) { InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); @@ -431,7 +431,7 @@ public async IAsyncEnumerable ChatAsync( // If the session was restored from a previous session, // convert only the current message to the prompt with the prompt template // specified in the HistoryTransform class implementation that is provided. - ChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); + IChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); prompt = HistoryTransform.HistoryToText(singleMessageHistory); } @@ -481,7 +481,7 @@ public IAsyncEnumerable ChatAsync( /// /// public IAsyncEnumerable ChatAsync( - ChatHistory history, + IChatHistory history, bool applyInputTransformPipeline, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) @@ -523,7 +523,7 @@ var inputTransform /// /// public IAsyncEnumerable ChatAsync( - ChatHistory history, + IChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default) { @@ -628,30 +628,30 @@ public record SessionState /// /// The history transform used in this session. /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The the chat history messages for this session. /// - public Message[] History { get; set; } = Array.Empty(); + public IChatHistory SessionChatHistory; /// /// Create a new session state. /// /// /// - /// + /// /// /// /// public SessionState( State? contextState, ExecutorBaseState executorState, - ChatHistory history, List inputTransformPipeline, + IChatHistory chatHistory, List inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform) { ContextState = contextState; ExecutorState = executorState; - History = history.Messages.ToArray(); + SessionChatHistory = chatHistory; InputTransformPipeline = inputTransformPipeline.Select(t => t.Clone()).ToArray(); OutputTransform = outputTransform.Clone(); HistoryTransform = historyTransform.Clone(); @@ -686,7 +686,7 @@ public void Save(string path) File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); - File.WriteAllText(historyFilepath, ChatHistorySerializer.ToJson(new ChatHistory(History))); + File.WriteAllText(historyFilepath, ChatHistorySerializer.ToJson(new ChatHistory(SessionChatHistory.Messages.ToArray()))); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); File.WriteAllText(inputTransformFilepath, JsonSerializer.Serialize(InputTransformPipeline)); @@ -726,7 +726,7 @@ public static SessionState Load(string path) string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); string historyJson = File.ReadAllText(historyFilepath); - var history = ChatHistorySerializer.FromJson(historyJson) + var history = ChatHistorySerializer.FromJson(historyJson) ?? throw new ArgumentException("History file is invalid", nameof(path)); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); @@ -764,7 +764,7 @@ public static SessionState Load(string path) historyTransform = File.Exists(historyTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) ?? throw new ArgumentException("History transform file is invalid", nameof(path))) - : new LLamaTransforms.DefaultHistoryTransform(); + : new LLamaTransforms.DefaultHistoryTransform(); } catch (JsonException) { diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index af39d676e..755458452 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -123,7 +123,7 @@ public void AddMessage(AuthorRole authorRole, string content) /// /// Serializer for chat history /// - public class ChatHistorySerializer where T : IChatHistory + public class ChatHistorySerializer { private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; @@ -131,7 +131,7 @@ public class ChatHistorySerializer where T : IChatHistory /// Serialize the chat history to JSON /// /// - public static string ToJson(T chatHistory) + public static string ToJson(IChatHistory chatHistory) { return JsonSerializer.Serialize(chatHistory, _jsonOptions); } @@ -141,9 +141,9 @@ public static string ToJson(T chatHistory) /// /// /// - public static T? FromJson(string json) + public static IChatHistory? FromJson(string json) { - return JsonSerializer.Deserialize(json); + return JsonSerializer.Deserialize(json); } } } diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index d74d9ddaf..f04d98b19 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -1,5 +1,6 @@ using LLama.Abstractions; using LLama.Common; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -17,7 +18,7 @@ public class LLamaTransforms /// Uses plain text with the following format: /// [Author]: [Message] /// - public class DefaultHistoryTransform : IHistoryTransform + public class DefaultHistoryTransform : IHistoryTransform where T : IChatHistory { private const string defaultUserName = "User"; private const string defaultAssistantName = "Assistant"; @@ -44,7 +45,7 @@ public class DefaultHistoryTransform : IHistoryTransform /// /// /// - public DefaultHistoryTransform(string? userName = null, string? assistantName = null, + public DefaultHistoryTransform(string? userName = null, string? assistantName = null, string? systemName = null, string? unknownName = null, bool isInstructMode = false) { _userName = userName ?? defaultUserName; @@ -57,11 +58,11 @@ public DefaultHistoryTransform(string? userName = null, string? assistantName = /// public IHistoryTransform Clone() { - return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); + return (IHistoryTransform)new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); } /// - public virtual string HistoryToText(ChatHistory history) + public virtual string HistoryToText(IChatHistory history) { StringBuilder sb = new(); foreach (var message in history.Messages) @@ -87,9 +88,9 @@ public virtual string HistoryToText(ChatHistory history) } /// - public virtual ChatHistory TextToHistory(AuthorRole role, string text) + public virtual IChatHistory TextToHistory(AuthorRole role, string text) { - ChatHistory history = new ChatHistory(); + T history = Activator.CreateInstance(); history.AddMessage(role, TrimNamesFromText(text, role)); return history; } From 8672de429ded0c6f212f16d7537d79a0179e5a40 Mon Sep 17 00:00:00 2001 From: Steven Cybinski <41087154+StevenCyb@users.noreply.github.com> Date: Mon, 15 Apr 2024 17:39:37 +0200 Subject: [PATCH 4/4] fix: a bug with chat history serialization --- LLama.Examples/Examples/ChatChineseGB2312.cs | 4 +- .../Examples/ChatSessionStripRoleName.cs | 2 +- .../Examples/ChatSessionWithHistory.cs | 2 +- .../Examples/ChatSessionWithRestart.cs | 4 +- .../Examples/ChatSessionWithRoleName.cs | 2 +- .../ChatCompletion/HistoryTransform.cs | 2 +- LLama.WebAPI/Services/StatefulChatService.cs | 2 +- LLama.WebAPI/Services/StatelessChatService.cs | 2 +- LLama/Abstractions/IHistoryTransform.cs | 4 +- LLama/ChatSession.cs | 44 +++++++++---------- LLama/Common/ChatHistory.cs | 8 ++-- LLama/LLamaTransforms.cs | 8 ++-- 12 files changed, 44 insertions(+), 40 deletions(-) diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index 27a6f0519..2d6bf1785 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -49,13 +49,13 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-kunkun-chinese.json"); - IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } session - .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); + .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); InferenceParams inferenceParams = new InferenceParams() { diff --git a/LLama.Examples/Examples/ChatSessionStripRoleName.cs b/LLama.Examples/Examples/ChatSessionStripRoleName.cs index e8f7e13a1..93f5c394d 100644 --- a/LLama.Examples/Examples/ChatSessionStripRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionStripRoleName.cs @@ -21,7 +21,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( diff --git a/LLama.Examples/Examples/ChatSessionWithHistory.cs b/LLama.Examples/Examples/ChatSessionWithHistory.cs index f0003f983..4e96acdef 100644 --- a/LLama.Examples/Examples/ChatSessionWithHistory.cs +++ b/LLama.Examples/Examples/ChatSessionWithHistory.cs @@ -31,7 +31,7 @@ public static async Task Run() else { var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); session = new ChatSession(executor, chatHistory); } diff --git a/LLama.Examples/Examples/ChatSessionWithRestart.cs b/LLama.Examples/Examples/ChatSessionWithRestart.cs index e588f6fda..c8bac7f90 100644 --- a/LLama.Examples/Examples/ChatSessionWithRestart.cs +++ b/LLama.Examples/Examples/ChatSessionWithRestart.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); ChatSession prototypeSession = await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory); prototypeSession.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform( @@ -53,7 +53,7 @@ public static async Task Run() if (userInput == "reset") { session.LoadSession(resetState); - Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.History)}"); + Console.WriteLine($"Reset to history:\n{session.HistoryTransform.HistoryToText(session.SessionChatHistory)}"); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("Session reset."); } diff --git a/LLama.Examples/Examples/ChatSessionWithRoleName.cs b/LLama.Examples/Examples/ChatSessionWithRoleName.cs index 8ec972a3c..080e9544e 100644 --- a/LLama.Examples/Examples/ChatSessionWithRoleName.cs +++ b/LLama.Examples/Examples/ChatSessionWithRoleName.cs @@ -19,7 +19,7 @@ public static async Task Run() var executor = new InteractiveExecutor(context); var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json"); - IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson) ?? new ChatHistory(); + IChatHistory chatHistory = ChatHistorySerializer.FromJson(chatHistoryJson, typeof(ChatHistory)) ?? new ChatHistory(); ChatSession session = new(executor, chatHistory); diff --git a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs index 04605e39e..4cb803b81 100644 --- a/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs +++ b/LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs @@ -7,7 +7,7 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion; /// /// Default HistoryTransform Patch /// -public class HistoryTransform : DefaultHistoryTransform +public class HistoryTransform : DefaultHistoryTransform { /// public string HistoryToText(global::LLama.Common.ChatHistory history) diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index c4c706d76..0691ce18e 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -28,7 +28,7 @@ public StatefulChatService(IConfiguration configuration, ILogger SendAsync(ChatHistory history) } } - public class HistoryTransform : DefaultHistoryTransform + public class HistoryTransform : DefaultHistoryTransform { public override string HistoryToText(IChatHistory history) { diff --git a/LLama/Abstractions/IHistoryTransform.cs b/LLama/Abstractions/IHistoryTransform.cs index d85d8d5b4..a5ff28cc9 100644 --- a/LLama/Abstractions/IHistoryTransform.cs +++ b/LLama/Abstractions/IHistoryTransform.cs @@ -1,4 +1,5 @@ using LLama.Common; +using System; using System.Text.Json.Serialization; namespace LLama.Abstractions @@ -21,8 +22,9 @@ public interface IHistoryTransform /// /// The role for the author. /// The chat history as plain text. + /// The type of the chat history. /// The updated history. - IChatHistory TextToHistory(AuthorRole role, string text); + IChatHistory TextToHistory(AuthorRole role, string text, Type type); /// /// Copy the transform. diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 351ed36e9..db18cbc04 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -52,12 +52,12 @@ public class ChatSession /// /// The chat history for this session. /// - public IChatHistory History { get; private set; } = new ChatHistory(); + public IChatHistory SessionChatHistory { get; private set; } = new ChatHistory(); /// /// The history transform used in this session. /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The input transform pipeline used in this session. @@ -110,7 +110,7 @@ public ChatSession(ILLamaExecutor executor) public ChatSession(ILLamaExecutor executor, IChatHistory history) : this(executor) { - History = history; + SessionChatHistory = history; } /// @@ -168,7 +168,7 @@ public SessionState GetSessionState() executorState.PastTokensCount > 0 ? Executor.Context.GetState() : null, executorState, - History, + SessionChatHistory, InputTransformPipeline, OutputTransform, HistoryTransform); @@ -198,7 +198,7 @@ public void LoadSession(SessionState state, bool loadTransforms = true) { Executor.Context.LoadState(state.ContextState); } - History = state.SessionChatHistory; + SessionChatHistory = state.SessionChatHistory; if (loadTransforms) { InputTransformPipeline = state.InputTransformPipeline.Select(t => t.Clone()).ToList(); @@ -216,7 +216,7 @@ public void LoadSession(SessionState state, bool loadTransforms = true) /// public void LoadSession(string path, bool loadTransforms = true) { - var state = SessionState.Load(path); + var state = SessionState.Load(path, this.SessionChatHistory.GetType()); // Handle non-polymorphic serialization of executor state if (state.ExecutorState is null) { @@ -234,7 +234,7 @@ public void LoadSession(string path, bool loadTransforms = true) public ChatSession AddMessage(Message message) { // If current message is a system message, only allow the history to be empty - if (message.AuthorRole == AuthorRole.System && History.Messages.Count > 0) + if (message.AuthorRole == AuthorRole.System && SessionChatHistory.Messages.Count > 0) { throw new ArgumentException("Cannot add a system message after another message", nameof(message)); } @@ -243,7 +243,7 @@ public ChatSession AddMessage(Message message) // or the previous message to be a system message or assistant message. if (message.AuthorRole == AuthorRole.User) { - Message? lastMessage = History.Messages.LastOrDefault(); + Message? lastMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastMessage is not null && lastMessage.AuthorRole == AuthorRole.User) { throw new ArgumentException("Cannot add a user message after another user message", nameof(message)); @@ -254,7 +254,7 @@ public ChatSession AddMessage(Message message) // the previous message must be a user message. if (message.AuthorRole == AuthorRole.Assistant) { - Message? lastMessage = History.Messages.LastOrDefault(); + Message? lastMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastMessage is null || lastMessage.AuthorRole != AuthorRole.User) { @@ -262,7 +262,7 @@ public ChatSession AddMessage(Message message) } } - History.AddMessage(message.AuthorRole, message.Content); + SessionChatHistory.AddMessage(message.AuthorRole, message.Content); return this; } @@ -296,7 +296,7 @@ public ChatSession AddUserMessage(string content) /// public ChatSession RemoveLastMessage() { - History.Messages.RemoveAt(History.Messages.Count - 1); + SessionChatHistory.Messages.RemoveAt(SessionChatHistory.Messages.Count - 1); return this; } @@ -364,16 +364,16 @@ public ChatSession ReplaceUserMessage( throw new ArgumentException("New message must be a user message", nameof(newMessage)); } - int index = History.Messages.IndexOf(oldMessage); + int index = SessionChatHistory.Messages.IndexOf(oldMessage); if (index == -1) { throw new ArgumentException("Old message does not exist in history", nameof(oldMessage)); } - History.Messages[index] = newMessage; + SessionChatHistory.Messages[index] = newMessage; // Remove all message after the new message - History.Messages.RemoveRange(index + 1, History.Messages.Count - index - 1); + SessionChatHistory.Messages.RemoveRange(index + 1, SessionChatHistory.Messages.Count - index - 1); return this; } @@ -424,14 +424,14 @@ public async IAsyncEnumerable ChatAsync( // If the session history was added as part of new chat session history, // convert the complete history includsing system message and manually added history // to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation. - prompt = HistoryTransform.HistoryToText(History); + prompt = HistoryTransform.HistoryToText(SessionChatHistory); } else { // If the session was restored from a previous session, // convert only the current message to the prompt with the prompt template // specified in the HistoryTransform class implementation that is provided. - IChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content); + IChatHistory singleMessageHistory = HistoryTransform.TextToHistory(message.AuthorRole, message.Content, SessionChatHistory.GetType()); prompt = HistoryTransform.HistoryToText(singleMessageHistory); } @@ -546,7 +546,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync( [EnumeratorCancellation] CancellationToken cancellationToken = default) { // Make sure the last message is an assistant message (reponse from the LLM). - Message? lastAssistantMessage = History.Messages.LastOrDefault(); + Message? lastAssistantMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastAssistantMessage is null || lastAssistantMessage.AuthorRole != AuthorRole.Assistant) @@ -558,7 +558,7 @@ public async IAsyncEnumerable RegenerateAssistantMessageAsync( RemoveLastMessage(); // Get the last user message. - Message? lastUserMessage = History.Messages.LastOrDefault(); + Message? lastUserMessage = SessionChatHistory.Messages.LastOrDefault(); if (lastUserMessage is null || lastUserMessage.AuthorRole != AuthorRole.User) @@ -628,7 +628,7 @@ public record SessionState /// /// The history transform used in this session. /// - public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); + public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform(); /// /// The the chat history messages for this session. @@ -704,7 +704,7 @@ public void Save(string path) /// /// /// Throws when session state is incorrect - public static SessionState Load(string path) + public static SessionState Load(string path, Type type) { if (string.IsNullOrWhiteSpace(path)) { @@ -726,7 +726,7 @@ public static SessionState Load(string path) string historyFilepath = Path.Combine(path, ChatSession.HISTORY_STATE_FILENAME); string historyJson = File.ReadAllText(historyFilepath); - var history = ChatHistorySerializer.FromJson(historyJson) + var history = ChatHistorySerializer.FromJson(historyJson, type) ?? throw new ArgumentException("History file is invalid", nameof(path)); string inputTransformFilepath = Path.Combine(path, ChatSession.INPUT_TRANSFORM_FILENAME); @@ -764,7 +764,7 @@ public static SessionState Load(string path) historyTransform = File.Exists(historyTransformFilepath) ? (JsonSerializer.Deserialize(File.ReadAllText(historyTransformFilepath)) ?? throw new ArgumentException("History transform file is invalid", nameof(path))) - : new LLamaTransforms.DefaultHistoryTransform(); + : new LLamaTransforms.DefaultHistoryTransform(); } catch (JsonException) { diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index 755458452..052d1b538 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; @@ -140,10 +141,11 @@ public static string ToJson(IChatHistory chatHistory) /// Deserialize a chat history from JSON /// /// + /// /// - public static IChatHistory? FromJson(string json) + public static IChatHistory? FromJson(string json, Type type) { - return JsonSerializer.Deserialize(json); + return JsonSerializer.Deserialize(json, type) as IChatHistory; } } } diff --git a/LLama/LLamaTransforms.cs b/LLama/LLamaTransforms.cs index f04d98b19..6ed86e777 100644 --- a/LLama/LLamaTransforms.cs +++ b/LLama/LLamaTransforms.cs @@ -18,7 +18,7 @@ public class LLamaTransforms /// Uses plain text with the following format: /// [Author]: [Message] /// - public class DefaultHistoryTransform : IHistoryTransform where T : IChatHistory + public class DefaultHistoryTransform : IHistoryTransform { private const string defaultUserName = "User"; private const string defaultAssistantName = "Assistant"; @@ -58,7 +58,7 @@ public DefaultHistoryTransform(string? userName = null, string? assistantName = /// public IHistoryTransform Clone() { - return (IHistoryTransform)new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); + return new DefaultHistoryTransform(_userName, _assistantName, _systemName, _unknownName, _isInstructMode); } /// @@ -88,9 +88,9 @@ public virtual string HistoryToText(IChatHistory history) } /// - public virtual IChatHistory TextToHistory(AuthorRole role, string text) + public virtual IChatHistory TextToHistory(AuthorRole role, string text, Type type) { - T history = Activator.CreateInstance(); + IChatHistory history = (IChatHistory)(Activator.CreateInstance(type) ?? new ChatHistory()); history.AddMessage(role, TrimNamesFromText(text, role)); return history; }