跳转至

Item 12: Implement RAG Pattern with ContextProvider

问题

你想让 Agent 基于你的文档回答问题,但 Agent 不知道你的文档内容:

# 你的文档:公司内部报销规则(100 页 PDF)
# Agent:只知道通用知识,不知道你们的报销规则

result = await agent.run("国内出差机票报销标准是什么?")
# Agent: "抱歉,我不知道你公司的具体报销标准..."

或者你尝试把整个文档塞进 Instructions:

agent = Agent(
    instructions="""
    # 你的文档内容(100 页 PDF,转成文本 ~50000 tokens)
    第一章 报销规则...
    第二章 差旅标准...
    ...(太长了,context window 装不下)
    """
)
# 问题:token 消耗巨大,大部分用不上

深入解释

RAG = Retrieval Augmented Generation(检索增强生成),核心思想是按需检索,而非全量加载

┌─────────────────────────────────────────────────────────────┐
│              全量加载(Inefficient)                           │
├─────────────────────────────────────────────────────────────┤
│  用户问题:"机票报销标准"                                     │
│                     ↓                                        │
│  把 50000 tokens 的文档全部塞进 context                      │
│                     ↓                                        │
│  LLM 处理 50000 tokens,99% 无关                            │
│                     ↓                                        │
│  回答 "机票报销标准"                                         │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│              RAG(Efficient)                                 │
├─────────────────────────────────────────────────────────────┤
│  用户问题:"机票报销标准"                                     │
│                     ↓                                        │
│  检索相关段落(~500 tokens)                                 │
│                     ↓                                        │
│  把 500 tokens 注入 context                                 │
│                     ↓                                        │
│  LLM 处理 500 tokens,精准回答                               │
└─────────────────────────────────────────────────────────────┘

ContextProvider 是实现 RAG 的最佳位置——它在 before_run() 中拦截请求,检索相关文档,注入 context。

心智模型:RAG 就像图书馆的索引系统——不问读者把整本书都背下来,而是按问题检索相关章节。

推荐做法

简化版:关键词检索 RAG

from agent_framework import ContextProvider

class SimpleRAGProvider(ContextProvider):
    """基于关键词匹配的简化 RAG 实现"""

    def __init__(self, documents: list[str]):
        super().__init__("simple-rag")
        self.documents = documents

    async def before_run(self, *, agent, session, context, state):
        # 1. 从用户输入提取查询
        query = self._extract_query(context)
        if not query:
            return

        # 2. 检索相关文档(简化:关键词匹配)
        relevant = self._keyword_search(query, top_k=3)
        if not relevant:
            return

        # 3. 注入检索结果到 context
        context.extend_instructions(
            self.source_id,
            f"\n【参考文档】\n{relevant}\n"
        )

    def _extract_query(self, context) -> str:
        """从 context 中提取用户查询"""
        for msg in context.input_messages:
            if hasattr(msg, 'contents'):
                for c in msg.contents:
                    if hasattr(c, 'text') and c.text:
                        return c.text
        return ""

    def _keyword_search(self, query: str, top_k: int) -> str:
        """简化搜索:关键词匹配"""
        query_words = set(query.lower().split())

        scored = []
        for i, doc in enumerate(self.documents):
            doc_words = set(doc.lower().split())
            score = len(query_words & doc_words)
            if score > 0:
                scored.append((score, i, doc))

        scored.sort(reverse=True)
        results = [doc for _, _, doc in scored[:top_k]]
        return "\n---\n".join(results)


# 使用
documents = [
    "国内出差报销标准:机票经济舱实报实销,高铁二等座实报实销...",
    "国外出差报销标准:机票需 VP 审批,住宿参照外事标准...",
    "日常报销标准:单笔 1000 元以下无需审批...",
]

agent = Agent(
    client=client,
    instructions="你是报销助手,基于提供的参考文档回答问题。",
    context_providers=[SimpleRAGProvider(documents)],
)

result = await agent.run("国内出差机票报销标准是什么?")
# Agent 会基于 "国内出差报销标准..." 文档回答

生产版:向量检索 RAG

from agent_framework import ContextProvider

class VectorRAGProvider(ContextProvider):
    """使用向量相似度搜索的 RAG 实现"""

    def __init__(self, documents: list[str], embeddings_client):
        super().__init__("vector-rag")
        self.documents = documents
        self.embeddings_client = embeddings_client
        self.doc_embeddings = self._index_documents()

    def _index_documents(self):
        """为文档生成向量索引"""
        # 实际使用:Pinecone/Milvus/Chroma 等向量数据库
        return [self.embeddings_client.embed(doc) for doc in self.documents]

    def _search(self, query: str, top_k: int = 3) -> list[tuple[float, int, str]]:
        """向量相似度搜索"""
        query_embedding = self.embeddings_client.embed(query)

        scored = []
        for i, doc_emb in enumerate(self.doc_embeddings):
            similarity = self._cosine_similarity(query_embedding, doc_emb)
            scored.append((similarity, i, self.documents[i]))

        scored.sort(reverse=True)
        return scored[:top_k]

    def _cosine_similarity(self, a: list, b: list) -> float:
        import math
        dot = sum(x * y for x, y in zip(a, b))
        norm_a = math.sqrt(sum(x * x for x in a))
        norm_b = math.sqrt(sum(x * x for x in b))
        return dot / (norm_a * norm_b)

    async def before_run(self, *, agent, session, context, state):
        query = self._extract_query(context)
        if not query:
            return

        results = self._search(query, top_k=3)
        if not results:
            return

        relevant_docs = "\n---\n".join(doc for _, _, doc in results)
        context.extend_instructions(
            self.source_id,
            f"\n【相关文档】\n{relevant_docs}\n"
        )

    def _extract_query(self, context) -> str:
        for msg in context.input_messages:
            if hasattr(msg, 'contents'):
                for c in msg.contents:
                    if hasattr(c, 'text') and c.text:
                        return c.text
        return ""

