跳到内容

如何强制代理调用工具

在本示例中,我们将构建一个 ReAct 代理,该代理在制定任何计划之前总是首先调用某个工具。在本示例中,我们将创建一个带搜索工具的代理。然而,在一开始,我们将强制代理调用搜索工具(然后让它在那之后做任何想做的事情)。当您知道要在应用程序中执行特定操作,同时也希望 LLM 在完成固定序列后能够灵活地响应用户的查询时,这非常有用。

设置

首先,我们需要安装所需的软件包

yarn add @langchain/langgraph @langchain/openai @langchain/core

接下来,我们需要设置 OpenAI(我们将使用的 LLM)的 API 密钥。此外,我们可以选择设置 LangSmith tracing 的 API 密钥,这将为我们提供一流的可观测性。

// process.env.OPENAI_API_KEY = "sk_...";

// Optional, add tracing in LangSmith
// process.env.LANGCHAIN_API_KEY = "ls__...";
// process.env.LANGCHAIN_CALLBACKS_BACKGROUND = "true";
process.env.LANGCHAIN_TRACING_V2 = "true";
process.env.LANGCHAIN_PROJECT = "Force Calling a Tool First: LangGraphJS";
Force Calling a Tool First: LangGraphJS

设置工具

我们首先定义要使用的工具。对于这个简单的例子,我们将通过 Tavily 使用一个内置的搜索工具。然而,创建自己的工具非常容易 - 请参阅此处的文档,了解如何操作。

import { DynamicStructuredTool } from "@langchain/core/tools";
import { z } from "zod";

const searchTool = new DynamicStructuredTool({
  name: "search",
  description:
    "Use to surf the web, fetch current information, check the weather, and retrieve other information.",
  schema: z.object({
    query: z.string().describe("The query to use in your search."),
  }),
  func: async ({}: { query: string }) => {
    // This is a placeholder for the actual implementation
    return "Cold, with a low of 13 ℃";
  },
});

await searchTool.invoke({ query: "What's the weather like?" });

const tools = [searchTool];

现在我们可以将这些工具包装在一个 ToolNode 中。这是一个预构建的节点,它接收 LangChain 聊天模型生成的工具调用,然后调用该工具并返回输出。

import { ToolNode } from "@langchain/langgraph/prebuilt";

const toolNode = new ToolNode(tools);

设置模型

现在我们需要加载要使用的聊天模型。\ 重要提示:这应满足两个标准

  1. 它应该与消息一起工作。我们将所有代理状态表示为消息形式,因此它需要能够很好地处理它们。
  2. 它应该支持 OpenAI 函数调用。这意味着它要么是 OpenAI 模型,要么是提供类似接口的模型。

注意:这些模型要求并非使用 LangGraph 的强制要求,它们仅适用于此示例。

import { ChatOpenAI } from "@langchain/openai";

const model = new ChatOpenAI({
  temperature: 0,
  model: "gpt-4o",
});

完成此操作后,我们应确保模型知道这些工具可供调用。我们可以通过将 LangChain 工具转换为适用于 OpenAI 函数调用的格式,然后将它们绑定到模型类来实现这一点。

const boundModel = model.bindTools(tools);

定义代理状态

langgraph 中主要的图类型是 StateGraph。这个图由一个状态对象参数化,该状态对象会在每个节点之间传递。然后,每个节点返回更新该状态的操作。

对于本示例,我们将跟踪的状态仅是一个消息列表。我们希望每个节点只向该列表添加消息。因此,我们将把代理状态定义为一个对象,该对象有一个键 (messages),其值指定如何更新状态。

import { Annotation } from "@langchain/langgraph";
import { BaseMessage } from "@langchain/core/messages";

const AgentState = Annotation.Root({
  messages: Annotation<BaseMessage[]>({
    reducer: (x, y) => x.concat(y),
  }),
});

定义节点

现在我们需要在图中定义几个不同的节点。在 langgraph 中,节点可以是一个函数或一个 runnable。为此,我们需要两个主要节点

  1. 代理:负责决定要采取的(如果有)行动。
  2. 调用工具的函数:如果代理决定采取行动,则此节点将执行该行动。

我们还需要定义一些边。其中一些边可能是条件性的。它们是条件性的原因在于,根据节点的输出,可能会采取多种路径中的一种。直到运行该节点(LLM 决定)后,才知道将采取哪条路径。

  1. 条件边:调用代理后,我们应采取以下任一操作: a. 如果代理表示要采取行动,则应调用调用工具的函数\ b. 如果代理表示已完成,则应结束
  2. 普通边:调用工具后,应始终返回代理以决定下一步要做什么

让我们定义节点,以及一个函数来决定采用哪条条件边。

import { AIMessage, AIMessageChunk } from "@langchain/core/messages";
import { RunnableConfig } from "@langchain/core/runnables";
import { concat } from "@langchain/core/utils/stream";

