Source code for langchain_community.embeddings.voyageai

from __future__ import annotations

import json
import logging
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
    cast,
)

import requests
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import (
    before_sleep_log,
    retry,
    stop_after_attempt,
    wait_exponential,
)

logger = logging.getLogger(__name__)


def _create_retry_decorator(embeddings: VoyageEmbeddings) -> Callable[[Any], Any]:
    min_seconds = 4
    max_seconds = 10
    # Wait 2^x * 1 second between each retry starting with
    # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
    return retry(
        reraise=True,
        stop=stop_after_attempt(embeddings.max_retries),
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )


def _check_response(response: dict) -> dict:
    if "data" not in response:
        raise RuntimeError(f"Voyage API Error. Message: {json.dumps(response)}")
    return response


[docs]def embed_with_retry(embeddings: VoyageEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" retry_decorator = _create_retry_decorator(embeddings) @retry_decorator def _embed_with_retry(**kwargs: Any) -> Any: response = requests.post(**kwargs) return _check_response(response.json()) return _embed_with_retry(**kwargs)
[docs]@deprecated( since="0.0.29", removal="0.3", alternative_import="langchain_voyageai.VoyageAIEmbeddings", ) class VoyageEmbeddings(BaseModel, Embeddings): """Voyage embedding models. To use, you should have the environment variable ``VOYAGE_API_KEY`` set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.embeddings import VoyageEmbeddings voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2") text = "This is a test query." query_result = voyage.embed_query(text) """ model: str voyage_api_base: str = "https://api.voyageai.com/v1/embeddings" voyage_api_key: Optional[SecretStr] = None batch_size: int """Maximum number of texts to embed in each API request.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" request_timeout: Optional[Union[float, Tuple[float, float]]] = None """Timeout in seconds for the API request.""" show_progress_bar: bool = False """Whether to show a progress bar when embedding. Must have tqdm installed if set to True.""" truncation: bool = True """Whether to truncate the input texts to fit within the context length. If True, over-length input texts will be truncated to fit within the context length, before vectorized by the embedding model. If False, an error will be raised if any given text exceeds the context length.""" class Config: extra = "forbid" @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["voyage_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY") ) if "model" not in values: values["model"] = "voyage-01" logger.warning( "model will become a required arg for VoyageAIEmbeddings, " "we recommend to specify it when using this class. " "Currently the default is set to voyage-01." ) if "batch_size" not in values: values["batch_size"] = ( 72 if "model" in values and (values["model"] in ["voyage-2", "voyage-02"]) else 7 ) return values def _invocation_params( self, input: List[str], input_type: Optional[str] = None ) -> Dict: api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() params: Dict = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, "json": { "model": self.model, "input": input, "input_type": input_type, "truncation": self.truncation, }, "timeout": self.request_timeout, } return params def _get_embeddings( self, texts: List[str], batch_size: Optional[int] = None, input_type: Optional[str] = None, ) -> List[List[float]]: embeddings: List[List[float]] = [] if batch_size is None: batch_size = self.batch_size if self.show_progress_bar: try: from tqdm.auto import tqdm except ImportError as e: raise ImportError( "Must have tqdm installed if `show_progress_bar` is set to True. " "Please install with `pip install tqdm`." ) from e _iter = tqdm(range(0, len(texts), batch_size)) else: _iter = range(0, len(texts), batch_size) if input_type and input_type not in ["query", "document"]: raise ValueError( f"input_type {input_type} is invalid. Options: None, 'query', " "'document'." ) for i in _iter: response = embed_with_retry( self, **self._invocation_params( input=texts[i : i + batch_size], input_type=input_type ), ) embeddings.extend(r["embedding"] for r in response["data"]) return embeddings
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to Voyage Embedding endpoint for embedding search docs. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ return self._get_embeddings( texts, batch_size=self.batch_size, input_type="document" )
[docs] def embed_query(self, text: str) -> List[float]: """Call out to Voyage Embedding endpoint for embedding query text. Args: text: The text to embed. Returns: Embedding for the text. """ return self._get_embeddings( [text], batch_size=self.batch_size, input_type="query" )[0]
[docs] def embed_general_texts( self, texts: List[str], *, input_type: Optional[str] = None ) -> List[List[float]]: """Call out to Voyage Embedding endpoint for embedding general text. Args: texts: The list of texts to embed. input_type: Type of the input text. Default to None, meaning the type is unspecified. Other options: query, document. Returns: Embedding for the text. """ return self._get_embeddings( texts, batch_size=self.batch_size, input_type=input_type )