Skip to content

Commit 97fb2c3

Browse files
Python: Allow Kernel Functions from Prompt for image and audio content (#11403)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> I noticed that even though the input and prompt rendering match what you want to use for image and audio generation, we didn't support that. This introduces just that, with two samples. This unlocks the following scenario's: - Running text to speech pipelines with set intro/outro statements - Creating function calls for image generation with limited scope and a lot of set pieces. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> - Adds a `get_image_content` method to the TextToImageClientBase class - Adds the option to select a TextToImage or TextToAudio client in the service selector (only for non-streaming) - Adds branches in the KernelFunctionFromPrompt _invoke_internal for those types. - Adds handling the output as a FunctionResult - Adds samples for both ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 81e89f0 commit 97fb2c3

File tree

11 files changed

+233
-48
lines changed

11 files changed

+233
-48
lines changed

Diff for: python/samples/concepts/audio/audio_from_prompt.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
5+
from samples.concepts.audio.audio_player import AudioPlayer
6+
from semantic_kernel import Kernel
7+
from semantic_kernel.connectors.ai import PromptExecutionSettings
8+
from semantic_kernel.connectors.ai.open_ai import OpenAITextToAudio
9+
from semantic_kernel.functions import KernelArguments
10+
11+
"""
12+
This simple sample demonstrates how to use the AzureTextToAudio services
13+
with a prompt and prompt rendering.
14+
15+
Resources required for this sample: An Azure Text to Speech deployment (e.g. tts).
16+
17+
Additional dependencies required for this sample:
18+
- pyaudio: run `pip install pyaudio` or `uv pip install pyaudio` if you are using uv.
19+
"""
20+
21+
22+
async def main():
23+
kernel = Kernel()
24+
kernel.add_service(OpenAITextToAudio(service_id="tts"))
25+
26+
result = await kernel.invoke_prompt(
27+
prompt="speak the following phrase: {{$phrase}}",
28+
arguments=KernelArguments(
29+
phrase="a painting of a flower vase",
30+
settings=PromptExecutionSettings(service_id="tts", voice="coral"),
31+
),
32+
)
33+
if result:
34+
AudioPlayer(audio_content=result.value[0]).play()
35+
36+
37+
if __name__ == "__main__":
38+
asyncio.run(main())

Diff for: python/samples/concepts/images/image_gen_prompt.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import asyncio
4+
from urllib.request import urlopen
5+
6+
try:
7+
from PIL import Image
8+
9+
pil_available = True
10+
except ImportError:
11+
pil_available = False
12+
13+
from semantic_kernel import Kernel
14+
from semantic_kernel.connectors.ai import PromptExecutionSettings
15+
from semantic_kernel.connectors.ai.open_ai import OpenAITextToImage
16+
from semantic_kernel.functions import KernelArguments
17+
18+
"""
19+
This sample demonstrates how to use the OpenAI text-to-image service to generate an image from a prompt.
20+
It uses the OpenAITextToImage class to create an image based on the provided prompt and settings.
21+
The generated image is then displayed using the PIL library if available.
22+
"""
23+
24+
25+
async def main():
26+
kernel = Kernel()
27+
kernel.add_service(OpenAITextToImage(service_id="dalle3"))
28+
29+
result = await kernel.invoke_prompt(
30+
prompt="Generate a image of {{$topic}} in the style of a {{$style}}",
31+
arguments=KernelArguments(
32+
topic="a flower vase",
33+
style="painting",
34+
settings=PromptExecutionSettings(
35+
service_id="dalle3",
36+
width=1024,
37+
height=1024,
38+
quality="hd",
39+
style="vivid",
40+
),
41+
),
42+
)
43+
if result and pil_available:
44+
img = Image.open(urlopen(str(result.value[0].uri))) # nosec
45+
img.show()
46+
47+
48+
if __name__ == "__main__":
49+
asyncio.run(main())
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import logging
4-
from typing import Literal
4+
from typing import Annotated, Literal
55

6-
from pydantic import Field, model_validator
6+
from pydantic import Field
77

88
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
9-
from semantic_kernel.exceptions.service_exceptions import ServiceInvalidExecutionSettingsError
109

1110
logger = logging.getLogger(__name__)
1211

@@ -18,13 +17,6 @@ class OpenAITextToAudioExecutionSettings(PromptExecutionSettings):
1817
input: str | None = Field(
1918
None, description="Do not set this manually. It is set by the service based on the text content."
2019
)
21-
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = "alloy"
20+
voice: Literal["alloy", "ash", "ballad", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] = "alloy"
2221
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | None = None
23-
speed: float | None = None
24-
25-
@model_validator(mode="after")
26-
def validate_speed(self) -> "OpenAITextToAudioExecutionSettings":
27-
"""Validate the speed parameter."""
28-
if self.speed is not None and (self.speed < 0.25 or self.speed > 4.0):
29-
raise ServiceInvalidExecutionSettingsError("Speed must be between 0.25 and 4.0.")
30-
return self
22+
speed: Annotated[float | None, Field(ge=0.25, le=4.0)] = None

Diff for: python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_text_to_image_execution_settings.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ class OpenAITextToImageExecutionSettings(PromptExecutionSettings):
4141
quality: str | None = None
4242
style: str | None = None
4343

44+
@model_validator(mode="before")
45+
@classmethod
46+
def get_size(cls, data: dict[str, Any]) -> dict[str, Any]:
47+
"""Check that the requested image size is valid."""
48+
if isinstance(data, dict):
49+
if "size" not in data and "width" in data and "height" in data:
50+
data["size"] = ImageSize(width=data["width"], height=data["height"])
51+
elif "extension_data" in data:
52+
extension_data = data["extension_data"]
53+
if (
54+
isinstance(extension_data, dict)
55+
and "size" not in extension_data
56+
and "width" in extension_data
57+
and "height" in extension_data
58+
):
59+
data["extension_data"]["size"] = ImageSize(
60+
width=extension_data["width"], height=extension_data["height"]
61+
)
62+
return data
63+
4464
@model_validator(mode="after")
4565
def check_size(self) -> "OpenAITextToImageExecutionSettings":
4666
"""Check that the requested image size is valid."""
@@ -51,16 +71,6 @@ def check_size(self) -> "OpenAITextToImageExecutionSettings":
5171

5272
return self
5373

54-
@model_validator(mode="after")
55-
def check_prompt(self) -> "OpenAITextToImageExecutionSettings":
56-
"""Check that the prompt is not empty."""
57-
prompt = self.prompt or self.extension_data.get("prompt")
58-
59-
if not prompt:
60-
raise ServiceInvalidExecutionSettingsError("The prompt is required.")
61-
62-
return self
63-
6474
def prepare_settings_dict(self, **kwargs) -> dict[str, Any]:
6575
"""Prepare the settings dictionary for the OpenAI API."""
6676
settings_dict = super().prepare_settings_dict(**kwargs)

Diff for: python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_to_image_base.py

+36-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
from typing import Any
4+
from warnings import warn
45

56
from openai.types.images_response import ImagesResponse
67

@@ -11,30 +12,55 @@
1112
from semantic_kernel.connectors.ai.open_ai.services.open_ai_handler import OpenAIHandler
1213
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
1314
from semantic_kernel.connectors.ai.text_to_image_client_base import TextToImageClientBase
14-
from semantic_kernel.exceptions.service_exceptions import ServiceResponseException
15+
from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError, ServiceResponseException
1516

1617

1718
class OpenAITextToImageBase(OpenAIHandler, TextToImageClientBase):
1819
"""OpenAI text to image client."""
1920

20-
async def generate_image(self, description: str, width: int, height: int, **kwargs: Any) -> bytes | str:
21+
async def generate_image(
22+
self,
23+
description: str,
24+
width: int | None = None,
25+
height: int | None = None,
26+
settings: PromptExecutionSettings | None = None,
27+
**kwargs: Any,
28+
) -> bytes | str:
2129
"""Generate image from text.
2230
2331
Args:
2432
description: Description of the image.
25-
width: Width of the image, check the openai documentation for the supported sizes.
26-
height: Height of the image, check the openai documentation for the supported sizes.
33+
width: Deprecated, use settings instead.
34+
height: Deprecated, use settings instead.
35+
settings: Execution settings for the prompt.
2736
kwargs: Additional arguments, check the openai images.generate documentation for the supported arguments.
2837
2938
Returns:
3039
bytes | str: Image bytes or image URL.
3140
"""
32-
settings = OpenAITextToImageExecutionSettings(
33-
prompt=description,
34-
size=ImageSize(width=width, height=height),
35-
ai_model_id=self.ai_model_id,
36-
**kwargs,
37-
)
41+
if not settings:
42+
settings = OpenAITextToImageExecutionSettings(**kwargs)
43+
if not isinstance(settings, OpenAITextToImageExecutionSettings):
44+
settings = OpenAITextToImageExecutionSettings.from_prompt_execution_settings(settings)
45+
if width:
46+
warn("The 'width' argument is deprecated. Use 'settings.size' instead.", DeprecationWarning)
47+
if settings.size and not settings.size.width:
48+
settings.size.width = width
49+
if height:
50+
warn("The 'height' argument is deprecated. Use 'settings.size' instead.", DeprecationWarning)
51+
if settings.size and not settings.size.height:
52+
settings.size.height = height
53+
if not settings.size and width and height:
54+
settings.size = ImageSize(width=width, height=height)
55+
56+
if not settings.prompt:
57+
settings.prompt = description
58+
59+
if not settings.prompt:
60+
raise ServiceInvalidRequestError("Prompt is required.")
61+
62+
if not settings.ai_model_id:
63+
settings.ai_model_id = self.ai_model_id
3864

3965
response = await self._send_request(settings)
4066

Diff for: python/semantic_kernel/connectors/ai/prompt_execution_settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, service_id: str | None = None, **kwargs: Any):
6868
@property
6969
def keys(self):
7070
"""Get the keys of the prompt execution settings."""
71-
return self.model_fields.keys()
71+
return self.__class__.model_fields.keys()
7272