// Define logic that will be used to determine which conditional edge to go down
const shouldContinue = (state: typeof AgentState.State) => {
  const { messages } = state;
  const lastMessage = messages[messages.length - 1] as AIMessage;
  // If there is no function call, then we finish
  if (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) {
    return "end";
  }
  // Otherwise if there is, we continue
  return "continue";
};

// Define the function that calls the model
const callModel = async (
  state: typeof AgentState.State,
  config?: RunnableConfig,
) => {
  const { messages } = state;
  let response: AIMessageChunk | undefined;
  for await (const message of await boundModel.stream(messages, config)) {
    if (!response) {
      response = message;
    } else {
      response = concat(response, message);
    }
  }
  // We return an object, because this will get added to the existing list
  return {
    messages: response ? [response as AIMessage] : [],
  };
};

修改

在此处,我们创建一个节点,该节点返回一个带有工具调用的 AIMessage - 我们将在开始时使用它来强制调用工具

// This is the new first - the first call of the model we want to explicitly hard-code some action
const firstModel = async (state: typeof AgentState.State) => {
  const humanInput = state.messages[state.messages.length - 1].content || "";
  return {
    messages: [
      new AIMessage({
        content: "",
        tool_calls: [
          {
            name: "search",
            args: {
              query: humanInput,
            },
            id: "tool_abcd123",
          },
        ],
      }),
    ],
  };
};

定义图

现在我们可以将所有内容组合起来并定义图了!

修改

我们将定义一个 firstModel 节点,并将其设置为入口点。

import { END, START, StateGraph } from "@langchain/langgraph";

// Define a new graph
const workflow = new StateGraph(AgentState)
  // Define the new entrypoint
  .addNode("first_agent", firstModel)
  // Define the two nodes we will cycle between
  .addNode("agent", callModel)
  .addNode("action", toolNode)
  // Set the entrypoint as `first_agent`
  // by creating an edge from the virtual __start__ node to `first_agent`
  .addEdge(START, "first_agent")
  // We now add a conditional edge
  .addConditionalEdges(
    // First, we define the start node. We use `agent`.
    // This means these are the edges taken after the `agent` node is called.
    "agent",
    // Next, we pass in the function that will determine which node is called next.
    shouldContinue,
    // Finally we pass in a mapping.
    // The keys are strings, and the values are other nodes.
    // END is a special node marking that the graph should finish.
    // What will happen is we will call `should_continue`, and then the output of that
    // will be matched against the keys in this mapping.
    // Based on which one it matches, that node will then be called.
    {
      // If `tools`, then we call the tool node.
      continue: "action",
      // Otherwise we finish.
      end: END,
    },
  )
  // We now add a normal edge from `tools` to `agent`.
  // This means that after `tools` is called, `agent` node is called next.
  .addEdge("action", "agent")
  // After we call the first agent, we know we want to go to action
  .addEdge("first_agent", "action");

// Finally, we compile it!
// This compiles it into a LangChain Runnable,
// meaning you can use it as you would any other runnable
const app = workflow.compile();

使用它!

现在我们可以使用它了!它现在暴露了与所有其他 LangChain runnables 相同的接口

import { HumanMessage } from "@langchain/core/messages";

const inputs = {
  messages: [new HumanMessage("what is the weather in sf")],
};

for await (const output of await app.stream(inputs)) {
  console.log(output);
  console.log("-----\n");
}
{
  first_agent: {
    messages: [
      AIMessage {
        "content": "",
        "additional_kwargs": {},
        "response_metadata": {},
        "tool_calls": [
          {
            "name": "search",
            "args": {
              "query": "what is the weather in sf"
            },
            "id": "tool_abcd123"
          }
        ],
        "invalid_tool_calls": []
      }
    ]
  }
}
-----

{
  action: {
    messages: [
      ToolMessage {
        "content": "Cold, with a low of 13 ℃",
        "name": "search",
        "additional_kwargs": {},
        "response_metadata": {},
        "tool_call_id": "tool_abcd123"
      }
    ]
  }
}
-----

{
  agent: {
    messages: [
      AIMessageChunk {
        "id": "chatcmpl-9y562g16z0MUNBJcS6nKMsDuFMRsS",
        "content": "The current weather in San Francisco is cold, with a low of 13°C.",
        "additional_kwargs": {},
        "response_metadata": {
          "prompt": 0,
          "completion": 0,
          "finish_reason": "stop",
          "system_fingerprint": "fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27fp_3aa7262c27"
        },
        "tool_calls": [],
        "tool_call_chunks": [],
        "invalid_tool_calls": [],
        "usage_metadata": {
          "input_tokens": 104,
          "output_tokens": 18,
          "total_tokens": 122
        }
      }
    ]
  }
}
-----