跳到内容

如何将 Pydantic 模型用作图状态

前提条件

本指南假设您熟悉以下内容

一个 StateGraph 在初始化时接受一个 state_schema 参数,该参数指定了图中节点可以访问和更新的状态的“形状”。

在我们的示例中,我们通常使用 Python 原生的 TypedDict 作为 state_schema(或者在 MessageGraph 的情况下使用 list),但 state_schema 可以是任何类型

在本操作指南中,我们将看到如何将 Pydantic BaseModel 用作 state_schema,以添加对输入的运行时验证。

已知限制

  • 本笔记本使用 Pydantic v2 BaseModel,这需要 langchain-core >= 0.3。使用 langchain-core < 0.3 将由于混合使用 Pydantic v1 和 v2 BaseModel 而导致错误。
  • 目前,图的 `output` 将不是一个 Pydantic 模型实例。
  • 运行时验证仅对节点的输入进行,不对输出进行。
  • Pydantic 的验证错误跟踪不会显示错误发生在哪个节点。

设置

首先我们需要安装所需的包

pip install --quiet -U langgraph
import getpass
import os


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")

为 LangGraph 开发设置 LangSmith

注册 LangSmith 以快速发现问题并提高您的 LangGraph 项目的性能。LangSmith 允许您使用跟踪数据来调试、测试和监控使用 LangGraph 构建的 LLM 应用程序 — 在这里阅读更多关于如何入门的信息。

输入验证

API 参考:StateGraph | START | END

from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict

from pydantic import BaseModel


# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
    a: str


def node(state: OverallState):
    return {"a": "goodbye"}


# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(node)  # node_1 is the first node
builder.add_edge(START, "node")  # Start the graph with node_1
builder.add_edge("node", END)  # End the graph after node_1
graph = builder.compile()

# Test the graph with a valid input
graph.invoke({"a": "hello"})
{'a': 'goodbye'}

使用一个无效输入调用图

try:
    graph.invoke({"a": 123})  # Should be a string
except Exception as e:
    print("An exception was raised because `a` is an integer rather than a string.")
    print(e)
An exception was raised because `a` is an integer rather than a string.
1 validation error for OverallState
a
  Input should be a valid string [type=string_type, input_value=123, input_type=int]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type

多个节点

运行时验证也适用于多节点图。在下面的示例中,bad_nodea 更新为一个整数。

由于运行时验证发生在输入上,因此验证错误将在调用 ok_node 时发生(而不是在 bad_node 返回与模式不一致的状态更新时)。

API 参考:StateGraph | START | END

from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict

from pydantic import BaseModel


# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
    a: str


def bad_node(state: OverallState):
    return {
        "a": 123  # Invalid
    }


def ok_node(state: OverallState):
    return {"a": "goodbye"}


# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(bad_node)
builder.add_node(ok_node)
builder.add_edge(START, "bad_node")
builder.add_edge("bad_node", "ok_node")
builder.add_edge("ok_node", END)
graph = builder.compile()

# Test the graph with a valid input
try:
    graph.invoke({"a": "hello"})
except Exception as e:
    print("An exception was raised because bad_node sets `a` to an integer.")
    print(e)
An exception was raised because bad_node sets `a` to an integer.
1 validation error for OverallState
a
  Input should be a valid string [type=string_type, input_value=123, input_type=int]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type

多个节点

运行时验证也适用于多节点图。在下面的示例中,bad_nodea 更新为一个整数。

由于运行时验证发生在输入上,因此验证错误将在调用 ok_node 时发生(而不是在 bad_node 返回与模式不一致的状态更新时)。

API 参考:StateGraph | START | END

from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict

from pydantic import BaseModel


# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
    a: str


def bad_node(state: OverallState):
    return {
        "a": 123  # Invalid
    }


def ok_node(state: OverallState):
    return {"a": "goodbye"}


# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(bad_node)
builder.add_node(ok_node)
builder.add_edge(START, "bad_node")
builder.add_edge("bad_node", "ok_node")
builder.add_edge("ok_node", END)
graph = builder.compile()

# Test the graph with a valid input
try:
    graph.invoke({"a": "hello"})
except Exception as e:
    print("An exception was raised because bad_node sets `a` to an integer.")
    print(e)

高级 Pydantic 模型用法

本节介绍在使用 Pydantic 模型与 LangGraph 时更高级的主题。

序列化行为

当使用 Pydantic 模型作为状态模式时,理解序列化如何工作非常重要,特别是当:- 将 Pydantic 对象作为输入传递时 - 接收来自图的输出时 - 使用嵌套的 Pydantic 模型时

让我们看看这些行为的实际应用

API 参考:StateGraph | START | END

from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel


class NestedModel(BaseModel):
    value: str


class ComplexState(BaseModel):
    text: str
    count: int
    nested: NestedModel


def process_node(state: ComplexState):
    # Node receives a validated Pydantic object
    print(f"Input state type: {type(state)}")
    print(f"Nested type: {type(state.nested)}")

    # Return a dictionary update
    return {"text": state.text + " processed", "count": state.count + 1}


# Build the graph
builder = StateGraph(ComplexState)
builder.add_node("process", process_node)
builder.add_edge(START, "process")
builder.add_edge("process", END)
graph = builder.compile()

# Create a Pydantic instance for input
input_state = ComplexState(text="hello", count=0, nested=NestedModel(value="test"))
print(f"Input object type: {type(input_state)}")

# Invoke graph with a Pydantic instance
result = graph.invoke(input_state)
print(f"Output type: {type(result)}")
print(f"Output content: {result}")

# Convert back to Pydantic model if needed
output_model = ComplexState(**result)
print(f"Converted back to Pydantic: {type(output_model)}")

运行时类型强制转换

Pydantic 对某些数据类型执行运行时类型强制转换。这很有帮助,但如果您不了解它,也可能导致意外行为。

API 参考:StateGraph | START | END

from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel


class CoercionExample(BaseModel):
    # Pydantic will coerce string numbers to integers
    number: int
    # Pydantic will parse string booleans to bool
    flag: bool


def inspect_node(state: CoercionExample):
    print(f"number: {state.number} (type: {type(state.number)})")
    print(f"flag: {state.flag} (type: {type(state.flag)})")
    return {}


builder = StateGraph(CoercionExample)
builder.add_node("inspect", inspect_node)
builder.add_edge(START, "inspect")
builder.add_edge("inspect", END)
graph = builder.compile()

# Demonstrate coercion with string inputs that will be converted
result = graph.invoke({"number": "42", "flag": "true"})

# This would fail with a validation error
try:
    graph.invoke({"number": "not-a-number", "flag": "true"})
except Exception as e:
    print(f"\nExpected validation error: {e}")

使用消息模型

在您的状态模式中使用 LangChain 消息类型时,序列化有一些重要的注意事项。在通过网络传输消息对象时,您应该使用 AnyMessage(而不是 BaseMessage)来进行正确的序列化/反序列化

API 参考:StateGraph | START | END | HumanMessage | AIMessage | AnyMessage

from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage
from typing import List


class ChatState(BaseModel):
    messages: List[AnyMessage]
    context: str


def add_message(state: ChatState):
    return {"messages": state.messages + [AIMessage(content="Hello there!")]}


builder = StateGraph(ChatState)
builder.add_node("add_message", add_message)
builder.add_edge(START, "add_message")
builder.add_edge("add_message", END)
graph = builder.compile()

# Create input with a message
initial_state = ChatState(
    messages=[HumanMessage(content="Hi")], context="Customer support chat"
)

result = graph.invoke(initial_state)
print(f"Output: {result}")

# Convert back to Pydantic model to see message types
output_model = ChatState(**result)
for i, msg in enumerate(output_model.messages):
    print(f"Message {i}: {type(msg).__name__} - {msg.content}")

评论