好 vs 坏对比

# 坏做法 1:把所有文档塞进 Instructions
agent = Agent(
    instructions="""
    # 50000 tokens 的文档内容...
    """
)
# token 消耗巨大,大部分用不上,context window 溢出

# 坏做法 2:让 Agent 自己"想"答案
result = await agent.run("你公司的报销标准是什么?")
# Agent 瞎编,不知道你公司的具体规则

# 好做法:用 RAG 按需检索
documents = load_your_documents()
agent = Agent(
    instructions="你是报销助手,基于提供的参考文档回答。",
    context_providers=[SimpleRAGProvider(documents)],
)
result = await agent.run("国内出差机票报销标准?")
# Agent 基于检索到的相关段落回答,精准且省 token

扩展讨论

RAG vs Skills vs Fine-tuning

特性 RAG Skills Fine-tuning
知识更新 实时(换文档即可) 需改 Skill 定义 需重新训练
token 消耗 按需检索 按需加载 固定(模型权重)
适用场景 动态文档/知识库 规则/流程/手册 风格/领域适应
准确性 依赖检索质量 高(精确控制) 中(依赖训练数据)
延迟 增加检索延迟 无(推理时)

检索策略选择

# 1. 关键词检索(BM25)- 简单场景
from rank_bm25 import BM25Ok

class BM25RAGProvider(ContextProvider):
    def __init__(self, documents: list[str]):
        super().__init__("bm25-rag")
        self.documents = documents
        self.bm25 = BM25Ok([doc.split() for doc in documents])

    def search(self, query: str, top_k: int = 3) -> list[str]:
        tokens = query.split()
        scores = self.bm25.get_scores(tokens)
        top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
        return [self.documents[i] for i in top_indices]

# 2. 向量检索 - 生产环境
# 使用:Pinecone / Milvus / Weaviate / ChromaDB

混合检索

class HybridRAGProvider(ContextProvider):
    """关键词 + 向量混合检索"""

    def __init__(self, documents: list[str], embeddings_client):
        super().__init__("hybrid-rag")
        self.documents = documents
        self.embeddings_client = embeddings_client
        self.bm25 = BM25Ok([doc.split() for doc in documents])
        self.doc_embeddings = [embeddings_client.embed(doc) for doc in documents]

    def search(self, query: str, top_k: int = 3) -> list[str]:
        # 关键词得分
        keyword_scores = self.bm25.get_scores(query.split())

        # 向量得分
        query_emb = self.embeddings_client.embed(query)
        vector_scores = [self._cosine_similarity(query_emb, doc_emb)
                        for doc_emb in self.doc_embeddings]

        # 融合
        final_scores = [0.4 * kw + 0.6 * vec for kw, vec in zip(keyword_scores, vector_scores)]

        top_indices = sorted(range(len(final_scores)), key=lambda i: final_scores[i], reverse=True)[:top_k]
        return [self.documents[i] for i in top_indices]

企业级考虑

# 1. 文档更新时重建索引
class VersionedRAGProvider(ContextProvider):
    def __init__(self, docs_path: str, embeddings_client):
        self.embeddings_client = embeddings_client
        self.docs_path = Path(docs_path)
        self._reindex()

    def _reindex(self):
        """文档变更时重建索引"""
        self.documents = list(self.docs_path.glob("*.txt"))
        self.doc_embeddings = [self.embeddings_client.embed(doc.read_text())
                             for doc in self.documents]

# 2. 查询缓存
class CachedRAGProvider(ContextProvider):
    def __init__(self, documents: list[str]):
        super().__init__("cached-rag")
        self.documents = documents
        self.cache = {}

    def _search(self, query: str) -> str:
        if query in self.cache:
            return self.cache[query]

        result = self._do_search(query)
        self.cache[query] = result
        return result

# 3. 检索结果重排序
class RerankedRAGProvider(ContextProvider):
    def __init__(self, documents: list[str], reranker):
        self.documents = documents
        self.reranker = reranker  # Cross-Encoder 重排序模型

    def search(self, query: str, top_k: int = 10, rerank_top: int = 3):
        # 第一阶段:向量检索出 top_k
        candidates = self._vector_search(query, top_k)

        # 第二阶段:Cross-Encoder 重排序
        reranked = self.reranker.rerank(query, candidates)[:rerank_top]
        return reranked

Things to Remember

  • RAG = 检索 + 生成,不是让 Agent 自己编答案
  • 检索在 before_run(),按需将相关文档注入 context
  • ContextProvider 是实现 RAG 的标准位置
  • context.extend_instructions() 将检索结果注入
  • RAG 适合动态文档/知识库,Skills 适合结构化规则
  • 简化版可用关键词匹配(BM25),生产用向量检索(Pinecone/Milvus)
  • 检索质量直接影响生成质量——好的检索策略很重要
  • 文档更新时需重建索引,向量数据库支持增量更新
  • 混合检索(关键词 + 向量)通常效果更好
  • 检索结果需要重排序(Re-ranker)以提升相关性