Skip to content

Commit 9e4b565

Browse files
add hybrid search
1 parent 97fb2c3 commit 9e4b565

File tree

6 files changed

+222
-14
lines changed

6 files changed

+222
-14
lines changed

Diff for: python/samples/concepts/memory/complex_memory.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
vectorstoremodel,
4545
)
4646
from semantic_kernel.data.const import DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction, IndexKind
47-
from semantic_kernel.data.vector_search import add_vector_to_records
47+
from semantic_kernel.data.vector_search import KeywordHybridSearchMixin, add_vector_to_records
4848

4949
# This is a rather complex sample, showing how to use the vector store
5050
# with a number of different collections.
@@ -279,6 +279,29 @@ async def main(collection: str, use_azure_openai: bool):
279279
print_with_color("Now we can start searching.", Colors.CBLUE)
280280
print_with_color(" For each type of search, enter a search term, for instance `python`.", Colors.CBLUE)
281281
print_with_color(" Enter exit to exit, and skip or nothing to skip this search.", Colors.CBLUE)
282+
if isinstance(record_collection, KeywordHybridSearchMixin):
283+
search_text = input("Enter search text for hybrid text search: ")
284+
if search_text.lower() == "exit":
285+
await cleanup(record_collection)
286+
return
287+
if not search_text or search_text.lower() != "skip":
288+
print("-" * 30)
289+
print_with_color(
290+
f"Using hybrid text search, for {distance_function.value}, "
291+
f"the {'higher' if DISTANCE_FUNCTION_DIRECTION_HELPER[distance_function](1, 0) else 'lower'} the score the better", # noqa: E501
292+
Colors.CBLUE,
293+
)
294+
try:
295+
vector = (await embedder.generate_raw_embeddings([search_text]))[0]
296+
search_results = await record_collection.hybrid_search(
297+
keywords=search_text, vector=vector, options=options
298+
)
299+
if search_results.total_count == 0:
300+
print("\nNothing found...\n")
301+
else:
302+
[print_record(result) async for result in search_results.results]
303+
except Exception as e:
304+
print(f"Error: {e}")
282305
if isinstance(record_collection, VectorTextSearchMixin):
283306
search_text = input("Enter search text for text search: ")
284307
if search_text.lower() == "exit":

Diff for: python/semantic_kernel/connectors/memory/azure_ai_search/azure_ai_search_collection.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717
get_search_client,
1818
get_search_index_client,
1919
)
20-
from semantic_kernel.data.record_definition import VectorStoreRecordDefinition, VectorStoreRecordVectorField
20+
from semantic_kernel.data.record_definition import (
21+
VectorStoreRecordDataField,
22+
VectorStoreRecordDefinition,
23+
VectorStoreRecordVectorField,
24+
)
2125
from semantic_kernel.data.text_search import AnyTagsEqualTo, EqualTo, KernelSearchResults
2226
from semantic_kernel.data.vector_search import (
27+
KeywordHybridSearchMixin,
2328
VectorizableTextSearchMixin,
2429
VectorizedSearchMixin,
2530
VectorSearchFilter,
@@ -33,6 +38,7 @@
3338
VectorStoreInitializationException,
3439
VectorStoreOperationException,
3540
)
41+
from semantic_kernel.kernel_types import OptionalOneOrMany
3642
from semantic_kernel.utils.feature_stage_decorator import experimental
3743

3844
if sys.version_info >= (3, 12):
@@ -49,6 +55,7 @@ class AzureAISearchCollection(
4955
VectorizableTextSearchMixin[TKey, TModel],
5056
VectorizedSearchMixin[TKey, TModel],
5157
VectorTextSearchMixin[TKey, TModel],
58+
KeywordHybridSearchMixin[TKey, TModel],
5259
Generic[TKey, TModel],
5360
):
5461
"""Azure AI Search collection implementation."""
@@ -241,13 +248,15 @@ async def _inner_search(
241248
search_text: str | None = None,
242249
vectorizable_text: str | None = None,
243250
vector: list[float | int] | None = None,
251+
keywords: OptionalOneOrMany[str] = None,
244252
**kwargs: Any,
245253
) -> KernelSearchResults[VectorSearchResult[TModel]]:
246254
search_args: dict[str, Any] = {
247255
"top": options.top,
248256
"skip": options.skip,
249257
"include_total_count": options.include_total_count,
250258
}
259+
vector_field = self.data_model_definition.try_get_vector_field(options.vector_field_name)
251260
if options.filter.filters:
252261
search_args["filter"] = self._build_filter_string(options.filter)
253262
if search_text is not None:
@@ -257,15 +266,29 @@ async def _inner_search(
257266
VectorizableTextQuery(
258267
text=vectorizable_text,
259268
k_nearest_neighbors=options.top,
260-
fields=options.vector_field_name,
269+
fields=vector_field.name if vector_field else None,
261270
)
262271
]
263272
if vector is not None:
273+
if keywords is not None:
274+
# hybrid search
275+
search_args["search_fields"] = (
276+
[options.keyword_field_name]
277+
if options.keyword_field_name
278+
else [
279+
field.name
280+
for field in self.data_model_definition.fields
281+
if isinstance(field, VectorStoreRecordDataField) and field.is_full_text_searchable
282+
]
283+
)
284+
if not search_args["search_fields"]:
285+
raise VectorStoreOperationException("No searchable fields found for hybrid search.")
286+
search_args["search_text"] = keywords if isinstance(keywords, str) else ", ".join(keywords)
264287
search_args["vector_queries"] = [
265288
VectorizedQuery(
266289
vector=vector,
267290
k_nearest_neighbors=options.top,
268-
fields=options.vector_field_name,
291+
fields=vector_field.name if vector_field else None,
269292
)
270293
]
271294
if "vector_queries" not in search_args and "search_text" not in search_args:

