跳到内容

如何使用摘要管理长上下文

在现代 LLM 应用中,无论是构建包含多轮对话的聊天机器人,还是包含大量工具调用的代理系统,上下文大小都可能迅速增长并达到提供商的限制。

处理此问题的一个有效策略是,在较早的消息达到特定阈值后对其进行总结。本指南演示了如何在 LangGraph 应用中,使用 LangMem 预置的 summarize_messagesSummarizationNode 来实现此方法。

在简单聊天机器人中使用

下面是一个带摘要功能的简单多轮聊天机器人示例

API: ChatOpenAI | StateGraph | START | summarize_messages | RunningSummary

from langgraph.graph import StateGraph, START, MessagesState
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import summarize_messages, RunningSummary
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-4o")
summarization_model = model.bind(max_tokens=128)  # (1)!

# We will keep track of our running summary in the graph state
class SummaryState(MessagesState):
    summary: RunningSummary | None

# Define the node that will be calling the LLM
def call_model(state: SummaryState) -> SummaryState:
    summarization_result = summarize_messages(  # (2)!
        state["messages"],
        # IMPORTANT: Pass running summary, if any
        running_summary=state.get("summary"),  # (3)!
        token_counter=model.get_num_tokens_from_messages,
        model=summarization_model, 
        max_tokens=256,  # (4)!
        max_tokens_before_summary=256,  # (5)!
        max_summary_tokens=128
    )
    response = model.invoke(summarization_result.messages)
    state_update = {"messages": [response]}
    if summarization_result.running_summary:  # (6)!
        state_update["summary"] = summarization_result.running_summary
    return state_update


checkpointer = InMemorySaver()
builder = StateGraph(SummaryState)
builder.add_node(call_model)
builder.add_edge(START, "call_model")
graph = builder.compile(checkpointer=checkpointer)  # (7)!

# Invoke the graph
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, my name is bob"}, config)
graph.invoke({"messages": "write a short poem about cats"}, config)
graph.invoke({"messages": "now do the same but for dogs"}, config)
graph.invoke({"messages": "what's my name?"}, config)
  1. 我们还为摘要模型设置了最大输出令牌数。为了更好地估算令牌预算,此值应与 `summarize_messages` 中的 `max_summary_tokens` 相匹配。
  2. 我们将在调用 LLM 之前尝试总结消息。如果 `state["messages"]` 中的消息符合 `max_tokens_before_summary` 预算,我们将直接返回这些消息。否则,我们将进行总结并返回 `[summary_message] + remaining_messages`。
  3. 如果存在运行中的摘要,请传递它。这使得 `summarize_messages` 能够避免在每轮对话中重复总结相同的消息。
  4. 这是总结后生成的消息列表的最大令牌预算。
  5. 这是触发总结功能的令牌阈值。默认为 `max_tokens`。
  6. 如果生成了摘要,则将其作为状态更新添加,并覆盖之前生成的摘要(如果有)。
  7. 使用检查点编译图是很重要的,否则图将无法记住之前的对话轮次。

在用户界面中使用

一个重要的问题是如何在你的应用用户界面中向用户展示消息。我们建议渲染完整、未修改的消息历史。你可以选择额外渲染摘要以及传递给 LLM 的消息。我们还建议对完整消息历史(例如,"messages")和总结结果(例如,"summary")使用单独的 LangGraph 状态键。在 SummarizationNode 中,总结结果存储在一个名为 context 的单独状态键中(见下方示例)。

使用 SummarizationNode

你也可以将总结功能分离到一个专门的节点中。让我们看看如何修改上面的示例,使用 SummarizationNode 来达到同样的效果

API: ChatOpenAI | StateGraph | START | SummarizationNode | RunningSummary

from typing import Any, TypedDict

from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import SummarizationNode, RunningSummary

model = ChatOpenAI(model="gpt-4o")
summarization_model = model.bind(max_tokens=128)


class State(MessagesState):
    context: dict[str, Any]  # (1)!


class LLMInputState(TypedDict):  # (2)!
    summarized_messages: list[AnyMessage]
    context: dict[str, Any]

summarization_node = SummarizationNode(  # (3)!
    token_counter=model.get_num_tokens_from_messages,
    model=summarization_model,
    max_tokens=256,
    max_tokens_before_summary=256,
    max_summary_tokens=128,
)

# IMPORTANT: we're passing a private input state here to isolate the summarization
def call_model(state: LLMInputState):  # (4)!
    response = model.invoke(state["summarized_messages"])
    return {"messages": [response]}

checkpointer = InMemorySaver()
builder = StateGraph(State)
builder.add_node(call_model)
builder.add_node("summarize", summarization_node)
builder.add_edge(START, "summarize")
builder.add_edge("summarize", "call_model")
graph = builder.compile(checkpointer=checkpointer)

# Invoke the graph
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, my name is bob"}, config)
graph.invoke({"messages": "write a short poem about cats"}, config)
graph.invoke({"messages": "now do the same but for dogs"}, config)
graph.invoke({"messages": "what's my name?"}, config)
  1. 我们将在 context 字段中跟踪我们的运行摘要(这是 SummarizationNode 期望的)。
  2. 定义私有状态,该状态仅用于过滤 call_model 节点的输入。
  3. SummarizationNode 内部使用了 summarize_messages,并自动处理了我们在上面示例中不得不手动完成的现有摘要传播。
  4. 现在,调用模型的节点只是一个简单的 LLM 调用。

在 ReAct 代理中使用

一个常见的用例是在调用工具的代理中总结消息历史。下面的示例演示了如何在 ReAct 风格的 LangGraph 代理中实现这一点

API: ChatOpenAI | tool | StateGraph | START | END | ToolNode | SummarizationNode | RunningSummary

from typing import Any, TypedDict

from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import SummarizationNode, RunningSummary

class State(MessagesState):
    context: dict[str, Any]

def search(query: str):
    """Search the web."""
    if "weather" in query.lower():
        return "The weather is sunny in New York, with a high of 104 degrees."
    elif "broadway" in query.lower():
        return "Hamilton is always on!"
    else:
        raise "Not enough information"

tools = [search]

model = ChatOpenAI(model="gpt-4o")
summarization_model = model.bind(max_tokens=128)

summarization_node = SummarizationNode(
    token_counter=model.get_num_tokens_from_messages,
    model=summarization_model,
    max_tokens=256,
    max_tokens_before_summary=1024,
    max_summary_tokens=128,
)

class LLMInputState(TypedDict):
    summarized_messages: list[AnyMessage]
    context: dict[str, Any]

def call_model(state: LLMInputState):
    response = model.bind_tools(tools).invoke(state["summarized_messages"])
    return {"messages": [response]}

# Define a router that determines whether to execute tools or exit
def should_continue(state: MessagesState):
    messages = state["messages"]
    last_message = messages[-1]
    if not last_message.tool_calls:
        return END
    else:
        return "tools"

checkpointer = InMemorySaver()
builder = StateGraph(State)
builder.add_node("summarize_node", summarization_node)
builder.add_node("call_model", call_model)
builder.add_node("tools", ToolNode(tools))
builder.set_entry_point("summarize_node")
builder.add_edge("summarize_node", "call_model")
builder.add_conditional_edges("call_model", should_continue, path_map=["tools", END])
builder.add_edge("tools", "summarize_node")  # (1)!
graph = builder.compile(checkpointer=checkpointer)

# Invoke the graph
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, i am bob"}, config)
graph.invoke({"messages": "what's the weather in nyc this weekend"}, config)
graph.invoke({"messages": "what's new on broadway?"}, config)
  1. 在执行工具后,我们不是返回到 LLM,而是首先返回到总结节点。

评论