Source code for langchain_community.embeddings.tensorflow_hub
from typing import Any, List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
[docs]class TensorflowHubEmbeddings(BaseModel, Embeddings):
    """TensorflowHub embedding models.
    To use, you should have the ``tensorflow_text`` python package installed.
    Example:
        .. code-block:: python
            from langchain_community.embeddings import TensorflowHubEmbeddings
            url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
            tf = TensorflowHubEmbeddings(model_url=url)
    """
    embed: Any  #: :meta private:
    model_url: str = DEFAULT_MODEL_URL
    """Model name to use."""
    def __init__(self, **kwargs: Any):
        """Initialize the tensorflow_hub and tensorflow_text."""
        super().__init__(**kwargs)
        try:
            import tensorflow_hub
        except ImportError:
            raise ImportError(
                "Could not import tensorflow-hub python package. "
                "Please install it with `pip install tensorflow-hub``."
            )
        try:
            import tensorflow_text  # noqa
        except ImportError:
            raise ImportError(
                "Could not import tensorflow_text python package. "
                "Please install it with `pip install tensorflow_text``."
            )
        self.embed = tensorflow_hub.load(self.model_url)
    class Config:
        extra = "forbid"
[docs]    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Compute doc embeddings using a TensorflowHub embedding model.
        Args:
            texts: The list of texts to embed.
        Returns:
            List of embeddings, one for each text.
        """
        texts = list(map(lambda x: x.replace("\n", " "), texts))
        embeddings = self.embed(texts).numpy()
        return embeddings.tolist() 
[docs]    def embed_query(self, text: str) -> List[float]:
        """Compute query embeddings using a TensorflowHub embedding model.
        Args:
            text: The text to embed.
        Returns:
            Embeddings for the text.
        """
        text = text.replace("\n", " ")
        embedding = self.embed([text]).numpy()[0]
        return embedding.tolist()