Diff for: python/semantic_kernel/data/vector_search.py

+67-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from semantic_kernel.kernel import Kernel
3434
from semantic_kernel.kernel_pydantic import KernelBaseModel
35-
from semantic_kernel.kernel_types import OneOrMany
35+
from semantic_kernel.kernel_types import OneOrMany, OptionalOneOrMany
3636
from semantic_kernel.utils.feature_stage_decorator import experimental
3737
from semantic_kernel.utils.list_handler import desync_list
3838

@@ -82,6 +82,7 @@ class VectorSearchOptions(SearchOptions):
8282

8383
filter: VectorSearchFilter = Field(default_factory=VectorSearchFilter)
8484
vector_field_name: str | None = None
85+
keyword_field_name: str | None = None
8586
top: Annotated[int, Field(gt=0)] = 3
8687
skip: Annotated[int, Field(ge=0)] = 0
8788
include_vectors: bool = False
@@ -116,6 +117,7 @@ def options_class(self) -> type[SearchOptions]:
116117
async def _inner_search(
117118
self,
118119
options: VectorSearchOptions,
120+
keywords: OptionalOneOrMany[str] = None,
119121
search_text: str | None = None,
120122
vectorizable_text: str | None = None,
121123
vector: list[float | int] | None = None,
@@ -144,6 +146,7 @@ async def _inner_search(
144146
145147
Args:
146148
options: The search options, can be None.
149+
keywords: The text to search for, optional.
147150
search_text: The text to search for, optional.
148151
vectorizable_text: The text to search for, will be vectorized downstream, optional.
149152
vector: The vector to search for, optional.
@@ -216,7 +219,7 @@ async def _get_vector_search_results_from_results(
216219

217220
@experimental
218221
class VectorizedSearchMixin(VectorSearchBase[TKey, TModel], Generic[TKey, TModel]):
219-
"""The mixin for searching with vectors."""
222+
"""The mixin for searching with vectors. To be used in combination with VectorStoreRecordCollection."""
220223

221224
async def vectorized_search(
222225
self,
@@ -280,7 +283,7 @@ def create_text_search_from_vectorized_search(
280283
class VectorizableTextSearchMixin(VectorSearchBase[TKey, TModel], Generic[TKey, TModel]):
281284
"""The mixin for searching with text that get's vectorized downstream.
282285
283-
To be used in combination with VectorSearchBase.
286+
To be used in combination with VectorStoreRecordCollection.
284287
"""
285288

286289
async def vectorizable_text_search(
@@ -341,7 +344,7 @@ def create_text_search_from_vectorizable_text_search(
341344

342345
@experimental
343346
class VectorTextSearchMixin(VectorSearchBase[TKey, TModel], Generic[TKey, TModel]):
344-
"""The mixin for text search, to be used in combination with VectorSearchBase."""
347+
"""The mixin for text search, to be used in combination with VectorStoreRecordCollection."""
345348

346349
async def text_search(
347350
self,
@@ -394,6 +397,66 @@ def create_text_search_from_vector_text_search(
394397
return VectorStoreTextSearch.from_vector_text_search(self, string_mapper, text_search_results_mapper)
395398

396399

400+
# region: Keyword Hybrid Search
401+
402+
403+
@experimental
404+
class KeywordHybridSearchMixin(VectorSearchBase[TKey, TModel], Generic[TKey, TModel]):
405+
"""The mixin for hybrid vector and text search, to be used in combination with VectorStoreRecordCollection."""
406+
407+
async def hybrid_search(
408+
self,
409+
vector: list[float | int],
410+
keywords: OneOrMany[str],
411+
options: SearchOptions | None = None,
412+
**kwargs: Any,
413+
) -> "KernelSearchResults[VectorSearchResult[TModel]]":
414+
"""Search the vector store for records that match the given text and filters.
415+
416+
Args:
417+
vector: The vector to search for.
418+
keywords: The keywords to search for.
419+
options: options, should include query_text
420+
**kwargs: if options are not set, this is used to create them.
421+
422+
Raises:
423+
VectorSearchExecutionException: If an error occurs during the search.
424+
VectorStoreModelDeserializationException: If an error occurs during deserialization.
425+
VectorSearchOptionsException: If the search options are invalid.
426+
VectorStoreMixinException: raised when the method is not used in combination with the VectorSearchBase.
427+
428+
"""
429+
options = create_options(self.options_class, options, **kwargs) # type: ignore
430+
try:
431+
return await self._inner_search(vector=vector, keywords=keywords, options=options) # type: ignore
432+
except (VectorStoreModelDeserializationException, VectorSearchOptionsException, VectorSearchExecutionException):
433+
raise # pragma: no cover
434+
except Exception as exc:
435+
raise VectorSearchExecutionException(f"An error occurred during the search: {exc}") from exc
436+
437+
def create_text_search_from_vector_text_search(
438+
self,
439+
string_mapper: Callable[[TModel], str] | None = None,
440+
text_search_results_mapper: Callable[[TModel], TextSearchResult] | None = None,
441+
) -> "VectorStoreTextSearch[TModel]":
442+
"""Create a VectorStoreTextSearch object.
443+
444+
This method is used to create a VectorStoreTextSearch object that can be used to search the vector store
445+
for records that match the given text and filter.
446+
The text string will be vectorized downstream and used for the vector search.
447+
448+
Args:
449+
string_mapper: A function that maps the record to a string.
450+
text_search_results_mapper: A function that maps the record to a TextSearchResult.
451+
452+
Returns:
453+
VectorStoreTextSearch: The created VectorStoreTextSearch object.
454+
"""
455+
from semantic_kernel.data.vector_store_text_search import VectorStoreTextSearch
456+
457+
return VectorStoreTextSearch.from_vector_text_search(self, string_mapper, text_search_results_mapper)
458+
459+
397460
# region: add_vector_to_records
398461

399462

Diff for: python/tests/conftest.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,7 @@ def data_model_definition(
365365
fields={
366366
"id": VectorStoreRecordKeyField(property_type="str"),
367367
"content": VectorStoreRecordDataField(
368-
has_embedding=True,
369-
embedding_property_name="vector",
370-
property_type="str",
368+
has_embedding=True, embedding_property_name="vector", property_type="str", is_full_text_searchable=True
371369
),
372370
"vector": VectorStoreRecordVectorField(
373371
dimensions=dimensions,

Diff for: python/tests/unit/connectors/memory/azure_ai_search/test_azure_ai_search.py

+103
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33

4+
import asyncio
45
from unittest.mock import MagicMock, Mock, patch
56

67
from pytest import fixture, mark, raises
@@ -14,6 +15,7 @@
1415
data_model_definition_to_azure_ai_search_index,
1516
get_search_index_client,
1617
)
18+
from semantic_kernel.data.vector_search import VectorSearchOptions
1719
from semantic_kernel.exceptions import (
1820
ServiceInitializationError,
1921
VectorStoreInitializationException,
@@ -72,6 +74,17 @@ def mock_get():
7274
yield mock_get_document
7375

7476

77+
@fixture
78+
def mock_search():
79+
async def iter_search_results(*args, **kwargs):
80+
yield {"id": "id1", "content": "content", "vector": [1.0, 2.0, 3.0]}
81+
await asyncio.sleep(0.0)
82+
83+
with patch(f"{BASE_PATH_SEARCH_CLIENT}.search") as mock_search:
84+
mock_search.side_effect = iter_search_results
85+
yield mock_search
86+
87+
7588
@fixture
7689
def mock_delete():
7790
with patch(f"{BASE_PATH_SEARCH_CLIENT}.delete_documents") as mock_delete_documents:
@@ -293,3 +306,93 @@ def test_get_search_index_client(azure_ai_search_unit_test_env):
293306

294307
with raises(ServiceInitializationError):
295308
get_search_index_client(settings)
309+
310+
311+
@mark.parametrize("include_vectors", [True, False])
312+
async def test_search_text_search(collection, mock_search, include_vectors):
313+
options = VectorSearchOptions(include_vectors=include_vectors)
314+
results = await collection.text_search("test", options=options)
315+
assert results is not None
316+
async for result in results.results:
317+
assert result is not None
318+
assert result.record is not None
319+
assert result.record["id"] == "id1"
320+
assert result.record["content"] == "content"
321+
if include_vectors:
322+
assert result.record["vector"] == [1.0, 2.0, 3.0]
323+
mock_search.assert_awaited_once_with(
324+
top=3,
325+
skip=0,
326+
include_total_count=False,
327+
search_text="test",
328+
select=["*"] if include_vectors else ["id", "content"],
329+
)
330+
331+
332+
@mark.parametrize("include_vectors", [True, False])
333+
async def test_search_vectorized_search(collection, mock_search, include_vectors):
334+
options = VectorSearchOptions(include_vectors=include_vectors)
335+
results = await collection.vectorized_search([0.1, 0.2, 0.3], options=options)
336+
assert results is not None
337+
async for result in results.results:
338+
assert result is not None
339+
assert result.record is not None
340+
assert result.record["id"] == "id1"
341+
assert result.record["content"] == "content"
342+
if include_vectors:
343+
assert result.record["vector"] == [1.0, 2.0, 3.0]
344+
for call in mock_search.call_args_list:
345+
assert call[1]["top"] == 3
346+
assert call[1]["skip"] == 0
347+
assert call[1]["include_total_count"] is False
348+
assert call[1]["select"] == ["*"] if include_vectors else ["id", "content"]
349+
assert call[1]["vector_queries"][0].vector == [0.1, 0.2, 0.3]
350+
assert call[1]["vector_queries"][0].fields == "vector"
351+
assert call[1]["vector_queries"][0].k_nearest_neighbors == 3
352+
353+
354+
@mark.parametrize("include_vectors", [True, False])
355+
async def test_search_vectorizable_search(collection, mock_search, include_vectors):
356+
options = VectorSearchOptions(include_vectors=include_vectors)
357+
results = await collection.vectorizable_text_search("test", options=options)
358+
assert results is not None
359+
async for result in results.results:
360+
assert result is not None
361+
assert result.record is not None
362+
assert result.record["id"] == "id1"
363+
assert result.record["content"] == "content"
364+
if include_vectors:
365+
assert result.record["vector"] == [1.0, 2.0, 3.0]
366+
for call in mock_search.call_args_list:
367+
assert call[1]["top"] == 3
368+
assert call[1]["skip"] == 0
369+
assert call[1]["include_total_count"] is False
370+
assert call[1]["select"] == ["*"] if include_vectors else ["id", "content"]
371+
assert call[1]["vector_queries"][0].text == "test"
372+
assert call[1]["vector_queries"][0].fields == "vector"
373+
assert call[1]["vector_queries"][0].k_nearest_neighbors == 3
374+
375+
376+
@mark.parametrize("include_vectors", [True, False])
377+
@mark.parametrize("keywords", ["test", ["test1", "test2"]], ids=["single", "multiple"])
378+
async def test_search_keyword_hybrid_search(collection, mock_search, include_vectors, keywords):
379+
options = VectorSearchOptions(include_vectors=include_vectors, keyword_field_name="content")
380+
results = await collection.hybrid_search(keywords=keywords, vector=[0.1, 0.2, 0.3], options=options)
381+
assert results is not None
382+
async for result in results.results:
383+
assert result is not None
384+
assert result.record is not None
385+
assert result.record["id"] == "id1"
386+
assert result.record["content"] == "content"
387+
if include_vectors:
388+
assert result.record["vector"] == [1.0, 2.0, 3.0]
389+
for call in mock_search.call_args_list:
390+
assert call[1]["top"] == 3
391+
assert call[1]["skip"] == 0
392+
assert call[1]["include_total_count"] is False
393+
assert call[1]["select"] == ["*"] if include_vectors else ["id", "content"]
394+
assert call[1]["search_fields"] == ["content"]
395+
assert call[1]["search_text"] == "test" if keywords == "test" else "test1, test2"
396+
assert call[1]["vector_queries"][0].vector == [0.1, 0.2, 0.3]
397+
assert call[1]["vector_queries"][0].fields == "vector"
398+
assert call[1]["vector_queries"][0].k_nearest_neighbors == 3

Diff for: python/tests/unit/data/test_vector_search_base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
async def test_search(vector_store_record_collection: VectorSearchBase):
1313
record = {"id": "test_id", "content": "test_content", "vector": [1.0, 2.0, 3.0]}
1414
await vector_store_record_collection.upsert(record)
15-
results = await vector_store_record_collection._inner_search(
16-
options=VectorSearchOptions(), search_text="test_content"
17-
)
15+
results = await vector_store_record_collection._inner_search(options=VectorSearchOptions(), keywords="test_content")
1816
records = [rec async for rec in results.results]
1917
assert records[0].record == record
2018

0 commit comments

Comments
 (0)