from pathlib import Path
from typing import Any, Dict, List, Optional, Type
import redis.commands.search.reducers as reducers
import yaml
from pydantic.v1 import BaseModel, Field, PrivateAttr
from redis import Redis
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
from redis.exceptions import ResponseError
from redisvl.extensions.router.schema import (
DistanceAggregationMethod,
Route,
RouteMatch,
RoutingConfig,
SemanticRouterIndexSchema,
)
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.redis.utils import convert_bytes, hashify, make_dict
from redisvl.utils.log import get_logger
from redisvl.utils.utils import model_to_dict
from redisvl.utils.vectorize import (
BaseVectorizer,
HFTextVectorizer,
vectorizer_from_dict,
)
logger = get_logger(__name__)
[docs]
class SemanticRouter(BaseModel):
"""Semantic Router for managing and querying route vectors."""
name: str
"""The name of the semantic router."""
routes: List[Route]
"""List of Route objects."""
vectorizer: BaseVectorizer = Field(default_factory=HFTextVectorizer)
"""The vectorizer used to embed route references."""
routing_config: RoutingConfig = Field(default_factory=RoutingConfig)
"""Configuration for routing behavior."""
_index: SearchIndex = PrivateAttr()
class Config:
arbitrary_types_allowed = True
def __init__(
self,
name: str,
routes: List[Route],
vectorizer: Optional[BaseVectorizer] = None,
routing_config: Optional[RoutingConfig] = None,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the SemanticRouter.
Args:
name (str): The name of the semantic router.
routes (List[Route]): List of Route objects.
vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer.
routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig.
redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
"""
# Set vectorizer default
if vectorizer is None:
vectorizer = HFTextVectorizer()
if routing_config is None:
routing_config = RoutingConfig()
super().__init__(
name=name,
routes=routes,
vectorizer=vectorizer,
routing_config=routing_config,
)
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
def _initialize_index(
self,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
**connection_kwargs,
):
"""Initialize the search index and handle Redis connection."""
schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims)
self._index = SearchIndex(schema=schema)
if redis_client:
self._index.set_client(redis_client)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)
existed = self._index.exists()
self._index.create(overwrite=overwrite)
if not existed or overwrite:
# write the routes to Redis
self._add_routes(self.routes)
@property
def route_names(self) -> List[str]:
"""Get the list of route names.
Returns:
List[str]: List of route names.
"""
return [route.name for route in self.routes]
@property
def route_thresholds(self) -> Dict[str, Optional[float]]:
"""Get the distance thresholds for each route.
Returns:
Dict[str, float]: Dictionary of route names and their distance thresholds.
"""
return {route.name: route.distance_threshold for route in self.routes}
[docs]
def update_routing_config(self, routing_config: RoutingConfig):
"""Update the routing configuration.
Args:
routing_config (RoutingConfig): The new routing configuration.
"""
self.routing_config = routing_config
def _route_ref_key(self, route_name: str, reference: str) -> str:
"""Generate the route reference key."""
reference_hash = hashify(reference)
return f"{self._index.prefix}:{route_name}:{reference_hash}"
def _add_routes(self, routes: List[Route]):
"""Add routes to the router and index.
Args:
routes (List[Route]): List of routes to be added.
"""
route_references: List[Dict[str, Any]] = []
keys: List[str] = []
for route in routes:
# embed route references as a single batch
reference_vectors = self.vectorizer.embed_many(
[reference for reference in route.references], as_buffer=True
)
# set route references
for i, reference in enumerate(route.references):
route_references.append(
{
"route_name": route.name,
"reference": reference,
"vector": reference_vectors[i],
}
)
keys.append(self._route_ref_key(route.name, reference))
# set route if does not yet exist client side
if not self.get(route.name):
self.routes.append(route)
self._index.load(route_references, keys=keys)
[docs]
def get(self, route_name: str) -> Optional[Route]:
"""Get a route by its name.
Args:
route_name (str): Name of the route.
Returns:
Optional[Route]: The selected Route object or None if not found.
"""
return next((route for route in self.routes if route.name == route_name), None)
def _process_route(self, result: Dict[str, Any]) -> RouteMatch:
"""Process resulting route objects and metadata."""
route_dict = make_dict(convert_bytes(result))
return RouteMatch(
name=route_dict["route_name"], distance=float(route_dict["distance"])
)
def _build_aggregate_request(
self,
vector_range_query: RangeQuery,
aggregation_method: DistanceAggregationMethod,
max_k: int,
) -> AggregateRequest:
"""Build the Redis aggregation request."""
aggregation_func: Type[Reducer]
if aggregation_method == DistanceAggregationMethod.min:
aggregation_func = reducers.min
elif aggregation_method == DistanceAggregationMethod.sum:
aggregation_func = reducers.sum
else:
aggregation_func = reducers.avg
aggregate_query = str(vector_range_query).split(" RETURN")[0]
aggregate_request = (
AggregateRequest(aggregate_query)
.group_by(
"@route_name", aggregation_func("vector_distance").alias("distance")
)
.sort_by("@distance", max=max_k)
.dialect(2)
)
return aggregate_request
def _classify_route(
self,
vector: List[float],
distance_threshold: float,
aggregation_method: DistanceAggregationMethod,
) -> RouteMatch:
"""Classify to a single route using a vector."""
vector_range_query = RangeQuery(
vector=vector,
vector_field_name="vector",
distance_threshold=distance_threshold,
return_fields=["route_name"],
)
aggregate_request = self._build_aggregate_request(
vector_range_query, aggregation_method, max_k=1
)
try:
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
self._index.name
).aggregate(aggregate_request, vector_range_query.params)
except ResponseError as e:
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
raise RuntimeError(
"Semantic routing is only available on Redis version 7.x.x or greater"
)
raise e
# process aggregation results into route matches
route_matches = [
self._process_route(route_match) for route_match in aggregation_result.rows
]
# process route matches
if route_matches:
top_route_match = route_matches[0]
if top_route_match.name is not None:
if route := self.get(top_route_match.name):
# use the matched route's distance threshold
_distance_threshold = route.distance_threshold or distance_threshold
if self._pass_threshold(top_route_match, _distance_threshold):
return top_route_match
else:
raise ValueError(
f"{top_route_match.name} not a supported route for the {self.name} semantic router."
)
# fallback to empty route match if no hits
return RouteMatch()
def _classify_multi_route(
self,
vector: List[float],
max_k: int,
distance_threshold: float,
aggregation_method: DistanceAggregationMethod,
) -> List[RouteMatch]:
"""Classify to multiple routes, up to max_k (int), using a vector."""
vector_range_query = RangeQuery(
vector=vector,
vector_field_name="vector",
distance_threshold=distance_threshold,
return_fields=["route_name"],
)
aggregate_request = self._build_aggregate_request(
vector_range_query, aggregation_method, max_k
)
try:
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
self._index.name
).aggregate(aggregate_request, vector_range_query.params)
except ResponseError as e:
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
raise RuntimeError(
"Semantic routing is only available on Redis version 7.x.x or greater"
)
raise e
# process aggregation results into route matches
route_matches = [
self._process_route(route_match) for route_match in aggregation_result.rows
]
# process route matches
top_route_matches: List[RouteMatch] = []
if route_matches:
for route_match in route_matches:
if route_match.name is not None:
if route := self.get(route_match.name):
# use the matched route's distance threshold
_distance_threshold = (
route.distance_threshold or distance_threshold
)
if self._pass_threshold(route_match, _distance_threshold):
top_route_matches.append(route_match)
else:
raise ValueError(
f"{route_match.name} not a supported route for the {self.name} semantic router."
)
return top_route_matches
def _pass_threshold(
self, route_match: Optional[RouteMatch], distance_threshold: float
) -> bool:
"""Check if a route match passes the distance threshold.
Args:
route_match (Optional[RouteMatch]): The route match to check.
distance_threshold (float): The fallback distance threshold to use if not assigned to a route.
Returns:
bool: True if the route match passes the threshold, False otherwise.
"""
if route_match and distance_threshold:
if route_match.distance is not None:
return route_match.distance <= distance_threshold
return False
def __call__(
self,
statement: Optional[str] = None,
vector: Optional[List[float]] = None,
distance_threshold: Optional[float] = None,
aggregation_method: Optional[DistanceAggregationMethod] = None,
) -> RouteMatch:
"""Query the semantic router with a given statement or vector.
Args:
statement (Optional[str]): The input statement to be queried.
vector (Optional[List[float]]): The input vector to be queried.
distance_threshold (Optional[float]): The threshold for semantic distance.
aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances.
Returns:
RouteMatch: The matching route.
"""
if not vector:
if not statement:
raise ValueError("Must provide a vector or statement to the router")
vector = self.vectorizer.embed(statement)
# override routing config
distance_threshold = (
distance_threshold or self.routing_config.distance_threshold
)
aggregation_method = (
aggregation_method or self.routing_config.aggregation_method
)
# perform route classification
top_route_match = self._classify_route(
vector, distance_threshold, aggregation_method
)
return top_route_match
[docs]
def route_many(
self,
statement: Optional[str] = None,
vector: Optional[List[float]] = None,
max_k: Optional[int] = None,
distance_threshold: Optional[float] = None,
aggregation_method: Optional[DistanceAggregationMethod] = None,
) -> List[RouteMatch]:
"""Query the semantic router with a given statement or vector for multiple matches.
Args:
statement (Optional[str]): The input statement to be queried.
vector (Optional[List[float]]): The input vector to be queried.
max_k (Optional[int]): The maximum number of top matches to return.
distance_threshold (Optional[float]): The threshold for semantic distance.
aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances.
Returns:
List[RouteMatch]: The matching routes and their details.
"""
if not vector:
if not statement:
raise ValueError("Must provide a vector or statement to the router")
vector = self.vectorizer.embed(statement)
# override routing config defaults
distance_threshold = (
distance_threshold or self.routing_config.distance_threshold
)
max_k = max_k or self.routing_config.max_k
aggregation_method = (
aggregation_method or self.routing_config.aggregation_method
)
# classify routes
top_route_matches = self._classify_multi_route(
vector, max_k, distance_threshold, aggregation_method
)
return top_route_matches
[docs]
def remove_route(self, route_name: str) -> None:
"""Remove a route and all references from the semantic router.
Args:
route_name (str): Name of the route to remove.
"""
route = self.get(route_name)
if route is None:
logger.warning(f"Route {route_name} is not found in the SemanticRouter")
else:
self._index.drop_keys(
[
self._route_ref_key(route.name, reference)
for reference in route.references
]
)
self.routes = [route for route in self.routes if route.name != route_name]
[docs]
def delete(self) -> None:
"""Delete the semantic router index."""
self._index.delete(drop=True)
[docs]
def clear(self) -> None:
"""Flush all routes from the semantic router index."""
self._index.clear()
self.routes = []
[docs]
@classmethod
def from_dict(
cls,
data: Dict[str, Any],
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a dictionary.
Args:
data (Dict[str, Any]): The dictionary containing the semantic router data.
Returns:
SemanticRouter: The semantic router instance.
Raises:
ValueError: If required data is missing or invalid.
.. code-block:: python
from redisvl.extensions.router import SemanticRouter
router_data = {
"name": "example_router",
"routes": [{"name": "route1", "references": ["ref1"], "distance_threshold": 0.5}],
"vectorizer": {"type": "openai", "model": "text-embedding-ada-002"},
}
router = SemanticRouter.from_dict(router_data)
"""
try:
name = data["name"]
routes_data = data["routes"]
vectorizer_data = data["vectorizer"]
routing_config_data = data["routing_config"]
except KeyError as e:
raise ValueError(f"Unable to load semantic router from dict: {str(e)}")
try:
vectorizer = vectorizer_from_dict(vectorizer_data)
except Exception as e:
raise ValueError(f"Unable to load vectorizer: {str(e)}")
if not vectorizer:
raise ValueError(f"Unable to load vectorizer: {vectorizer_data}")
routes = [Route(**route) for route in routes_data]
routing_config = RoutingConfig(**routing_config_data)
return cls(
name=name,
routes=routes,
vectorizer=vectorizer,
routing_config=routing_config,
**kwargs,
)
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert the SemanticRouter instance to a dictionary.
Returns:
Dict[str, Any]: The dictionary representation of the SemanticRouter.
.. code-block:: python
from redisvl.extensions.router import SemanticRouter
router = SemanticRouter(name="example_router", routes=[], redis_url="redis://localhost:6379")
router_dict = router.to_dict()
"""
return {
"name": self.name,
"routes": [model_to_dict(route) for route in self.routes],
"vectorizer": {
"type": self.vectorizer.type,
"model": self.vectorizer.model,
},
"routing_config": model_to_dict(self.routing_config),
}
[docs]
@classmethod
def from_yaml(
cls,
file_path: str,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a YAML file.
Args:
file_path (str): The path to the YAML file.
Returns:
SemanticRouter: The semantic router instance.
Raises:
ValueError: If the file path is invalid.
FileNotFoundError: If the file does not exist.
.. code-block:: python
from redisvl.extensions.router import SemanticRouter
router = SemanticRouter.from_yaml("router.yaml", redis_url="redis://localhost:6379")
"""
try:
fp = Path(file_path).resolve()
except OSError as e:
raise ValueError(f"Invalid file path: {file_path}") from e
if not fp.exists():
raise FileNotFoundError(f"File {file_path} does not exist")
with open(fp, "r") as f:
yaml_data = yaml.safe_load(f)
return cls.from_dict(
yaml_data,
**kwargs,
)
[docs]
def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
"""Write the semantic router to a YAML file.
Args:
file_path (str): The path to the YAML file.
overwrite (bool): Whether to overwrite the file if it already exists.
Raises:
FileExistsError: If the file already exists and overwrite is False.
.. code-block:: python
from redisvl.extensions.router import SemanticRouter
router = SemanticRouter(
name="example_router",
routes=[],
redis_url="redis://localhost:6379"
)
router.to_yaml("router.yaml")
"""
fp = Path(file_path).resolve()
if fp.exists() and not overwrite:
raise FileExistsError(f"Schema file {file_path} already exists.")
with open(fp, "w") as f:
yaml_data = self.to_dict()
yaml.dump(yaml_data, f, sort_keys=False)