7373
def prepare_settings_dict(self, **kwargs) -> dict[str, Any]:
7474
"""Prepare the settings as a dictionary for sending to the AI service.
@@ -86,7 +86,7 @@ def prepare_settings_dict(self, **kwargs) -> dict[str, Any]:
8686
by_alias=True,
8787
)
8888

89-
def update_from_prompt_execution_settings(self, config: _T) -> None:
89+
def update_from_prompt_execution_settings(self, config: "PromptExecutionSettings") -> None:
9090
"""Update the prompt execution settings from a completion config."""
9191
if config.service_id is not None:
9292
self.service_id = config.service_id
@@ -95,7 +95,7 @@ def update_from_prompt_execution_settings(self, config: _T) -> None:
9595
self.unpack_extension_data()
9696

9797
@classmethod
98-
def from_prompt_execution_settings(cls: type[_T], config: _T) -> _T:
98+
def from_prompt_execution_settings(cls: type[_T], config: "PromptExecutionSettings") -> _T:
9999
"""Create a prompt execution settings from a completion config."""
100100
config.pack_extension_data()
101101
return cls(

Diff for: python/semantic_kernel/connectors/ai/text_to_image_client_base.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,54 @@
33
from abc import ABC, abstractmethod
44
from typing import Any
55

6+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
7+
from semantic_kernel.contents.image_content import ImageContent
68
from semantic_kernel.services.ai_service_client_base import AIServiceClientBase
79

810

911
class TextToImageClientBase(AIServiceClientBase, ABC):
1012
"""Base class for text to image client."""
1113

1214
@abstractmethod
13-
async def generate_image(self, description: str, width: int, height: int, **kwargs: Any) -> bytes | str:
15+
async def generate_image(
16+
self,
17+
description: str,
18+
width: int | None = None,
19+
height: int | None = None,
20+
settings: PromptExecutionSettings | None = None,
21+
**kwargs: Any,
22+
) -> bytes | str:
1423
"""Generate image from text.
1524
1625
Args:
1726
description: Description of the image.
18-
width: Width of the image.
19-
height: Height of the image.
27+
width: Deprecated, use settings instead.
28+
height: Deprecated, use settings instead.
29+
settings: Execution settings for the prompt.
2030
kwargs: Additional arguments.
2131
2232
Returns:
2333
bytes | str: Image bytes or image URL.
2434
"""
2535
raise NotImplementedError
36+
37+
async def get_image_content(
38+
self,
39+
description: str,
40+
settings: PromptExecutionSettings,
41+
**kwargs: Any,
42+
) -> ImageContent:
43+
"""Generate an image from prompt and return an ImageContent.
44+
45+
Args:
46+
description: Description of the image.
47+
settings: Execution settings for the prompt.
48+
kwargs: Additional arguments.
49+
50+
Returns:
51+
ImageContent: Image content.
52+
"""
53+
image = await self.generate_image(description=description, settings=settings, **kwargs)
54+
if isinstance(image, str):
55+
return ImageContent(uri=image)
56+
return ImageContent(data=image)

Diff for: python/semantic_kernel/functions/kernel_function_from_prompt.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
1313
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
1414
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
15+
from semantic_kernel.connectors.ai.text_to_audio_client_base import TextToAudioClientBase
16+
from semantic_kernel.connectors.ai.text_to_image_client_base import TextToImageClientBase
1517
from semantic_kernel.const import DEFAULT_SERVICE_NAME
18+
from semantic_kernel.contents.audio_content import AudioContent
1619
from semantic_kernel.contents.chat_history import ChatHistory
1720
from semantic_kernel.contents.chat_message_content import ChatMessageContent
21+
from semantic_kernel.contents.image_content import ImageContent
1822
from semantic_kernel.contents.text_content import TextContent
1923
from semantic_kernel.exceptions import FunctionExecutionException, FunctionInitializationError
2024
from semantic_kernel.exceptions.function_exceptions import PromptRenderingException
@@ -204,6 +208,34 @@ async def _invoke_internal(self, context: FunctionInvocationContext) -> None:
204208
)
205209
return
206210

211+
if isinstance(prompt_render_result.ai_service, TextToImageClientBase):
212+
try:
213+
images = await prompt_render_result.ai_service.get_image_content(
214+
description=unescape(prompt_render_result.rendered_prompt),
215+
settings=prompt_render_result.execution_settings,
216+
)
217+
except Exception as exc:
218+
raise FunctionExecutionException(f"Error occurred while invoking function {self.name}: {exc}") from exc
219+
220+
context.result = self._create_function_result(
221+
completions=[images], arguments=context.arguments, prompt=prompt_render_result.rendered_prompt
222+
)
223+
return
224+
225+
if isinstance(prompt_render_result.ai_service, TextToAudioClientBase):
226+
try:
227+
audio = await prompt_render_result.ai_service.get_audio_content(
228+
text=unescape(prompt_render_result.rendered_prompt),
229+
settings=prompt_render_result.execution_settings,
230+
)
231+
except Exception as exc:
232+
raise FunctionExecutionException(f"Error occurred while invoking function {self.name}: {exc}") from exc
233+
234+
context.result = self._create_function_result(
235+
completions=[audio], arguments=context.arguments, prompt=prompt_render_result.rendered_prompt
236+
)
237+
return
238+
207239
raise ValueError(f"Service `{type(prompt_render_result.ai_service).__name__}` is not a valid AI service")
208240

209241
async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None:
@@ -253,7 +285,9 @@ async def _render_prompt(
253285
if prompt_render_context.rendered_prompt is None:
254286
raise PromptRenderingException("Prompt rendering failed, no rendered prompt was returned.")
255287
selected_service: tuple["AIServiceClientBase", PromptExecutionSettings] = context.kernel.select_ai_service(
256-
function=self, arguments=context.arguments
288+
function=self,
289+
arguments=context.arguments,
290+
type=(TextCompletionClientBase, ChatCompletionClientBase) if prompt_render_context.is_streaming else None,
257291
)
258292
return PromptRenderingResult(
259293
rendered_prompt=prompt_render_context.rendered_prompt,
@@ -268,7 +302,7 @@ async def _inner_render_prompt(self, context: PromptRenderContext) -> None:
268302

269303
def _create_function_result(
270304
self,
271-
completions: list[ChatMessageContent] | list[TextContent],
305+
completions: list[ChatMessageContent] | list[TextContent] | list[ImageContent] | list[AudioContent],
272306
arguments: KernelArguments,
273307
chat_history: ChatHistory | None = None,
274308
prompt: str | None = None,

0 commit comments

Comments
 (0)