提取高基数分类特征#
假设我们构建了一个图书推荐聊天机器人,作为其中的一部分,我们希望提取并过滤作者姓名(如果它是用户输入的一部分)。用户可能会问类似这样的问题:
“有哪些关于外星人的史蒂芬·金的书?”
如果我们不小心,我们的提取系统很可能从这个输入中提取作者姓名“史蒂芬·金”。这可能会导致我们错过所有最相关的图书结果,因为用户几乎肯定是在寻找由 *斯蒂芬·金* 创作的书。
这是一个需要提取 **高基数分类** 值的情况。给定一个包含图书及其作者的数据库,存在大量但有限数量的有效作者姓名,我们需要某种方法来确保我们的提取系统即使在用户输入引用无效姓名时也能输出有效且相关的作者姓名。
我们构建了一个数据集来帮助基准测试处理这一挑战的不同方法。该数据集很简单:它包含 23 个拼写错误和纠正的人名。为了将其用于高基数分类测试,我们将生成一大组有效姓名(约 10,000 个),其中包含数据集中所有姓名的正确拼写。使用它,我们将测试各种提取系统从用户问题中提取纠正后的姓名的能力:
“有哪些关于外星人的 {拼写错误的姓名} 的书?”
对于我们数据集中的每个数据点,我们将使用拼写错误的姓名作为输入,并期望纠正后的姓名作为提取的输出。
设置#
首先,我们需要安装一些包并设置一些环境变量。
%pip install -qU langchain-benchmarks langchain-openai faker chromadb numpy scikit-learn
import getpass
import os
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
from operator import attrgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langsmith import Client
from langchain_benchmarks import registry
这是 langchain-benchmarket 中的 Name Correction
基准测试。
client = Client()
task = registry["Name Correction"]
task.dataset_url
'https://smith.langchain.com/public/78df83ee-ba7f-41c6-832c-2b23327d4cf7/d'
**注意**:如果您是第一次运行此笔记本,请通过取消注释以下内容将公共数据集克隆到您的 LangSmith 组织中。
# client.clone_public_dataset(task.dataset_url)
examples = list(client.list_examples(dataset_name=task.dataset_name))
for example in examples[:5]:
print(example.inputs, example.outputs)
{'name': 'Tracy Cook'} {'name': 'Traci Cook'}
{'name': 'Dan Klein'} {'name': 'Daniel Klein'}
{'name': 'Jen Mcintosh'} {'name': 'Jennifer Mcintosh'}
{'name': 'Cassie Hull'} {'name': 'Cassandra Hull'}
{'name': 'Andy Williams'} {'name': 'Andrew Williams'}
def run_on_dataset(chain, run_name):
client.run_on_dataset(
dataset_name=task.dataset_name,
llm_or_chain_factory=chain,
evaluation=task.eval_config,
project_name=run_name,
)
使用更多虚假名称进行扩充#
在我们的测试中,我们将创建一个包含 10,000 个姓名的列表,它们代表该类别所有可能的值。这将包括我们来自数据集的目标姓名。
from faker import Faker
Faker.seed(42)
fake = Faker()
fake.seed_instance(0)
incorrect_names = [example.inputs["name"] for example in examples]
correct_names = [example.outputs["name"] for example in examples]
# We'll make sure that our list of valid names contains the correct spellings
# and not the incorrect spellings from our dataset
valid_names = list(
set([fake.name() for _ in range(10_000)] + correct_names).difference(
incorrect_names
)
)
len(valid_names)
9382
valid_names[:3]
['Debra Lee', 'Kevin Harper', 'Donald Anderson']
链 1:基线#
作为基线,我们将创建一个函数调用链,它没有关于有效姓名集的信息。
class Search(BaseModel):
query: str
author: str
system = """Generate a relevant search query for a library system"""
prompt = ChatPromptTemplate.from_messages(
[
("system", "{system}"),
("human", "what are books about aliens by {name}"),
]
)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(Search)
query_analyzer_1 = (
prompt.partial(system=system) | structured_llm | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_1, "GPT-3.5")
View the evaluation results for project 'GPT-3.5' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=f429ec84-b879-4e66-b7fb-ef7be69d1acd
View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23
正如我们可能预期的那样,这给了我们一个 Correct rate: 0%
。让我们看看我们是否能做得更好 :)
在 LangSmith 中查看测试运行 这里。
链 2:提示中的所有候选者#
接下来,让我们将完整的有效姓名列表转储到系统提示中。我们将需要一个上下文窗口比 gpt-3.5-turbo-0125 的 16k 个令牌窗口更长的模型,因此我们将使用 gpt-4-0125-preview。
valid_names_str = "\n".join(valid_names)
system_2 = """Generate a relevant search query for a library system.
`author` attribute MUST be one of:
{valid_names_str}
Do NOT hallucinate author name!"""
formatted_system = system_2.format(valid_names_str=valid_names_str)
structured_llm_2 = ChatOpenAI(
model="gpt-4-0125-preview", temperature=0
).with_structured_output(Search)
query_analyzer_2 = (
prompt.partial(system=formatted_system)
| structured_llm_2
| {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_2, "GPT-4, all names in prompt")
View the evaluation results for project 'GPT-4, all names in prompt' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=8c4cfdfc-3646-438e-be47-43a40d66292a
View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23
这使我们的 Correct rate: 26%
。
在 LangSmith 中查看测试运行 这里。
链 3:提示中来自向量存储的 top k 候选者#
10,000 个姓名在提示中太多了。也许我们可以通过首先使用向量搜索缩短列表来获得更好的性能,只包含与用户问题最相似的姓名。我们可以返回使用 GPT-3.5 作为结果。
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings
k = 10
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = Chroma.from_texts(valid_names, embeddings, collection_name="author_names")
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
system_chain = (
(lambda name: f"what are books about aliens by {name}")
| retriever
| (
lambda docs: system_2.format(
valid_names_str="\n".join(d.page_content for d in docs)
)
)
)
query_analyzer_3 = (
RunnablePassthrough.assign(system=system_chain)
| prompt
| structured_llm
| {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_3, f"GPT-3.5, top {k} names in prompt, vecstore")
View the evaluation results for project 'GPT-3.5, top 10 names in prompt, vecstore' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=af93ec50-ccbb-4b3c-908a-70c75e5516ea
View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23
这使我们的 Correct rate: 57%
。
在 LangSmith 中查看测试运行 这里。
链 4:提示中按 ngram 重叠排序的 top k 候选者#
除了使用向量搜索(需要嵌入和向量存储)之外,另一种更便宜、更快的做法是比较用户问题和有效姓名列表之间的 ngram 重叠。
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# Function to generate character n-grams
def ngrams(string, n=3):
string = "START" + string.replace(" ", "").lower() + "END"
ngrams = zip(*[string[i:] for i in range(n)])
return ["".join(ngram) for ngram in ngrams]
# Vectorize documents using TfidfVectorizer with the custom n-grams function
vectorizer = TfidfVectorizer(analyzer=ngrams)
tfidf_matrix = vectorizer.fit_transform(valid_names)
def get_names(query):
# Vectorize query
query_tfidf = vectorizer.transform([query])
# Compute cosine similarity
cosine_similarities = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
# Find the index of the most similar document
most_similar_document_indexes = np.argsort(-cosine_similarities)
return "\n".join([valid_names[i] for i in most_similar_document_indexes[:k]])
def get_system_prompt(input):
name = input["name"]
valid_names_str = get_names(f"what are books about aliens by {name}")
return system_2.format(valid_names_str=valid_names_str)
query_analyzer_4 = (
RunnablePassthrough.assign(system=get_system_prompt)
| prompt
| structured_llm
| {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_4, f"GPT-3.5, top {k} names in prompt, ngram")
View the evaluation results for project 'GPT-3.5, top 10 names in prompt, ngram' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=bc28b761-2ac9-4391-8df1-758f0a4d5100
View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23
这使我们的 Correct rate: 65%
。
在 LangSmith 中查看测试运行 这里。
链 5:用来自向量存储的 top 候选者替换#
除了(或在)提取之前搜索类似的候选者之外,我们还可以通过在有效姓名上搜索来比较和纠正事后提取的值。使用 Pydantic 类,使用验证器可以轻松实现这一点。
from langchain_core.pydantic_v1 import validator
class Search(BaseModel):
query: str
author: str
@validator("author")
def double(cls, v: str) -> str:
return vectorstore.similarity_search(v, k=1)[0].page_content
structured_llm_3 = llm.with_structured_output(Search)
query_analyzer_5 = (
prompt.partial(system=system) | structured_llm_3 | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_5, f"GPT-3.5, correct name, vecstore")
View the evaluation results for project 'GPT-3.5, correct name, vecstore' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=e3eda1e1-bc25-46e8-a4fb-db324cefd1c9
View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23
这使我们的 Correct rate: 83%
。
在 LangSmith 中查看测试运行 这里。
链 6:用按 ngram 重叠排序的 top 候选者替换#
我们可以使用 ngram 重叠搜索来完成相同的事情,而不是向量搜索。
class Search(BaseModel):
query: str
author: str
@validator("author")
def double(cls, v: str) -> str:
return get_names(v).split("\n")[0]
structured_llm_4 = llm.with_structured_output(Search)
query_analyzer_6 = (
prompt.partial(system=system) | structured_llm_4 | {"name": attrgetter("author")}
)
run_on_dataset(query_analyzer_6, f"GPT-3.5, correct name, ngram")
View the evaluation results for project 'GPT-3.5, correct name, ngram' at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6/compare?selectedSessions=8f8846c8-2ada-41bc-8d2c-e1d56e7c92ce
View all tests for Dataset Extracting Corrected Names at:
https://smith.langchain.com/o/43ae1439-dbb7-53b8-bef4-155154d3f962/datasets/1765d6b2-aa2e-46ec-9158-9f4ca8f228c6
[------------------------------------------------->] 23/23
这使我们的 Correct rate: 74%
,比链 5(使用向量搜索而不是 ngram 做相同的事情)略差。
在 LangSmith 中查看测试运行 这里。
在 LangSmith 中查看所有结果#
要查看完整的数据集和所有测试结果,请前往 LangSmith:https://smith.langchain.com/public/8c0a4c25-426d-4582-96fc-d7def170be76/d