跳到内容

如何通过摘要管理长上下文

在现代的 LLM 应用中,无论是构建多轮对话的聊天机器人,还是具有大量工具调用的代理系统,上下文大小都可能迅速增长并触及服务提供商的限制。

处理此问题的一个有效策略是,在消息达到某个阈值后,对较早的消息进行摘要。本指南将演示如何在您的 LangGraph 应用中使用 LangMem 预构建的 summarize_messagesSummarizationNode 来实现此方法。

在简单聊天机器人中使用

以下是一个带有摘要功能的简单多轮聊天机器人的示例。

API: ChatOpenAI | StateGraph | START | summarize_messages | RunningSummary

from langgraph.graph import StateGraph, START, MessagesState
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import summarize_messages, RunningSummary
from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-4o")
summarization_model = model.bind(max_tokens=128)  # (1)!

# We will keep track of our running summary in the graph state
class SummaryState(MessagesState):
    summary: RunningSummary | None

# Define the node that will be calling the LLM
def call_model(state: SummaryState) -> SummaryState:
    summarization_result = summarize_messages(  # (2)!
        state["messages"],
        # IMPORTANT: Pass running summary, if any
        running_summary=state.get("summary"),  # (3)!
        token_counter=model.get_num_tokens_from_messages,
        model=summarization_model, 
        max_tokens=256,  # (4)!
        max_tokens_before_summary=256,  # (5)!
        max_summary_tokens=128
    )
    response = model.invoke(summarization_result.messages)
    state_update = {"messages": [response]}
    if summarization_result.running_summary:  # (6)!
        state_update["summary"] = summarization_result.running_summary
    return state_update


checkpointer = InMemorySaver()
builder = StateGraph(SummaryState)
builder.add_node(call_model)
builder.add_edge(START, "call_model")
graph = builder.compile(checkpointer=checkpointer)  # (7)!

# Invoke the graph
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, my name is bob"}, config)
graph.invoke({"messages": "write a short poem about cats"}, config)
graph.invoke({"messages": "now do the same but for dogs"}, config)
graph.invoke({"messages": "what's my name?"}, config)
  1. 我们还为摘要模型设置了最大输出词元数。这应与 summarize_messages 中的 max_summary_tokens 匹配,以便更好地估算词元预算。
  2. 我们将在调用 LLM 之前尝试对消息进行摘要。如果 state["messages"] 中的消息符合 max_tokens_before_summary 预算,我们将简单地返回这些消息。否则,我们将进行摘要并返回 [summary_message] + remaining_messages
  3. 传递运行中的摘要(如果有)。这使得 summarize_messages 能够避免在每个对话轮次中重复摘要相同的消息。
  4. 这是摘要后生成的消息列表的最大词元预算。
  5. 这是触发摘要的词元阈值。默认为 max_tokens
  6. 如果我们生成了摘要,则将其作为状态更新添加,并覆盖之前生成的任何摘要。
  7. 使用检查点编译图非常重要,否则图将不会记住之前的对话轮次。

在 UI 中使用

一个重要的问题是如何在您的应用程序 UI 中向用户呈现消息。我们建议渲染完整、未经修改的消息历史。您可以选择额外渲染摘要和传递给 LLM 的消息。我们还建议为完整的消息历史(例如,"messages")和摘要结果(例如,"summary")使用不同的 LangGraph 状态键。在 SummarizationNode 中,摘要结果存储在一个名为 context 的独立状态键中(见下例)。

使用 SummarizationNode

您也可以将摘要分离到一个专用节点中。让我们探讨如何修改上面的示例以使用 SummarizationNode 来实现相同的结果。

API: ChatOpenAI | StateGraph | START | SummarizationNode | RunningSummary

from typing import Any, TypedDict

from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import SummarizationNode, RunningSummary

model = ChatOpenAI(model="gpt-4o")
summarization_model = model.bind(max_tokens=128)


