Source code for langchain_core.runnables.fallbacks

import asyncio
import inspect
import typing
from contextvars import copy_context
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    AsyncIterator,
    Dict,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
    RunnableConfig,
    _set_config_context,
    ensure_config,
    get_async_callback_manager_for_config,
    get_callback_manager_for_config,
    get_config_list,
    patch_config,
)
from langchain_core.runnables.utils import (
    ConfigurableFieldSpec,
    Input,
    Output,
    asyncio_accepts_context,
    get_unique_config_specs,
)
from langchain_core.utils.aiter import py_anext

if TYPE_CHECKING:
    from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun


[docs]class RunnableWithFallbacks(RunnableSerializable[Input, Output]): """Runnable that can fallback to other Runnables if it fails. External APIs (e.g., APIs for a language model) may at times experience degraded performance or even downtime. In these cases, it can be useful to have a fallback Runnable that can be used in place of the original Runnable (e.g., fallback to another LLM provider). Fallbacks can be defined at the level of a single Runnable, or at the level of a chain of Runnables. Fallbacks are tried in order until one succeeds or all fail. While you can instantiate a ``RunnableWithFallbacks`` directly, it is usually more convenient to use the ``with_fallbacks`` method on a Runnable. Example: .. code-block:: python from langchain_core.chat_models.openai import ChatOpenAI from langchain_core.chat_models.anthropic import ChatAnthropic model = ChatAnthropic( model="claude-3-haiku-20240307" ).with_fallbacks([ChatOpenAI(model="gpt-3.5-turbo-0125")]) # Will usually use ChatAnthropic, but fallback to ChatOpenAI # if ChatAnthropic fails. model.invoke('hello') # And you can also use fallbacks at the level of a chain. # Here if both LLM providers fail, we'll fallback to a good hardcoded # response. from langchain_core.prompts import PromptTemplate from langchain_core.output_parser import StrOutputParser from langchain_core.runnables import RunnableLambda def when_all_is_lost(inputs): return ("Looks like our LLM providers are down. " "Here's a nice 🦜️ emoji for you instead.") chain_with_fallback = ( PromptTemplate.from_template('Tell me a joke about {topic}') | model | StrOutputParser() ).with_fallbacks([RunnableLambda(when_all_is_lost)]) """ runnable: Runnable[Input, Output] """The Runnable to run first.""" fallbacks: Sequence[Runnable[Input, Output]] """A sequence of fallbacks to try.""" exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) """The exceptions on which fallbacks should be tried. Any exception that is not a subclass of these exceptions will be raised immediately. """ exception_key: Optional[str] = None """If string is specified then handled exceptions will be passed to fallbacks as part of the input under the specified key. If None, exceptions will not be passed to fallbacks. If used, the base Runnable and its fallbacks must accept a dictionary as input.""" class Config: arbitrary_types_allowed = True @property def InputType(self) -> Type[Input]: return self.runnable.InputType @property def OutputType(self) -> Type[Output]: return self.runnable.OutputType def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return self.runnable.get_input_schema(config) def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return self.runnable.get_output_schema(config) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( spec for step in [self.runnable, *self.fallbacks] for spec in step.config_specs ) @classmethod def is_lc_serializable(cls) -> bool: return True @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] @property def runnables(self) -> Iterator[Runnable[Input, Output]]: yield self.runnable yield from self.fallbacks
[docs] def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: if self.exception_key is not None and not isinstance(input, dict): raise ValueError( "If 'exception_key' is specified then input must be a dictionary." f"However found a type of {type(input)} for input" ) # setup callbacks config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) first_error = None last_error = None for runnable in self.runnables: try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) output = context.run( runnable.invoke, input, **kwargs, ) except self.exceptions_to_handle as e: if first_error is None: first_error = e last_error = e except BaseException as e: run_manager.on_chain_error(e) raise e else: run_manager.on_chain_end(output) return output if first_error is None: raise ValueError("No error stored at end of fallbacks.") run_manager.on_chain_error(first_error) raise first_error
[docs] async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: if self.exception_key is not None and not isinstance(input, dict): raise ValueError( "If 'exception_key' is specified then input must be a dictionary." f"However found a type of {type(input)} for input" ) # setup callbacks config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) first_error = None last_error = None for runnable in self.runnables: try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) coro = runnable.ainvoke(input, child_config, **kwargs) if asyncio_accepts_context(): output = await asyncio.create_task(coro, context=context) # type: ignore else: output = await coro except self.exceptions_to_handle as e: if first_error is None: first_error = e last_error = e except BaseException as e: await run_manager.on_chain_error(e) raise e else: await run_manager.on_chain_end(output) return output if first_error is None: raise ValueError("No error stored at end of fallbacks.") await run_manager.on_chain_error(first_error) raise first_error
[docs] def batch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain_core.callbacks.manager import CallbackManager if self.exception_key is not None and not all( isinstance(input, dict) for input in inputs ): raise ValueError( "If 'exception_key' is specified then inputs must be dictionaries." f"However found a type of {type(inputs[0])} for input" ) if not inputs: return [] # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, verbose=False, inheritable_tags=config.get("tags"), local_tags=None, inheritable_metadata=config.get("metadata"), local_metadata=None, ) for config in configs ] # start the root runs, one per input run_managers = [ cm.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, name=config.get("run_name"), run_id=config.pop("run_id", None), ) for cm, input, config in zip(callback_managers, inputs, configs) ] to_return: Dict[int, Any] = {} run_again = {i: input for i, input in enumerate(inputs)} handled_exceptions: Dict[int, BaseException] = {} first_to_raise = None for runnable in self.runnables: outputs = runnable.batch( [input for _, input in sorted(run_again.items())], [ # each step a child run of the corresponding root run patch_config(configs[i], callbacks=run_managers[i].get_child()) for i in sorted(run_again) ], return_exceptions=True, **kwargs, ) for (i, input), output in zip(sorted(run_again.copy().items()), outputs): if isinstance(output, BaseException) and not isinstance( output, self.exceptions_to_handle ): if not return_exceptions: first_to_raise = first_to_raise or output else: handled_exceptions[i] = cast(BaseException, output) run_again.pop(i) elif isinstance(output, self.exceptions_to_handle): if self.exception_key: input[self.exception_key] = output # type: ignore handled_exceptions[i] = cast(BaseException, output) else: run_managers[i].on_chain_end(output) to_return[i] = output run_again.pop(i) handled_exceptions.pop(i, None) if first_to_raise: raise first_to_raise if not run_again: break sorted_handled_exceptions = sorted(handled_exceptions.items()) for i, error in sorted_handled_exceptions: run_managers[i].on_chain_error(error) if not return_exceptions and sorted_handled_exceptions: raise sorted_handled_exceptions[0][1] to_return.update(handled_exceptions) return [output for _, output in sorted(to_return.items())]
[docs] async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain_core.callbacks.manager import AsyncCallbackManager if self.exception_key is not None and not all( isinstance(input, dict) for input in inputs ): raise ValueError( "If 'exception_key' is specified then inputs must be dictionaries." f"However found a type of {type(inputs[0])} for input" ) if not inputs: return [] # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, verbose=False, inheritable_tags=config.get("tags"), local_tags=None, inheritable_metadata=config.get("metadata"), local_metadata=None, ) for config in configs ] # start the root runs, one per input run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( cm.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) for cm, input, config in zip(callback_managers, inputs, configs) ) ) to_return = {} run_again = {i: input for i, input in enumerate(inputs)} handled_exceptions: Dict[int, BaseException] = {} first_to_raise = None for runnable in self.runnables: outputs = await runnable.abatch( [input for _, input in sorted(run_again.items())], [ # each step a child run of the corresponding root run patch_config(configs[i], callbacks=run_managers[i].get_child()) for i in sorted(run_again) ], return_exceptions=True, **kwargs, ) for (i, input), output in zip(sorted(run_again.copy().items()), outputs): if isinstance(output, BaseException) and not isinstance( output, self.exceptions_to_handle ): if not return_exceptions: first_to_raise = first_to_raise or output else: handled_exceptions[i] = cast(BaseException, output) run_again.pop(i) elif isinstance(output, self.exceptions_to_handle): if self.exception_key: input[self.exception_key] = output # type: ignore handled_exceptions[i] = cast(BaseException, output) else: to_return[i] = output await run_managers[i].on_chain_end(output) run_again.pop(i) handled_exceptions.pop(i, None) if first_to_raise: raise first_to_raise if not run_again: break sorted_handled_exceptions = sorted(handled_exceptions.items()) await asyncio.gather( *( run_managers[i].on_chain_error(error) for i, error in sorted_handled_exceptions ) ) if not return_exceptions and sorted_handled_exceptions: raise sorted_handled_exceptions[0][1] to_return.update(handled_exceptions) return [output for _, output in sorted(to_return.items())] # type: ignore
[docs] def stream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: """""" if self.exception_key is not None and not isinstance(input, dict): raise ValueError( "If 'exception_key' is specified then input must be a dictionary." f"However found a type of {type(input)} for input" ) # setup callbacks config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) first_error = None last_error = None for runnable in self.runnables: try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) stream = context.run( runnable.stream, input, **kwargs, ) chunk: Output = context.run(next, stream) # type: ignore except self.exceptions_to_handle as e: first_error = e if first_error is None else first_error last_error = e except BaseException as e: run_manager.on_chain_error(e) raise e else: first_error = None break if first_error: run_manager.on_chain_error(first_error) raise first_error yield chunk output: Optional[Output] = chunk try: for chunk in stream: yield chunk try: output = output + chunk # type: ignore except TypeError: output = None except BaseException as e: run_manager.on_chain_error(e) raise e run_manager.on_chain_end(output)
[docs] async def astream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: if self.exception_key is not None and not isinstance(input, dict): raise ValueError( "If 'exception_key' is specified then input must be a dictionary." f"However found a type of {type(input)} for input" ) # setup callbacks config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) first_error = None last_error = None for runnable in self.runnables: try: if self.exception_key and last_error is not None: input[self.exception_key] = last_error child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) stream = runnable.astream( input, child_config, **kwargs, ) if asyncio_accepts_context(): chunk: Output = await asyncio.create_task( # type: ignore[call-arg] py_anext(stream), # type: ignore[arg-type] context=context, ) else: chunk = cast(Output, await py_anext(stream)) except self.exceptions_to_handle as e: first_error = e if first_error is None else first_error last_error = e except BaseException as e: await run_manager.on_chain_error(e) raise e else: first_error = None break if first_error: await run_manager.on_chain_error(first_error) raise first_error yield chunk output: Optional[Output] = chunk try: async for chunk in stream: yield chunk try: output = output + chunk # type: ignore except TypeError: output = None except BaseException as e: await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(output)
def __getattr__(self, name: str) -> Any: """Get an attribute from the wrapped Runnable and its fallbacks. Returns: If the attribute is anything other than a method that outputs a Runnable, returns getattr(self.runnable, name). If the attribute is a method that does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new RunnableBinding) then self.runnable and each of the runnables in self.fallbacks is replaced with getattr(x, name). Example: .. code-block:: python from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic gpt_4o = ChatOpenAI(model="gpt-4o") claude_3_sonnet = ChatAnthropic(model="claude-3-sonnet-20240229") llm = gpt_4o.with_fallbacks([claude_3_sonnet]) llm.model_name # -> "gpt-4o" # .bind_tools() is called on both ChatOpenAI and ChatAnthropic # Equivalent to: # gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])]) llm.bind_tools([...]) # -> RunnableWithFallbacks( runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}), fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})], ) """ # noqa: E501 attr = getattr(self.runnable, name) if _returns_runnable(attr): @wraps(attr) def wrapped(*args: Any, **kwargs: Any) -> Any: new_runnable = attr(*args, **kwargs) new_fallbacks = [] for fallback in self.fallbacks: fallback_attr = getattr(fallback, name) new_fallbacks.append(fallback_attr(*args, **kwargs)) return self.__class__( **{ **self.dict(), **{"runnable": new_runnable, "fallbacks": new_fallbacks}, } ) return wrapped return attr
def _returns_runnable(attr: Any) -> bool: if not callable(attr): return False return_type = typing.get_type_hints(attr).get("return") return bool(return_type and _is_runnable_type(return_type)) def _is_runnable_type(type_: Any) -> bool: if inspect.isclass(type_): return issubclass(type_, Runnable) origin = getattr(type_, "__origin__", None) if inspect.isclass(origin): return issubclass(origin, Runnable) elif origin is typing.Union: return all(_is_runnable_type(t) for t in type_.__args__) else: return False