如何在状态中使用上下文对象¶
有时您需要一些资源(例如数据库连接、请求会话等)在您的图执行期间持续存在,而不会被检查点保存。
LangGraph 支持使用 Context 通道装饰状态键,这将负责
- 在图开始执行之前初始化值,并访问传递给 invoke/stream 的配置
- 在图执行结束时运行您需要的任何清理代码,无论图成功还是出错
上下文通道的参数应该是 ContextManager 类或函数。
设置¶
首先,我们需要安装所需的包
%%capture --no-stderr
%pip install --quiet -U langgraph langchain_openai
接下来,我们需要为 OpenAI(我们将使用的 LLM)和 Tavily(我们将使用的搜索工具)设置 API 密钥
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
可选地,我们可以为 LangSmith 追踪 设置 API 密钥,这将为我们提供一流的可观察性。
os.environ["LANGCHAIN_TRACING_V2"] = "true"
_set_env("LANGCHAIN_API_KEY")
from langchain_core.tools import tool
@tool
def search(query: str):
"""Call to surf the web."""
# This is a placeholder for the actual implementation
# Don't let the LLM know this though 😊
return ["The answer to your question lies within."]
tools = [search]
我们现在可以将这些工具包装在一个简单的 ToolExecutor 中。这是一个非常简单的类,它接收一个 ToolInvocation 并调用该工具,返回输出。
ToolInvocation 是任何具有 tool
和 tool_input
属性的类。
from langgraph.prebuilt import ToolExecutor
tool_executor = ToolExecutor(tools)
设置模型¶
现在我们需要加载我们要使用的聊天模型。重要的是,它应该满足两个标准
- 它应该与消息一起使用。我们将所有代理状态表示为消息形式,因此它需要能够很好地与消息一起使用。
- 它应该与 OpenAI 函数调用一起使用。这意味着它应该是一个 OpenAI 模型或一个公开类似接口的模型。
注意:这些模型要求不是使用 LangGraph 的要求 - 它们只是此示例的要求。
from langchain_openai import ChatOpenAI
model = ChatOpenAI(temperature=0)
完成此操作后,我们应该确保该模型知道它可以使用这些工具进行调用。我们可以通过将 LangChain 工具转换为 OpenAI 函数调用的格式,然后将它们绑定到模型类来实现。
model = model.bind_tools(tools)
定义上下文对象¶
在这里,我们将上下文对象定义为 pydantic 模型,该模型由使用 @contextmanager 装饰的工厂函数创建。@contextmanager 确保可以在执行结束时运行您需要的任何清理代码
import httpx
from contextlib import contextmanager
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig
class AgentContext(BaseModel):
class Config:
arbitrary_types_allowed = True
httpx_session: httpx.Client
@contextmanager
def make_agent_context(config: RunnableConfig):
# here you could read the config values passed invoke/stream to customize the context object
# as an example, we create an httpx session, which could then be used in your graph's nodes
session = httpx.Client()
try:
yield AgentContext(httpx_session=session)
finally:
session.close()
定义代理状态¶
langgraph
中的主要图类型是 StateGraph。此图由它传递给每个节点的状态对象参数化。然后,每个节点返回用于更新该状态的操作。这些操作可以 SET 状态上的特定属性(例如覆盖现有值)或添加到现有属性中。是否设置或添加取决于您使用其构造图的状态对象进行注释。
对于这个示例,我们将跟踪的状态将只是一个消息列表。我们希望每个节点只将消息添加到该列表中。因此,我们将使用一个 pydantic.BaseModel
,它具有一个键(messages
),并对其进行注释,以便将 messages
属性视为“追加”。
import operator
from typing import Annotated, Sequence
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import BaseModel
from langgraph.channels.context import Context
class AgentState(BaseModel):
messages: Annotated[Sequence[BaseMessage], operator.add]
context: Annotated[AgentContext, Context(make_agent_context)]
定义节点¶
现在我们需要在我们的图中定义几个不同的节点。在 langgraph
中,一个节点可以是函数或 可运行。我们需要定义两个主要节点
- 代理:负责决定要采取什么(如果有)操作。
- 一个调用工具的函数:如果代理决定采取操作,则此节点将执行该操作。
我们还需要定义一些边。其中一些边可能是条件性的。它们是条件性的原因是,根据节点的输出,可能会走几条路径之一。在运行该节点之前,无法知道要走哪条路径(由 LLM 决定)。
- 条件边:调用代理后,我们应该:a. 如果代理说要采取操作,那么应该调用调用工具的函数 b. 如果代理说它已经完成,那么它应该完成
- 普通边:调用工具后,它应该始终返回代理以决定下一步要做什么
让我们定义节点,以及一个函数来决定如何选择哪个条件边。
修改
我们将每个节点定义为接收 AgentState 基本模型作为其第一个参数。
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import ToolInvocation
# Define the function that determines whether to continue or not
def should_continue(state):
messages = state.messages
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(state):
# using context value
req = state.context.httpx_session.get("https://langchain.ac.cn/")
assert req.status_code == 200, req
messages = state.messages
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define the function to execute tools
def call_tool(state):
messages = state.messages
# Based on the continue condition
# we know the last message involves a function call
last_message = messages[-1]
# We construct an ToolInvocation from the function_call
tool_call = last_message.tool_calls[0]
action = ToolInvocation(
tool=tool_call["name"],
tool_input=tool_call["args"],
)
# We call the tool_executor and get back a response
response = tool_executor.invoke(action)
# We use the response to create a ToolMessage
tool_message = ToolMessage(
content=str(response), name=action.tool, tool_call_id=tool_call["id"]
)
# We return a list, because this will get added to the existing list
return {"messages": [tool_message]}
定义图¶
我们现在可以将所有内容组合在一起并定义图了!
from langgraph.graph import END, StateGraph, START
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
from IPython.display import Image, display
display(Image(app.get_graph().draw_mermaid_png()))
from langchain_core.messages import HumanMessage
inputs = {"messages": [HumanMessage(content="what is the weather in sf")]}
for chunk in app.stream(inputs, stream_mode="values"):
chunk["messages"][-1].pretty_print()