Source code for langchain_community.vectorstores.databricks_vector_search

from __future__ import annotations

import json
import logging
import uuid
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
)

import numpy as np
from langchain_core._api import warn_deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:
    from databricks.vector_search.client import VectorSearchIndex

logger = logging.getLogger(__name__)


[docs]class DatabricksVectorSearch(VectorStore): """`Databricks Vector Search` vector store. To use, you should have the ``databricks-vectorsearch`` python package installed. Example: .. code-block:: python from langchain_community.vectorstores import DatabricksVectorSearch from databricks.vector_search.client import VectorSearchClient vs_client = VectorSearchClient() vs_index = vs_client.get_index( endpoint_name="vs_endpoint", index_name="ml.llm.index" ) vectorstore = DatabricksVectorSearch(vs_index) Args: index: A Databricks Vector Search index object. embedding: The embedding model. Required for direct-access index or delta-sync index with self-managed embeddings. text_column: The name of the text column to use for the embeddings. Required for direct-access index or delta-sync index with self-managed embeddings. Make sure the text column specified is in the index. columns: The list of column names to get when doing the search. Defaults to ``[primary_key, text_column]``. Delta-sync index with Databricks-managed embeddings manages the ingestion, deletion, and embedding for you. Manually ingestion/deletion of the documents/texts is not supported for delta-sync index. If you want to use a delta-sync index with self-managed embeddings, you need to provide the embedding model and text column name to use for the embeddings. Example: .. code-block:: python from langchain_community.vectorstores import DatabricksVectorSearch from databricks.vector_search.client import VectorSearchClient from langchain_community.embeddings.openai import OpenAIEmbeddings vs_client = VectorSearchClient() vs_index = vs_client.get_index( endpoint_name="vs_endpoint", index_name="ml.llm.index" ) vectorstore = DatabricksVectorSearch( index=vs_index, embedding=OpenAIEmbeddings(), text_column="document_content" ) If you want to manage the documents ingestion/deletion yourself, you can use a direct-access index. Example: .. code-block:: python from langchain_community.vectorstores import DatabricksVectorSearch from databricks.vector_search.client import VectorSearchClient from langchain_community.embeddings.openai import OpenAIEmbeddings vs_client = VectorSearchClient() vs_index = vs_client.get_index( endpoint_name="vs_endpoint", index_name="ml.llm.index" ) vectorstore = DatabricksVectorSearch( index=vs_index, embedding=OpenAIEmbeddings(), text_column="document_content" ) vectorstore.add_texts( texts=["text1", "text2"] ) For more information on Databricks Vector Search, see `Databricks Vector Search documentation: https://docs.databricks.com/en/generative-ai/vector-search.html. """
[docs] def __init__( self, index: VectorSearchIndex, *, embedding: Optional[Embeddings] = None, text_column: Optional[str] = None, columns: Optional[List[str]] = None, ): try: from databricks.vector_search.client import VectorSearchIndex except ImportError as e: raise ImportError( "Could not import databricks-vectorsearch python package. " "Please install it with `pip install databricks-vectorsearch`." ) from e # index self.index = index if not isinstance(index, VectorSearchIndex): raise TypeError("index must be of type VectorSearchIndex.") # index_details index_details = self.index.describe() self.primary_key = index_details["primary_key"] self.index_type = index_details.get("index_type") self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict()) self._direct_access_index_spec = index_details.get( "direct_access_index_spec", dict() ) # text_column if self._is_databricks_managed_embeddings(): index_source_column = self._embedding_source_column_name() # check if input text column matches the source column of the index if text_column is not None and text_column != index_source_column: raise ValueError( f"text_column '{text_column}' does not match with the " f"source column of the index: '{index_source_column}'." ) self.text_column = index_source_column else: self._require_arg(text_column, "text_column") self.text_column = text_column # columns self.columns = columns or [] # add primary key column and source column if not in columns if self.primary_key not in self.columns: self.columns.append(self.primary_key) if self.text_column and self.text_column not in self.columns: self.columns.append(self.text_column) # Validate specified columns are in the index if self._is_direct_access_index(): index_schema = self._index_schema() if index_schema: for col in self.columns: if col not in index_schema: raise ValueError( f"column '{col}' is not in the index's schema." ) # embedding model if not self._is_databricks_managed_embeddings(): # embedding model is required for direct-access index # or delta-sync index with self-managed embedding self._require_arg(embedding, "embedding") self._embedding = embedding # validate dimension matches index_embedding_dimension = self._embedding_vector_column_dimension() if index_embedding_dimension is not None: inferred_embedding_dimension = self._infer_embedding_dimension() if inferred_embedding_dimension != index_embedding_dimension: raise ValueError( f"embedding model's dimension '{inferred_embedding_dimension}' " f"does not match with the index's dimension " f"'{index_embedding_dimension}'." ) else: if embedding is not None: logger.warning( "embedding model is not used in delta-sync index with " "Databricks-managed embeddings." ) self._embedding = None
[docs] @classmethod def from_texts( cls: Type[VST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[Dict]] = None, **kwargs: Any, ) -> VST: raise NotImplementedError( "`from_texts` is not supported. " "Use `add_texts` to add to existing direct-access index." )
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[Dict]] = None, ids: Optional[List[Any]] = None, **kwargs: Any, ) -> List[str]: """Add texts to the index. Only support direct-access index. Args: texts: List of texts to add. metadatas: List of metadata for each text. Defaults to None. ids: List of ids for each text. Defaults to None. If not provided, a random uuid will be generated for each text. Returns: List of ids from adding the texts into the index. """ self._op_require_direct_access_index("add_texts") assert self.embeddings is not None, "embedding model is required." # Wrap to list if input texts is a single string if isinstance(texts, str): texts = [texts] texts = list(texts) vectors = self.embeddings.embed_documents(texts) ids = ids or [str(uuid.uuid4()) for _ in texts] metadatas = metadatas or [{} for _ in texts] updates = [ { self.primary_key: id_, self.text_column: text, self._embedding_vector_column_name(): vector, **metadata, } for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas) ] upsert_resp = self.index.upsert(updates) if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"): failed_ids = upsert_resp.get("result", dict()).get( "failed_primary_keys", [] ) if upsert_resp.get("status") == "FAILURE": logger.error("Failed to add texts to the index.") else: logger.warning("Some texts failed to be added to the index.") return [id_ for id_ in ids if id_ not in failed_ids] return ids
@property def embeddings(self) -> Optional[Embeddings]: """Access the query embedding object if available.""" return self._embedding
[docs] def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]: """Delete documents from the index. Only support direct-access index. Args: ids: List of ids of documents to delete. Returns: True if successful. """ self._op_require_direct_access_index("delete") if ids is None: raise ValueError("ids must be provided.") self.index.delete(ids) return True
[docs] def similarity_search_with_score( self, query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, *, query_type: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query, along with scores. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents most similar to the embedding and score for each. """ if self._is_databricks_managed_embeddings(): query_text = query query_vector = None else: assert self.embeddings is not None, "embedding model is required." query_text = None query_vector = self.embeddings.embed_query(query) search_resp = self.index.similarity_search( columns=self.columns, query_text=query_text, query_vector=query_vector, filters=filter or _alias_filters(kwargs), num_results=k, query_type=query_type, ) return self._parse_search_response(search_resp)
@staticmethod def _identity_fn(score: float) -> float: return score def _select_relevance_score_fn(self) -> Callable[[float], float]: """ Databricks Vector search uses a normalized score 1/(1+d) where d is the L2 distance. Hence, we simply return the identity function. """ return self._identity_fn
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Any] = None, *, query_type: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents selected by maximal marginal relevance. """ if not self._is_databricks_managed_embeddings(): embedding_column = self._embedding_vector_column_name() else: raise ValueError( "`max_marginal_relevance_search` is not supported for index with " "Databricks-managed embeddings." ) search_resp = self.index.similarity_search( columns=list(set(self.columns + [embedding_column])), query_text=None, query_vector=embedding, filters=filter or _alias_filters(kwargs), num_results=fetch_k, query_type=query_type, ) embeddings_result_index = ( search_resp.get("manifest").get("columns").index({"name": embedding_column}) ) embeddings = [ doc[embeddings_result_index] for doc in search_resp.get("result").get("data_array") ] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), embeddings, k=k, lambda_mult=lambda_mult, ) ignore_cols: List = ( [embedding_column] if embedding_column not in self.columns else [] ) candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] return selected_results
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Any] = None, *, query_type: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents most similar to the embedding. """ docs_with_score = self.similarity_search_by_vector_with_score( embedding=embedding, k=k, filter=filter, query_type=query_type, **kwargs, ) return [doc for doc, _ in docs_with_score]
[docs] def similarity_search_by_vector_with_score( self, embedding: List[float], k: int = 4, filter: Optional[Any] = None, *, query_type: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to embedding vector, along with scores. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. filter: Filters to apply to the query. Defaults to None. query_type: The type of this query. Supported values are "ANN" and "HYBRID". Returns: List of Documents most similar to the embedding and score for each. """ if self._is_databricks_managed_embeddings(): raise ValueError( "`similarity_search_by_vector` is not supported for index with " "Databricks-managed embeddings." ) search_resp = self.index.similarity_search( columns=self.columns, query_vector=embedding, filters=filter or _alias_filters(kwargs), num_results=k, query_type=query_type, ) return self._parse_search_response(search_resp)
def _parse_search_response( self, search_resp: Dict, ignore_cols: Optional[List[str]] = None ) -> List[Tuple[Document, float]]: """Parse the search response into a list of Documents with score.""" if ignore_cols is None: ignore_cols = [] columns = [ col["name"] for col in search_resp.get("manifest", dict()).get("columns", []) ] docs_with_score = [] for result in search_resp.get("result", dict()).get("data_array", []): doc_id = result[columns.index(self.primary_key)] text_content = result[columns.index(self.text_column)] metadata = { col: value for col, value in zip(columns[:-1], result[:-1]) if col not in ([self.primary_key, self.text_column] + ignore_cols) } metadata[self.primary_key] = doc_id score = result[-1] doc = Document(page_content=text_content, metadata=metadata) docs_with_score.append((doc, score)) return docs_with_score def _index_schema(self) -> Optional[Dict]: """Return the index schema as a dictionary. Return None if no schema found. """ if self._is_direct_access_index(): schema_json = self._direct_access_index_spec.get("schema_json") if schema_json is not None: return json.loads(schema_json) return None def _embedding_vector_column_name(self) -> Optional[str]: """Return the name of the embedding vector column. None if the index is not a self-managed embedding index. """ return self._embedding_vector_column().get("name") def _embedding_vector_column_dimension(self) -> Optional[int]: """Return the dimension of the embedding vector column. None if the index is not a self-managed embedding index. """ return self._embedding_vector_column().get("embedding_dimension") def _embedding_vector_column(self) -> Dict: """Return the embedding vector column configs as a dictionary. Empty if the index is not a self-managed embedding index. """ index_spec = ( self._delta_sync_index_spec if self._is_delta_sync_index() else self._direct_access_index_spec ) return next(iter(index_spec.get("embedding_vector_columns") or list()), dict()) def _embedding_source_column_name(self) -> Optional[str]: """Return the name of the embedding source column. None if the index is not a Databricks-managed embedding index. """ return self._embedding_source_column().get("name") def _embedding_source_column(self) -> Dict: """Return the embedding source column configs as a dictionary. Empty if the index is not a Databricks-managed embedding index. """ index_spec = self._delta_sync_index_spec return next(iter(index_spec.get("embedding_source_columns") or list()), dict()) def _is_delta_sync_index(self) -> bool: """Return True if the index is a delta-sync index.""" return self.index_type == "DELTA_SYNC" def _is_direct_access_index(self) -> bool: """Return True if the index is a direct-access index.""" return self.index_type == "DIRECT_ACCESS" def _is_databricks_managed_embeddings(self) -> bool: """Return True if the embeddings are managed by Databricks Vector Search.""" return ( self._is_delta_sync_index() and self._embedding_source_column_name() is not None ) def _infer_embedding_dimension(self) -> int: """Infer the embedding dimension from the embedding function.""" assert self.embeddings is not None, "embedding model is required." return len(self.embeddings.embed_query("test")) def _op_require_direct_access_index(self, op_name: str) -> None: """ Raise ValueError if the operation is not supported for direct-access index.""" if not self._is_direct_access_index(): raise ValueError(f"`{op_name}` is only supported for direct-access index.") @staticmethod def _require_arg(arg: Any, arg_name: str) -> None: """Raise ValueError if the required arg with name `arg_name` is None.""" if not arg: raise ValueError(f"`{arg_name}` is required for this index.")
def _alias_filters(kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ The `filters` argument was used in the previous versions. It is now replaced with `filter` for consistency with other vector stores, but we still support `filters` for backward compatibility. """ if "filters" in kwargs: warn_deprecated( since="0.2.11", removal="0.3", message="DatabricksVectorSearch received a key `filters` in search_kwargs. " "`filters` was deprecated since langchain-community 0.2.11 and will " "be removed in 0.3. Please use `filter` instead.", ) return kwargs.pop("filters", None)