In [1]
已复制!
%%capture --no-stderr
%pip install --quiet -U langgraph langchain_openai numpy
%%capture --no-stderr %pip install --quiet -U langgraph langchain_openai numpy
In [2]
已复制!
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
import getpass import os def _set_env(var: str): if not os.environ.get(var): os.environ[var] = getpass.getpass(f"{var}: ") _set_env("OPENAI_API_KEY")
定义工具¶
让我们考虑一个玩具示例,其中我们为 标准普尔 500 指数 中的每家上市公司提供一个工具。每个工具根据作为参数提供的年份获取公司特定的信息。
我们首先构建一个注册表,将唯一标识符与每个工具的模式相关联。我们将使用 JSON 模式来表示工具,该模式可以与支持工具调用的聊天模型直接绑定。
In [10]
已复制!
import re
import uuid
from langchain_core.tools import StructuredTool
def create_tool(company: str) -> dict:
"""Create schema for a placeholder tool."""
# Remove non-alphanumeric characters and replace spaces with underscores for the tool name
formatted_company = re.sub(r"[^\w\s]", "", company).replace(" ", "_")
def company_tool(year: int) -> str:
# Placeholder function returning static revenue information for the company and year
return f"{company} had revenues of $100 in {year}."
return StructuredTool.from_function(
company_tool,
name=formatted_company,
description=f"Information about {company}",
)
# Abbreviated list of S&P 500 companies for demonstration
s_and_p_500_companies = [
"3M",
"A.O. Smith",
"Abbott",
"Accenture",
"Advanced Micro Devices",
"Yum! Brands",
"Zebra Technologies",
"Zimmer Biomet",
"Zoetis",
]
# Create a tool for each company and store it in a registry with a unique UUID as the key
tool_registry = {
str(uuid.uuid4()): create_tool(company) for company in s_and_p_500_companies
}
import re import uuid from langchain_core.tools import StructuredTool def create_tool(company: str) -> dict: """创建占位符工具的模式.""" # 删除非字母数字字符并将空格替换为下划线,以生成工具名称 formatted_company = re.sub(r"[^\w\s]", "", company).replace(" ", "_") def company_tool(year: int) -> str: # 占位符函数,返回公司和年份的静态收入信息 return f"{company} 在 {year} 年的收入为 100 美元。" return StructuredTool.from_function( company_tool, name=formatted_company, description=f"关于 {company} 的信息", ) # 用于演示的标准普尔 500 公司的缩略列表 s_and_p_500_companies = [ "3M", "A.O. Smith", "Abbott", "Accenture", "Advanced Micro Devices", "Yum! Brands", "Zebra Technologies", "Zimmer Biomet", "Zoetis", ] # 为每家公司创建一个工具,并将其存储在带有唯一 UUID 作为键的注册表中 tool_registry = { str(uuid.uuid4()): create_tool(company) for company in s_and_p_500_companies }
定义图¶
工具选择¶
我们将构建一个节点,该节点根据状态中的信息(例如最近的用户消息)检索可用工具的子集。一般来说,完整的检索解决方案 范围都适用于此步骤。作为一个简单的解决方案,我们在向量存储中索引工具描述的嵌入,并通过语义搜索将用户查询与工具相关联。
In [11]
已复制!
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
tool_documents = [
Document(
page_content=tool.description,
id=id,
metadata={"tool_name": tool.name},
)
for id, tool in tool_registry.items()
]
vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
document_ids = vector_store.add_documents(tool_documents)
from langchain_core.documents import Document from langchain_core.vectorstores import InMemoryVectorStore from langchain_openai import OpenAIEmbeddings tool_documents = [ Document( page_content=tool.description, id=id, metadata={"tool_name": tool.name}, ) for id, tool in tool_registry.items() ] vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings()) document_ids = vector_store.add_documents(tool_documents)
In [12]
已复制!
from typing import Annotated
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
# Define the state structure using TypedDict.
# It includes a list of messages (processed by add_messages)
# and a list of selected tool IDs.
class State(TypedDict):
messages: Annotated[list, add_messages]
selected_tools: list[str]
builder = StateGraph(State)
# Retrieve all available tools from the tool registry.
tools = list(tool_registry.values())
llm = ChatOpenAI()
# The agent function processes the current state
# by binding selected tools to the LLM.
def agent(state: State):
# Map tool IDs to actual tools
# based on the state's selected_tools list.
selected_tools = [tool_registry[id] for id in state["selected_tools"]]
# Bind the selected tools to the LLM for the current interaction.
llm_with_tools = llm.bind_tools(selected_tools)
# Invoke the LLM with the current messages and return the updated message list.
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# The select_tools function selects tools based on the user's last message content.
def select_tools(state: State):
last_user_message = state["messages"][-1]
query = last_user_message.content
tool_documents = vector_store.similarity_search(query)
return {"selected_tools": [document.id for document in tool_documents]}
builder.add_node("agent", agent)
builder.add_node("select_tools", select_tools)
tool_node = ToolNode(tools=tools)
builder.add_node("tools", tool_node)
builder.add_conditional_edges(
"agent",
tools_condition,
)
builder.add_edge("tools", "agent")
builder.add_edge("select_tools", "agent")
builder.add_edge(START, "select_tools")
graph = builder.compile()
from typing import Annotated from langchain_openai import ChatOpenAI from typing_extensions import TypedDict from langgraph.graph import StateGraph, START from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode, tools_condition # 使用 TypedDict 定义状态结构。 # 它包含一个消息列表(由 add_messages 处理) # 和一个选定工具 ID 列表。 class State(TypedDict): messages: Annotated[list, add_messages] selected_tools: list[str] builder = StateGraph(State) # 从工具注册表中检索所有可用工具。 tools = list(tool_registry.values()) llm = ChatOpenAI() # 代理函数处理当前状态 # 通过将选定的工具绑定到 LLM。 def agent(state: State): # 根据状态的 selected_tools 列表将工具 ID 映射到实际工具。 selected_tools = [tool_registry[id] for id in state["selected_tools"]] # 将选定的工具绑定到 LLM 以进行当前交互。 llm_with_tools = llm.bind_tools(selected_tools) # 使用当前消息调用 LLM 并返回更新的消息列表。 return {"messages": [llm_with_tools.invoke(state["messages"])]} # select_tools 函数根据用户的最后一条消息内容选择工具。 def select_tools(state: State): last_user_message = state["messages"][-1] query = last_user_message.content tool_documents = vector_store.similarity_search(query) return {"selected_tools": [document.id for document in tool_documents]} builder.add_node("agent", agent) builder.add_node("select_tools", select_tools) tool_node = ToolNode(tools=tools) builder.add_node("tools", tool_node) builder.add_conditional_edges( "agent", tools_condition, ) builder.add_edge("tools", "agent") builder.add_edge("select_tools", "agent") builder.add_edge(START, "select_tools") graph = builder.compile()
In [13]
已复制!
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
from IPython.display import Image, display try: display(Image(graph.get_graph().draw_mermaid_png())) except Exception: # 这需要一些额外的依赖项,并且是可选的 pass
In [14]
已复制!
user_input = "Can you give me some information about AMD in 2022?"
result = graph.invoke({"messages": [("user", user_input)]})
user_input = "你能告诉我一些关于 AMD 在 2022 年的信息吗?" result = graph.invoke({"messages": [("user", user_input)]})
In [15]
已复制!
print(result["selected_tools"])
print(result["selected_tools"])
['ab9c0d59-3d16-448d-910c-73cf10a26020', 'f5eff8f6-7fb9-47b6-b54f-19872a52db84', '2962e168-9ef4-48dc-8b7c-9227e7956d39', '24a9fb82-19fe-4a88-944e-47bc4032e94a']
In [16]
已复制!
for message in result["messages"]:
message.pretty_print()
for message in result["messages"]: message.pretty_print()
================================ Human Message ================================= Can you give me some information about AMD in 2022? ================================== Ai Message ================================== Tool Calls: Advanced_Micro_Devices (call_CRxQ0oT7NY7lqf35DaRNTJ35) Call ID: call_CRxQ0oT7NY7lqf35DaRNTJ35 Args: year: 2022 ================================= Tool Message ================================= Name: Advanced_Micro_Devices Advanced Micro Devices had revenues of $100 in 2022. ================================== Ai Message ================================== In 2022, Advanced Micro Devices (AMD) had revenues of $100.
重复工具选择¶
为了管理错误的工具选择导致的错误,我们可以重新访问 select_tools
节点。实现此方法的一个选项是修改 select_tools
以使用状态中的所有消息(例如,使用聊天模型)生成向量存储查询,并添加从 tools
到 select_tools
的路由边。
我们在下面实现了此更改。为了演示目的,我们在 select_tools
节点的第一次迭代中添加了 hack_remove_tool_condition
来模拟初始工具选择中的错误,该错误会删除正确的工具。请注意,在第二次迭代中,代理完成了运行,因为它可以访问正确的工具。
在 LangChain 中使用 Pydantic
此笔记本使用 Pydantic v2 BaseModel
,它需要 langchain-core >= 0.3
。使用 langchain-core < 0.3
将导致错误,因为混合了 Pydantic v1 和 v2 BaseModels
。
In [46]
已复制!
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langgraph.pregel.retry import RetryPolicy
from pydantic import BaseModel, Field
class QueryForTools(BaseModel):
"""Generate a query for additional tools."""
query: str = Field(..., description="Query for additional tools.")
def select_tools(state: State):
"""Selects tools based on the last message in the conversation state.
If the last message is from a human, directly uses the content of the message
as the query. Otherwise, constructs a query using a system message and invokes
the LLM to generate tool suggestions.
"""
last_message = state["messages"][-1]
hack_remove_tool_condition = False # Simulate an error in the first tool selection
if isinstance(last_message, HumanMessage):
query = last_message.content
hack_remove_tool_condition = True # Simulate wrong tool selection
else:
assert isinstance(last_message, ToolMessage)
system = SystemMessage(
"Given this conversation, generate a query for additional tools. "
"The query should be a short string containing what type of information "
"is needed. If no further information is needed, "
"set more_information_needed False and populate a blank string for the query."
)
input_messages = [system] + state["messages"]
response = llm.bind_tools([QueryForTools], tool_choice=True).invoke(
input_messages
)
query = response.tool_calls[0]["args"]["query"]
# Search the tool vector store using the generated query
tool_documents = vector_store.similarity_search(query)
if hack_remove_tool_condition:
# Simulate error by removing the correct tool from the selection
selected_tools = [
document.id
for document in tool_documents
if document.metadata["tool_name"] != "Advanced_Micro_Devices"
]
else:
selected_tools = [document.id for document in tool_documents]
return {"selected_tools": selected_tools}
graph_builder = StateGraph(State)
graph_builder.add_node("agent", agent)
graph_builder.add_node("select_tools", select_tools, retry=RetryPolicy(max_attempts=3))
tool_node = ToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges(
"agent",
tools_condition,
)
graph_builder.add_edge("tools", "select_tools")
graph_builder.add_edge("select_tools", "agent")
graph_builder.add_edge(START, "select_tools")
graph = graph_builder.compile()
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage from langgraph.pregel.retry import RetryPolicy from pydantic import BaseModel, Field class QueryForTools(BaseModel): """生成其他工具的查询.""" query: str = Field(..., description="其他工具的查询。") def select_tools(state: State): """根据对话状态中的最后一条消息选择工具。如果最后一条消息来自人类,则直接使用消息内容作为查询。否则,使用系统消息构造查询并调用 LLM 生成工具建议。 """ last_message = state["messages"][-1] hack_remove_tool_condition = False # 模拟第一次工具选择中的错误 if isinstance(last_message, HumanMessage): query = last_message.content hack_remove_tool_condition = True # 模拟错误的工具选择 else: assert isinstance(last_message, ToolMessage) system = SystemMessage( "给定此对话,生成其他工具的查询。 " "查询应该是包含所需信息类型的简短字符串。如果不需要更多信息, " "将 more_information_needed 设置为 False 并为查询填充一个空字符串。" ) input_messages = [system] + state["messages"] response = llm.bind_tools([QueryForTools], tool_choice=True).invoke( input_messages ) query = response.tool_calls[0]["args"]["query"] # 使用生成的查询搜索工具向量存储 tool_documents = vector_store.similarity_search(query) if hack_remove_tool_condition: # 通过从选择中删除正确的工具来模拟错误 selected_tools = [ document.id for document in tool_documents if document.metadata["tool_name"] != "Advanced_Micro_Devices" ] else: selected_tools = [document.id for document in tool_documents] return {"selected_tools": selected_tools} graph_builder = StateGraph(State) graph_builder.add_node("agent", agent) graph_builder.add_node("select_tools", select_tools, retry=RetryPolicy(max_attempts=3)) tool_node = ToolNode(tools=tools) graph_builder.add_node("tools", tool_node) graph_builder.add_conditional_edges( "agent", tools_condition, ) graph_builder.add_edge("tools", "select_tools") graph_builder.add_edge("select_tools", "agent") graph_builder.add_edge(START, "select_tools") graph = graph_builder.compile()
In [47]
已复制!
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
from IPython.display import Image, display try: display(Image(graph.get_graph().draw_mermaid_png())) except Exception: # 这需要一些额外的依赖项,并且是可选的 pass
In [48]
已复制!
user_input = "Can you give me some information about AMD in 2022?"
result = graph.invoke({"messages": [("user", user_input)]})
user_input = "你能告诉我一些关于 AMD 在 2022 年的信息吗?" result = graph.invoke({"messages": [("user", user_input)]})
In [49]
已复制!
for message in result["messages"]:
message.pretty_print()
for message in result["messages"]: message.pretty_print()
================================ Human Message ================================= Can you give me some information about AMD in 2022? ================================== Ai Message ================================== Tool Calls: Accenture (call_qGmwFnENwwzHOYJXiCAaY5Mx) Call ID: call_qGmwFnENwwzHOYJXiCAaY5Mx Args: year: 2022 ================================= Tool Message ================================= Name: Accenture Accenture had revenues of $100 in 2022. ================================== Ai Message ================================== Tool Calls: Advanced_Micro_Devices (call_u9e5UIJtiieXVYi7Y9GgyDpn) Call ID: call_u9e5UIJtiieXVYi7Y9GgyDpn Args: year: 2022 ================================= Tool Message ================================= Name: Advanced_Micro_Devices Advanced Micro Devices had revenues of $100 in 2022. ================================== Ai Message ================================== In 2022, AMD had revenues of $100.