如何禁用不支持流式传输的模型的流式传输¶
一些聊天模型,包括来自 OpenAI 的新 O1 模型(取决于您阅读本文的时间),不支持流式传输。这可能导致使用astream_events API时出现问题,因为它以流式传输模式调用模型,并期望流式传输能够正常运行。
在本指南中,我们将向您展示如何禁用不支持流式传输的模型的流式传输,确保它们永远不会以流式传输模式调用,即使通过 astream_events API 调用也是如此。
在 [4]
已复制!
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langgraph.graph import StateGraph, START, END
llm = ChatOpenAI(model="o1-preview", temperature=1)
graph_builder = StateGraph(MessagesState)
def chatbot(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
from langchain_openai import ChatOpenAI from langgraph.graph import MessagesState from langgraph.graph import StateGraph, START, END llm = ChatOpenAI(model="o1-preview", temperature=1) graph_builder = StateGraph(MessagesState) def chatbot(state: MessagesState): return {"messages": [llm.invoke(state["messages"])]} graph_builder.add_node("chatbot", chatbot) graph_builder.add_edge(START, "chatbot") graph_builder.add_edge("chatbot", END) graph = graph_builder.compile()
在 [5]
已复制!
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
from IPython.display import Image, display display(Image(graph.get_graph().draw_mermaid_png()))
不禁用流式传输¶
现在我们已经定义了我们的图,让我们尝试在不禁用流式传输的情况下调用 astream_events
。这应该会抛出一个错误,因为 o1
模型本身不支持流式传输。
在 [6]
已复制!
input = {"messages": {"role": "user", "content": "how many r's are in strawberry?"}}
try:
async for event in graph.astream_events(input, version="v2"):
if event["event"] == "on_chat_model_end":
print(event["data"]["output"].content, end="", flush=True)
except:
print("Streaming not supported!")
input = {"messages": {"role": "user", "content": "how many r's are in strawberry?"}} try: async for event in graph.astream_events(input, version="v2"): if event["event"] == "on_chat_model_end": print(event["data"]["output"].content, end="", flush=True) except: print("Streaming not supported!")
Streaming not supported!
在 [7]
已复制!
llm = ChatOpenAI(model="o1-preview", temperature=1, disable_streaming=True)
graph_builder = StateGraph(MessagesState)
def chatbot(state: MessagesState):
return {"messages": [llm.invoke(state["messages"])]}
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
llm = ChatOpenAI(model="o1-preview", temperature=1, disable_streaming=True) graph_builder = StateGraph(MessagesState) def chatbot(state: MessagesState): return {"messages": [llm.invoke(state["messages"])]} graph_builder.add_node("chatbot", chatbot) graph_builder.add_edge(START, "chatbot") graph_builder.add_edge("chatbot", END) graph = graph_builder.compile()
现在,使用相同的输入重新运行,我们应该不会看到任何错误。
在 [8]
已复制!
input = {"messages": {"role": "user", "content": "how many r's are in strawberry?"}}
async for event in graph.astream_events(input, version="v2"):
if event["event"] == "on_chat_model_end":
print(event["data"]["output"].content, end="", flush=True)
input = {"messages": {"role": "user", "content": "how many r's are in strawberry?"}} async for event in graph.astream_events(input, version="v2"): if event["event"] == "on_chat_model_end": print(event["data"]["output"].content, end="", flush=True)
There are three "r"s in the word "strawberry".