import os
from typing import Any, Callable, Dict, List, Optional
from pydantic.v1 import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type
from redisvl.utils.vectorize.base import BaseVectorizer
# ignore that cohere isn't imported
# mypy: disable-error-code="name-defined"
[docs]
class CohereTextVectorizer(BaseVectorizer):
"""The CohereTextVectorizer class utilizes Cohere's API to generate
embeddings for text data.
This vectorizer is designed to interact with Cohere's /embed API,
requiring an API key for authentication. The key can be provided
directly in the `api_config` dictionary or through the `COHERE_API_KEY`
environment variable. User must obtain an API key from Cohere's website
(https://dashboard.cohere.com/). Additionally, the `cohere` python
client must be installed with `pip install cohere`.
The vectorizer supports only synchronous operations, allows for batch
processing of texts and flexibility in handling preprocessing tasks.
.. code-block:: python
from redisvl.utils.vectorize import CohereTextVectorizer
vectorizer = CohereTextVectorizer(
model="embed-english-v3.0",
api_config={"api_key": "your-cohere-api-key"} # OR set COHERE_API_KEY in your env
)
query_embedding = vectorizer.embed(
text="your input query text here",
input_type="search_query"
)
doc_embeddings = cohere.embed_many(
texts=["your document text", "more document text"],
input_type="search_document"
)
"""
_client: Any = PrivateAttr()
def __init__(
self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None
):
"""Initialize the Cohere vectorizer.
Visit https://cohere.ai/embed to learn about embeddings.
Args:
model (str): Model to use for embedding. Defaults to 'embed-english-v3.0'.
api_config (Optional[Dict], optional): Dictionary containing the API key.
Defaults to None.
Raises:
ImportError: If the cohere library is not installed.
ValueError: If the API key is not provided.
"""
self._initialize_client(api_config)
super().__init__(model=model, dims=self._set_model_dims(model))
def _initialize_client(self, api_config: Optional[Dict]):
"""
Setup the Cohere clients using the provided API key or an
environment variable.
"""
# Dynamic import of the cohere module
try:
from cohere import AsyncClient, Client
except ImportError:
raise ImportError(
"Cohere vectorizer requires the cohere library. \
Please install with `pip install cohere`"
)
# Fetch the API key from api_config or environment variable
api_key = (
api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY")
)
if not api_key:
raise ValueError(
"Cohere API key is required. "
"Provide it in api_config or set the COHERE_API_KEY environment variable."
)
self._client = Client(api_key=api_key, client_name="redisvl")
def _set_model_dims(self, model) -> int:
try:
embedding = self._client.embed(
texts=["dimension test"],
model=model,
input_type="search_document",
).embeddings[0]
except (KeyError, IndexError) as ke:
raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}")
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
return len(embedding)
[docs]
def embed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
"""Embed a chunk of text using the Cohere Embeddings API.
Must provide the embedding `input_type` as a `kwarg` to this method
that specifies the type of input you're giving to the model.
Supported input types:
- ``search_document``: Used for embeddings stored in a vector database for search use-cases.
- ``search_query``: Used for embeddings of search queries run against a vector DB to find relevant documents.
- ``classification``: Used for embeddings passed through a text classifier
- ``clustering``: Used for the embeddings run through a clustering algorithm.
When hydrating your Redis DB, the documents you want to search over
should be embedded with input_type= "search_document" and when you are
querying the database, you should set the input_type = "search query".
If you want to use the embeddings for a classification or clustering
task downstream, you should set input_type= "classification" or
"clustering".
Args:
text (str): Chunk of text to embed.
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
as_buffer (bool, optional): Whether to convert the raw embedding
to a byte string. Defaults to False.
input_type (str): Specifies the type of input passed to the model.
Required for embedding models v3 and higher.
Returns:
List[float]: Embedding.
Raises:
TypeError: In an invalid input_type is provided.
"""
input_type = kwargs.get("input_type")
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
if not isinstance(input_type, str):
raise TypeError(
"Must pass in a str value for cohere embedding input_type. \
See https://docs.cohere.com/reference/embed."
)
if preprocess:
text = preprocess(text)
embedding = self._client.embed(
texts=[text], model=self.model, input_type=input_type
).embeddings[0]
return self._process_embedding(embedding, as_buffer)
[docs]
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def embed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
"""Embed many chunks of text using the Cohere Embeddings API.
Must provide the embedding `input_type` as a `kwarg` to this method
that specifies the type of input you're giving to the model.
Supported input types:
- ``search_document``: Used for embeddings stored in a vector database for search use-cases.
- ``search_query``: Used for embeddings of search queries run against a vector DB to find relevant documents.
- ``classification``: Used for embeddings passed through a text classifier
- ``clustering``: Used for the embeddings run through a clustering algorithm.
When hydrating your Redis DB, the documents you want to search over
should be embedded with input_type= "search_document" and when you are
querying the database, you should set the input_type = "search query".
If you want to use the embeddings for a classification or clustering
task downstream, you should set input_type= "classification" or
"clustering".
Args:
texts (List[str]): List of text chunks to embed.
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
batch_size (int, optional): Batch size of texts to use when creating
embeddings. Defaults to 10.
as_buffer (bool, optional): Whether to convert the raw embedding
to a byte string. Defaults to False.
input_type (str): Specifies the type of input passed to the model.
Required for embedding models v3 and higher.
Returns:
List[List[float]]: List of embeddings.
Raises:
TypeError: In an invalid input_type is provided.
"""
input_type = kwargs.get("input_type")
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
if not isinstance(input_type, str):
raise TypeError(
"Must pass in a str value for cohere embedding input_type.\
See https://docs.cohere.com/reference/embed."
)
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._client.embed(
texts=batch, model=self.model, input_type=input_type
)
embeddings += [
self._process_embedding(embedding, as_buffer)
for embedding in response.embeddings
]
return embeddings
async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
raise NotImplementedError
async def aembed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
raise NotImplementedError
@property
def type(self) -> str:
return "cohere"