banner
子文

子文

世界,你好鸭~
x
github

百行のコードで三体問答ロボットを実現する

langchain + gradio を使用して、わずか 100 行の Python コードで簡単な RAG アプリケーションを実装できます。この記事では、「三体」を例に、実装手順を解説します。

RAG とは何ですか?#

Retrieval Augmented Generation (RAG)は、特定の情報推論能力に焦点を当てた LLM の強化手法です。LLM(例:chatgpt)は、特定のデータセットを使用してトレーニングされたモデルであり、特定のドメインの情報や機密情報に対して質問応答を行う場合、正確な回答を提供することが難しくなります。例えば、chatgpt に「三体」の章北海について説明させると、以下の結果が得られます:

image

このような場合、既存のデータを使用して大規模なモデルを再調整することは非常にコストがかかり、リアルタイムの効果を得ることは困難ですが、RAG はこの問題をうまく解決します。

RAG の実装方法は?#

RAG アプリケーションの実装は、通常 2 つのステップに分かれます:インデックスの構築と検索生成です。

インデックスの構築#

ソースデータを埋め込みモデルを使用してベクトルに変換し、ベクトルデータベースに保存します。一般的な手順は次のとおりです:

  1. Load: DocumentLoader を使用してさまざまなタイプのドキュメントデータを読み込みます。
  2. Split: ドキュメントを一定のルールに従って小さなチャンクに分割します。これにより、モデルが文脈をより良く理解できるようになります。
  3. Store: 最後に、分割されたチャンクを埋め込みモデルを使用してベクトルにマッピングし、ベクトルデータベースに保存します。

image

検索生成#

実際には、2 つのステップがあります:

  1. Retrieve: ユーザーの入力質問もベクトルに変換し、ベクトルデータベースで関連性が最も高いチャンクを検索します。
  2. Generate: ChatModel または LLM(例:chatgpt)を使用して、ユーザーの質問に基づいて検索された内容に基づいて要約を生成します。

image

langchain とは何ですか?#

LangChainは、開発者が言語モデルを使用してエンドツーエンドのアプリケーションを構築するのを支援する強力なフレームワークです。LangChain は、大規模な言語モデル(LLM)やチャットモデルをサポートするアプリケーションを作成するプロセスを簡素化するためのツール、コンポーネント、インターフェースを提供します。LangChain は、言語モデルとの対話を簡単に管理し、複数のコンポーネントをリンクし、API やデータベースなどの追加リソースを統合することができます。

gradio とは何ですか?#

Gradioは、インタラクティブなアプリケーションを簡単に構築するためのオープンソースの Python ライブラリです。Gradio は、機械学習モデルをユーザーフレンドリーなインターフェースに統合し、モデルの使用と理解を容易にします。

langchain+gradio を使用して三体の質問応答ボットを簡単に実装する#

プロジェクトのソースコードは GitHub に公開されています:https://github.com/zivenyang/3body-chatbot
以下は、主要なコードです(当時、国内で openai アカウントを申請するのが難しかったため、Azure Openai の API を使用しています)。

from dotenv import load_dotenv
from langchain.chat_models import AzureChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.embeddings import ModelScopeEmbeddings
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import gradio as gr
import os

# .envファイルから変数を読み込む、AZURE_OPENAI_ENDPOINTとAZURE_OPENAI_API_KEY
load_dotenv()

# 埋め込みモデル、中国の小説なので、中国語の埋め込みモデルを使用します
MODEL_ID = "damo/nlp_gte_sentence-embedding_chinese-base"
# ベクトルデータベースの保存先
PERSIST_DIRECTORY = 'docs/chroma/'

# 索引ドキュメントと回答を同時にコンソールに出力するためのクラス
class AnswerConversationBufferMemory(ConversationBufferMemory):
    def save_context(self, inputs, outputs) -> None:
        return super(AnswerConversationBufferMemory, self).save_context(inputs,{'response': outputs['answer']})

def create_db():
    """ローカルファイルを読み込んでベクトルに変換し、ベクトルデータベースに保存する"""

    # ローカルファイルを読み込む、つまり三体の小説
    text_loader_kwargs={'autodetect_encoding': True}
    loader = DirectoryLoader("./docs", glob="**/*.txt", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
    pages = loader.load()

    # ファイルをチャンクに分割する、chunk_sizeとオーバーラップは、GPUのパフォーマンスにも依存します。メモリが大きいほど、より細かく分割されます。
    text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 512,
    chunk_overlap = 0,
    length_function = len,
        )
    splits = text_splitter.split_documents(pages)

    # ベクトル(埋め込み)を生成し、データベースに保存する
    embedding = ModelScopeEmbeddings(model_id=MODEL_ID)
    db = Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        persist_directory=PERSIST_DIRECTORY
    )
    # データベースを永続化する
    db.persist()
    return db

def querying(query, history):
    db = None
    if not os.path.exists(PERSIST_DIRECTORY):
        # ベクトルデータベースが存在しない場合は作成する
        db = create_db()
    else:
        # 既存のベクトルデータベースをロードする
        embedding = ModelScopeEmbeddings(model_id=MODEL_ID)
        db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding)

    # chatモデル
    llm = AzureChatOpenAI(
        openai_api_version="2023-05-15",
        azure_deployment="gpt35-16k",
        model_version="0613",
        temperature=0
    )

    # chatのキャッシュ、チャット履歴を保持するため
    memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True)

    # chat
    qa_chain = ConversationalRetrievalChain.from_llm(
      llm=llm,
      retriever=db.as_retriever(search_kwargs={"k": 7}),
      chain_type='stuff',
      memory=memory,
      return_source_documents=True,
  )
    result = qa_chain({"question": query})
    print(result)
    return result["answer"].strip()

# gradio
iface = gr.ChatInterface(
    fn = querying,
    chatbot=gr.Chatbot(height=1000),
    textbox=gr.Textbox(placeholder="ロジックは誰ですか?", container=False, scale=7),
    title="三体の質問応答ボット",
    theme="soft",
    examples=["暗黒森林法則について簡単に説明してください",
              "程心は最後に誰と一緒になりましたか?"],
    cache_examples=True,
    retry_btn="リトライ",
    undo_btn="元に戻す",
    clear_btn="クリア",
    submit_btn="送信"
    )

iface.launch(share=True)

実装結果#

image

参考資料#

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。