{ "cells": [ { "cell_type": "markdown", "id": "51466c8d-8ce4-4b3d-be4e-18fdbeda5f53", "metadata": {}, "source": [ "# How to use Pydantic model as state\n", "\n", "
\n", "

Prerequisites

\n", "

\n", " This guide assumes familiarity with the following:\n", "

\n", "

\n", "
\n", "\n", "A [StateGraph](https://langchain-ai.github.io/langgraph/reference/graphs/#langgraph.graph.StateGraph) accepts a `state_schema` argument on initialization that specifies the \"shape\" of the state that the nodes in the graph can access and update.\n", "\n", "In our examples, we typically use a python-native `TypedDict` for `state_schema` (or in the case of [MessageGraph](https://langchain-ai.github.io/langgraph/reference/graphs/#messagegraph), a [list](https://docs.python.org/3/library/stdtypes.html#list)), but `state_schema` can be any [type](https://docs.python.org/3/library/stdtypes.html#type-objects).\n", "\n", "In this how-to guide, we'll see how a [Pydantic BaseModel](https://docs.pydantic.dev/latest/api/base_model/). can be used for `state_schema` to add run time validation on **inputs**.\n", "\n", "\n", "
\n", "

Known Limitations

\n", "

\n", "

\n", "

\n", "
" ] }, { "cell_type": "markdown", "id": "7cbd446a-808f-4394-be92-d45ab818953c", "metadata": {}, "source": [ "## Setup\n", "\n", "First we need to install the packages required" ] }, { "cell_type": "code", "execution_count": 1, "id": "af4ce0ba-7596-4e5f-8bf8-0b0bd6e62833", "metadata": {}, "outputs": [], "source": [ "%%capture --no-stderr\n", "%pip install --quiet -U langgraph" ] }, { "cell_type": "code", "execution_count": 2, "id": "01456d57-4064-4ccb-baf9-98df39c6b8e0", "metadata": {}, "outputs": [], "source": [ "import getpass\n", "import os\n", "\n", "\n", "def _set_env(var: str):\n", " if not os.environ.get(var):\n", " os.environ[var] = getpass.getpass(f\"{var}: \")\n", "\n", "\n", "_set_env(\"OPENAI_API_KEY\")" ] }, { "cell_type": "markdown", "id": "4f385bde-e013-4365-88f3-813c632d4b7c", "metadata": {}, "source": [ "
\n", "

Set up LangSmith for LangGraph development

\n", "

\n", " Sign up for LangSmith to quickly spot issues and improve the performance of your LangGraph projects. LangSmith lets you use trace data to debug, test, and monitor your LLM apps built with LangGraph — read more about how to get started here. \n", "

\n", "
" ] }, { "cell_type": "markdown", "id": "e20dd648-df7a-40f5-9b32-afbdcf1ee4d8", "metadata": {}, "source": [ "## Input Validation" ] }, { "cell_type": "code", "execution_count": 4, "id": "efc46b36-425c-49c3-9f9e-d9785c70b034", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'a': 'goodbye'}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langgraph.graph import StateGraph, START, END\n", "from typing_extensions import TypedDict\n", "\n", "from pydantic import BaseModel\n", "\n", "\n", "# The overall state of the graph (this is the public state shared across nodes)\n", "class OverallState(BaseModel):\n", " a: str\n", "\n", "\n", "def node(state: OverallState):\n", " return {\"a\": \"goodbye\"}\n", "\n", "\n", "# Build the state graph\n", "builder = StateGraph(OverallState)\n", "builder.add_node(node) # node_1 is the first node\n", "builder.add_edge(START, \"node\") # Start the graph with node_1\n", "builder.add_edge(\"node\", END) # End the graph after node_1\n", "graph = builder.compile()\n", "\n", "# Test the graph with a valid input\n", "graph.invoke({\"a\": \"hello\"})" ] }, { "cell_type": "markdown", "id": "25b594c2-8198-4f76-9606-ea47151ff9d1", "metadata": {}, "source": [ "Invoke the graph with an **invalid** input" ] }, { "cell_type": "code", "execution_count": 5, "id": "05d7d43b-0b71-4e25-af6f-61d1560a46cb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "An exception was raised because `a` is an integer rather than a string.\n", "1 validation error for OverallState\n", "a\n", " Input should be a valid string [type=string_type, input_value=123, input_type=int]\n", " For further information visit https://errors.pydantic.dev/2.9/v/string_type\n" ] } ], "source": [ "try:\n", " graph.invoke({\"a\": 123}) # Should be a string\n", "except Exception as e:\n", " print(\"An exception was raised because `a` is an integer rather than a string.\")\n", " print(e)" ] }, { "cell_type": "markdown", "id": "0aafc180-17b5-4364-b1df-fb41aa575067", "metadata": {}, "source": [ "## Multiple Nodes\n", "\n", "Run-time validation will also work in a multi-node graph. In the example below `bad_node` updates `a` to an integer. \n", "\n", "Because run-time validation occurs on **inputs**, the validation error will occur when `ok_node` is called (not when `bad_node` returns an update to the state which is inconsistent with the schema)." ] }, { "cell_type": "code", "execution_count": 6, "id": "25336b0d-2fe6-45c8-8204-f962c3995df7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "An exception was raised because bad_node sets `a` to an integer.\n", "1 validation error for OverallState\n", "a\n", " Input should be a valid string [type=string_type, input_value=123, input_type=int]\n", " For further information visit https://errors.pydantic.dev/2.9/v/string_type\n" ] } ], "source": [ "from langgraph.graph import StateGraph, START, END\n", "from typing_extensions import TypedDict\n", "\n", "from pydantic import BaseModel\n", "\n", "\n", "# The overall state of the graph (this is the public state shared across nodes)\n", "class OverallState(BaseModel):\n", " a: str\n", "\n", "\n", "def bad_node(state: OverallState):\n", " return {\n", " \"a\": 123 # Invalid\n", " }\n", "\n", "\n", "def ok_node(state: OverallState):\n", " return {\"a\": \"goodbye\"}\n", "\n", "\n", "# Build the state graph\n", "builder = StateGraph(OverallState)\n", "builder.add_node(bad_node)\n", "builder.add_node(ok_node)\n", "builder.add_edge(START, \"bad_node\")\n", "builder.add_edge(\"bad_node\", \"ok_node\")\n", "builder.add_edge(\"ok_node\", END)\n", "graph = builder.compile()\n", "\n", "# Test the graph with a valid input\n", "try:\n", " graph.invoke({\"a\": \"hello\"})\n", "except Exception as e:\n", " print(\"An exception was raised because bad_node sets `a` to an integer.\")\n", " print(e)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }