from typing import Any, Dict, List, Optional
from redis import Redis
from redisvl.extensions.llmcache.base import BaseLLMCache
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.schema import IndexSchema
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
[docs]
class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""
entry_id_field_name: str = "id"
prompt_field_name: str = "prompt"
vector_field_name: str = "prompt_vector"
response_field_name: str = "response"
metadata_field_name: str = "metadata"
def __init__(
self,
name: str = "llmcache",
prefix: Optional[str] = None,
distance_threshold: float = 0.1,
ttl: Optional[int] = None,
vectorizer: Optional[BaseVectorizer] = None,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_args: Dict[str, Any] = {},
**kwargs,
):
"""Semantic Cache for Large Language Models.
Args:
name (str, optional): The name of the semantic cache search index.
Defaults to "llmcache".
prefix (Optional[str], optional): The prefix for Redis keys
associated with the semantic cache search index. Defaults to
None, and the index name will be used as the key prefix.
distance_threshold (float, optional): Semantic threshold for the
cache. Defaults to 0.1.
ttl (Optional[int], optional): The time-to-live for records cached
in Redis. Defaults to None.
vectorizer (BaseVectorizer, optional): The vectorizer for the cache.
Defaults to HFTextVectorizer.
redis_client(Redis, optional): A redis client connection instance.
Defaults to None.
redis_url (str, optional): The redis url. Defaults to
"redis://localhost:6379".
connection_args (Dict[str, Any], optional): The connection arguments
for the redis client. Defaults to None.
Raises:
TypeError: If an invalid vectorizer is provided.
TypeError: If the TTL value is not an int.
ValueError: If the threshold is not between 0 and 1.
ValueError: If the index name is not provided
"""
super().__init__(ttl)
# Use the index name as the key prefix by default
if prefix is None:
prefix = name
# Set vectorizer default
if vectorizer is None:
vectorizer = HFTextVectorizer(
model="sentence-transformers/all-mpnet-base-v2"
)
# build cache index schema
schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}})
# add fields
schema.add_fields(
[
{"name": self.prompt_field_name, "type": "text"},
{"name": self.response_field_name, "type": "text"},
{
"name": self.vector_field_name,
"type": "vector",
"attrs": {
"dims": vectorizer.dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
]
)
# build search index
self._index = SearchIndex(schema=schema)
# handle redis connection
if redis_client:
self._index.set_client(redis_client)
else:
self._index.connect(redis_url=redis_url, **connection_args)
# initialize other components
self.default_return_fields = [
self.entry_id_field_name,
self.prompt_field_name,
self.response_field_name,
self.vector_field_name,
self.metadata_field_name,
]
self.set_vectorizer(vectorizer)
self.set_threshold(distance_threshold)
self._index.create(overwrite=False)
@property
def index(self) -> SearchIndex:
"""The underlying SearchIndex for the cache.
Returns:
SearchIndex: The search index.
"""
return self._index
@property
def distance_threshold(self) -> float:
"""The semantic distance threshold for the cache.
Returns:
float: The semantic distance threshold.
"""
return self._distance_threshold
[docs]
def set_threshold(self, distance_threshold: float) -> None:
"""Sets the semantic distance threshold for the cache.
Args:
distance_threshold (float): The semantic distance threshold for
the cache.
Raises:
ValueError: If the threshold is not between 0 and 1.
"""
if not 0 <= float(distance_threshold) <= 1:
raise ValueError(
f"Distance must be between 0 and 1, got {distance_threshold}"
)
self._distance_threshold = float(distance_threshold)
[docs]
def set_vectorizer(self, vectorizer: BaseVectorizer) -> None:
"""Sets the vectorizer for the LLM cache.
Must be a valid subclass of BaseVectorizer and have equivalent
dimensions to the vector field defined in the schema.
Args:
vectorizer (BaseVectorizer): The RedisVL vectorizer to use for
vectorizing cache entries.
Raises:
TypeError: If the vectorizer is not a valid type.
ValueError: If the vector dimensions are mismatched.
"""
if not isinstance(vectorizer, BaseVectorizer):
raise TypeError("Must provide a valid redisvl.vectorizer class.")
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
if schema_vector_dims != vectorizer.dims:
raise ValueError(
"Invalid vector dimensions! "
f"Vectorizer has dims defined as {vectorizer.dims}",
f"Vector field has dims defined as {schema_vector_dims}",
)
self._vectorizer = vectorizer
[docs]
def clear(self) -> None:
"""Clear the cache of all keys while preserving the index."""
with self._index.client.pipeline(transaction=False) as pipe: # type: ignore
for key in self._index.client.scan_iter(match=f"{self._index.prefix}:*"): # type: ignore
pipe.delete(key)
pipe.execute()
[docs]
def delete(self) -> None:
"""Clear the semantic cache of all keys and remove the underlying search
index."""
self._index.delete(drop=True)
def _refresh_ttl(self, key: str) -> None:
"""Refresh the time-to-live for the specified key."""
if self.ttl:
self._index.client.expire(key, self.ttl) # type: ignore
def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
"""Converts a text prompt to its vector representation using the
configured vectorizer."""
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")
return self._vectorizer.embed(prompt)
def _search_cache(
self, vector: List[float], num_results: int, return_fields: Optional[List[str]]
) -> List[Dict[str, Any]]:
"""Searches the semantic cache for similar prompt vectors and returns
the specified return fields for each cache hit."""
# Setup and type checks
if not isinstance(vector, list):
raise TypeError("Vector must be a list of floats")
return_fields = return_fields or self.default_return_fields
if not isinstance(return_fields, list):
raise TypeError("return_fields must be a list of field names")
# Construct vector RangeQuery for the cache check
query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
return_fields=return_fields,
distance_threshold=self._distance_threshold,
num_results=num_results,
return_score=True,
)
# Gather and return the cache hits
cache_hits: List[Dict[str, Any]] = self._index.query(query)
# Process cache hits
for hit in cache_hits:
self._refresh_ttl(hit[self.entry_id_field_name])
# Check for metadata and deserialize
if self.metadata_field_name in hit:
hit[self.metadata_field_name] = self.deserialize(
hit[self.metadata_field_name]
)
return cache_hits
[docs]
def check(
self,
prompt: Optional[str] = None,
vector: Optional[List[float]] = None,
num_results: int = 1,
return_fields: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Checks the semantic cache for results similar to the specified prompt
or vector.
This method searches the cache using vector similarity with
either a raw text prompt (converted to a vector) or a provided vector as
input. It checks for semantically similar prompts and fetches the cached
LLM responses.
Args:
prompt (Optional[str], optional): The text prompt to search for in
the cache.
vector (Optional[List[float]], optional): The vector representation
of the prompt to search for in the cache.
num_results (int, optional): The number of cached results to return.
Defaults to 1.
return_fields (Optional[List[str]], optional): The fields to include
in each returned result. If None, defaults to all available
fields in the cached entry.
Returns:
List[Dict[str, Any]]: A list of dicts containing the requested
return fields for each similar cached response.
Raises:
ValueError: If neither a `prompt` nor a `vector` is specified.
TypeError: If `return_fields` is not a list when provided.
.. code-block:: python
response = cache.check(
prompt="What is the captial city of France?"
)
"""
if not (prompt or vector):
raise ValueError("Either prompt or vector must be specified.")
# Use provided vector or create from prompt
vector = vector or self._vectorize_prompt(prompt)
# Check for cache hits by searching the cache
cache_hits = self._search_cache(vector, num_results, return_fields)
return cache_hits
[docs]
def store(
self,
prompt: str,
response: str,
vector: Optional[List[float]] = None,
metadata: Optional[dict] = None,
) -> str:
"""Stores the specified key-value pair in the cache along with metadata.
Args:
prompt (str): The user prompt to cache.
response (str): The LLM response to cache.
vector (Optional[List[float]], optional): The prompt vector to
cache. Defaults to None, and the prompt vector is generated on
demand.
metadata (Optional[dict], optional): The optional metadata to cache
alongside the prompt and response. Defaults to None.
Returns:
str: The Redis key for the entries added to the semantic cache.
Raises:
ValueError: If neither prompt nor vector is specified.
TypeError: If provided metadata is not a dictionary.
.. code-block:: python
key = cache.store(
prompt="What is the captial city of France?",
response="Paris",
metadata={"city": "Paris", "country": "France"}
)
"""
# Vectorize prompt if necessary and create cache payload
vector = vector or self._vectorize_prompt(prompt)
# Construct semantic cache payload
id_field = self.entry_id_field_name
payload = {
id_field: self.hash_input(prompt),
self.prompt_field_name: prompt,
self.response_field_name: response,
self.vector_field_name: array_to_buffer(vector),
}
if metadata is not None:
if not isinstance(metadata, dict):
raise TypeError("If specified, cached metadata must be a dictionary.")
# Serialize the metadata dict and add to cache payload
payload[self.metadata_field_name] = self.serialize(metadata)
# Load LLMCache entry with TTL
keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field)
return keys[0]