class State(MessagesState):
    context: dict[str, Any]  # (1)!


class LLMInputState(TypedDict):  # (2)!
    summarized_messages: list[AnyMessage]
    context: dict[str, Any]

summarization_node = SummarizationNode(  # (3)!
    token_counter=model.get_num_tokens_from_messages,
    model=summarization_model,
    max_tokens=256,
    max_tokens_before_summary=256,
    max_summary_tokens=128,
)

# IMPORTANT: we're passing a private input state here to isolate the summarization
def call_model(state: LLMInputState):  # (4)!
    response = model.invoke(state["summarized_messages"])
    return {"messages": [response]}

checkpointer = InMemorySaver()
builder = StateGraph(State)
builder.add_node(call_model)
builder.add_node("summarize", summarization_node)
builder.add_edge(START, "summarize")
builder.add_edge("summarize", "call_model")
graph = builder.compile(checkpointer=checkpointer)

# Invoke the graph
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, my name is bob"}, config)
graph.invoke({"messages": "write a short poem about cats"}, config)
graph.invoke({"messages": "now do the same but for dogs"}, config)
graph.invoke({"messages": "what's my name?"}, config)
  1. 我们将在 context 字段中跟踪我们的运行摘要(这是 SummarizationNode 所期望的)。
  2. 定义仅用于过滤 call_model 节点输入的私有状态。
  3. SummarizationNode 在底层使用 summarize_messages,并自动处理现有摘要的传递,而在上面的示例中我们必须手动完成此操作。
  4. 现在,模型调用节点只是一个单一的 LLM 调用。

在 ReAct Agent 中使用

一个常见的用例是在工具调用代理中摘要消息历史。下面的示例演示了如何在 ReAct 风格的 LangGraph 代理中实现这一点。

API: ChatOpenAI | tool | StateGraph | START | END | ToolNode | SummarizationNode | RunningSummary

from typing import Any, TypedDict

from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import InMemorySaver
from langmem.short_term import SummarizationNode, RunningSummary

class State(MessagesState):
    context: dict[str, Any]

def search(query: str):
    """Search the web."""
    if "weather" in query.lower():
        return "The weather is sunny in New York, with a high of 104 degrees."
    elif "broadway" in query.lower():
        return "Hamilton is always on!"
    else:
        raise "Not enough information"

tools = [search]

model = ChatOpenAI(model="gpt-4o")
summarization_model = model.bind(max_tokens=128)

summarization_node = SummarizationNode(
    token_counter=model.get_num_tokens_from_messages,
    model=summarization_model,
    max_tokens=256,
    max_tokens_before_summary=1024,
    max_summary_tokens=128,
)

class LLMInputState(TypedDict):
    summarized_messages: list[AnyMessage]
    context: dict[str, Any]

def call_model(state: LLMInputState):
    response = model.bind_tools(tools).invoke(state["summarized_messages"])
    return {"messages": [response]}

# Define a router that determines whether to execute tools or exit
def should_continue(state: MessagesState):
    messages = state["messages"]
    last_message = messages[-1]
    if not last_message.tool_calls:
        return END
    else:
        return "tools"

checkpointer = InMemorySaver()
builder = StateGraph(State)
builder.add_node("summarize_node", summarization_node)
builder.add_node("call_model", call_model)
builder.add_node("tools", ToolNode(tools))
builder.set_entry_point("summarize_node")
builder.add_edge("summarize_node", "call_model")
builder.add_conditional_edges("call_model", should_continue, path_map=["tools", END])
builder.add_edge("tools", "summarize_node")  # (1)!
graph = builder.compile(checkpointer=checkpointer)

# Invoke the graph
config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": "hi, i am bob"}, config)
graph.invoke({"messages": "what's the weather in nyc this weekend"}, config)
graph.invoke({"messages": "what's new on broadway?"}, config)
  1. 在执行工具后,我们首先返回到摘要节点,而不是直接返回给 LLM。

评论