Skip to content

[ENH] Implements a new embedding function for ChromaDB that uses the modern google-genai library (the recommended replacement for google.generativeai). #4278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions chromadb/utils/embedding_functions/google_embedding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,123 @@
import numpy.typing as npt
from chromadb.utils.embedding_functions.schemas import validate_config_schema

class GoogleGenAiEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google-genai Python package installed and have a Google API key.
This uses the newer google-genai package which is the recommended replacement for google.generativeai.
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "text-embedding-004",
api_key_env_var: str = "CHROMA_GOOGLE_GENAI_API_KEY",
):
"""
Initialize the GoogleGenAiEmbeddingFunction.

Args:
api_key (Optional[str], optional): API key for the Google Generative AI API. If not provided, it will be read from the environment variable.
model_name (str, optional): The name of the model to use for text embeddings.
Defaults to "text-embedding-004".
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google Generative AI API.
Defaults to "CHROMA_GOOGLE_GENAI_API_KEY".
"""
try:
from google.genai import Client
except ImportError:
raise ValueError(
"The Google GenAI python package is not installed. Please install it with `pip install google-genai`"
)

self.api_key_env_var = api_key_env_var
self.api_key = api_key or os.getenv(api_key_env_var)
if not self.api_key:
raise ValueError(f"The {api_key_env_var} environment variable is not set.")

self.model_name = model_name
self._client = Client(api_key=self.api_key)

def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.

Args:
input: Documents to generate embeddings for.

Returns:
Embeddings for the documents.
"""
# Google GenAI only works with text documents
if not all(isinstance(item, str) for item in input):
raise ValueError(
"Google GenAI only supports text documents, not images"
)

embeddings_list: List[npt.NDArray[np.float32]] = []
for text in input:
try:
response = self._client.models.embed_content(
model=self.model_name,
contents=text
)
embeddings_list.append(
np.array(response.embeddings[0].values, dtype=np.float32)
)
except Exception as e:
raise ValueError(f"Error generating embedding: {str(e)}")

# Convert to the expected Embeddings type (List[Vector])
return cast(Embeddings, embeddings_list)

@staticmethod
def name() -> str:
return "google_genai"

def default_space(self) -> Space:
return "cosine"

def supported_spaces(self) -> List[Space]:
return ["cosine", "l2", "ip"]

@staticmethod
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
api_key_env_var = config.get("api_key_env_var")
model_name = config.get("model_name")

if api_key_env_var is None or model_name is None:
assert False, "This code should not be reached"

return GoogleGenAiEmbeddingFunction(
api_key_env_var=api_key_env_var, model_name=model_name
)

def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model_name": self.model_name,
}

def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
if "model_name" in new_config:
raise ValueError(
"The model name cannot be changed after the embedding function has been initialized."
)

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
"""
Validate the configuration using the JSON schema.

Args:
config: Configuration to validate

Raises:
ValidationError: If the configuration does not match the schema
"""
validate_config_schema(config, "google_genai")


class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
Expand Down