Through langchain + gradio, a simple RAG application can be implemented with just a hundred lines of Python code. This article will use "The Three-Body Problem" as an example to explain the implementation steps.
What is RAG?#
Retrieval Augmented Generation (RAG) is a method to enhance the specific information reasoning ability of LLM. LLM, such as chatgpt, is generally a model trained on a specific organization's dataset at a specific time. However, when we need to ask questions about specific domain information or confidential information, it is difficult for it to provide accurate answers. For example, if we ask chatgpt to introduce the actions of Zhang Beihai in "The Three-Body Problem," we will get the following result:
In this case, it is costly and difficult to fine-tune a large model with existing data, and it is difficult to achieve real-time effects. RAG solves this problem well.
How to implement RAG?#
Implementing a RAG application generally involves two steps: building an index and retrieval generation.
Building an Index#
Use the Embedding model to convert the source data into word vectors and save them in the vector database. The following steps are usually involved:
- Load: First, use the DocumentLoader to read different types of document data.
- Split: Then, divide the documents into smaller chunks according to certain rules, so that the model can better understand the context.
- Store: Finally, use the Embedding model to map the segmented chunks to vectors and store them in the vector database for retrieval.
Retrieval Generation#
In fact, this involves two steps:
- Retrieve: Convert the user's input question into word vectors and retrieve the most relevant chunks in the vector database.
- Generate: Use the ChatModel or LLM, such as chatgpt, to generate a summary of the retrieved content based on the user's question and a specific prompt.
What is langchain?#
LangChain is a powerful framework designed to help developers build end-to-end applications using language models. It provides a set of tools, components, and interfaces to simplify the process of creating applications supported by large language models (LLMs) and chat models. LangChain makes it easy to manage interactions with language models, link multiple components together, and integrate additional resources such as APIs and databases.
What is gradio?#
Gradio is an open-source Python library for quickly building interactive applications. It helps developers easily integrate machine learning models into user-friendly interfaces, making the models more user-friendly and understandable.
Use langchain+gradio to quickly implement a Three-Body Question-Answering Chatbot#
The project source code has been published on GitHub: https://github.com/zivenyang/3body-chatbot
Here is the core code (Azure OpenAI API was used due to difficulties in applying for an openai account in China at that time):
from dotenv import load_dotenv
from langchain.chat_models import AzureChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.embeddings import ModelScopeEmbeddings
from langchain.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import gradio as gr
import os
# Import variables from .env, AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY
load_dotenv()
# Word embedding model, using the Chinese word embedding model trained by Damo Institute because it is a Chinese novel
MODEL_ID = "damo/nlp_gte_sentence-embedding_chinese-base"
# Path to store the vector database
PERSIST_DIRECTORY = 'docs/chroma/'
# To output both indexed documents and answers in the console, otherwise an error will occur
class AnswerConversationBufferMemory(ConversationBufferMemory):
def save_context(self, inputs, outputs) -> None:
return super(AnswerConversationBufferMemory, self).save_context(inputs,{'response': outputs['answer']})
def create_db():
"""Read local files and generate word vectors to store in the vector database"""
# Read local files, i.e., the Three-Body Problem novel
text_loader_kwargs={'autodetect_encoding': True}
loader = DirectoryLoader("./docs", glob="**/*.txt", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
pages = loader.load()
# Split the files into chunks, the chunk_size also depends on the performance of the graphics card, the larger the memory, the finer the split
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 512,
chunk_overlap = 0,
length_function = len,
)
splits = text_splitter.split_documents(pages)
# Generate vectors (embeddings) and store them in the database
embedding = ModelScopeEmbeddings(model_id=MODEL_ID)
db = Chroma.from_documents(
documents=splits,
embedding=embedding,
persist_directory=PERSIST_DIRECTORY
)
# Persist the database
db.persist()
return db
def querying(query, history):
db = None
if not os.path.exists(PERSIST_DIRECTORY):
# Create the vector database if it does not exist
db = create_db()
else:
# Load the existing vector database
embedding = ModelScopeEmbeddings(model_id=MODEL_ID)
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding)
# Chat model
llm = AzureChatOpenAI(
openai_api_version="2023-05-15",
azure_deployment="gpt35-16k",
model_version="0613",
temperature=0
)
# Chat buffer, used to keep the chat history
memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Chat
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=db.as_retriever(search_kwargs={"k": 7}),
chain_type='stuff',
memory=memory,
return_source_documents=True,
)
result = qa_chain({"question": query})
print(result)
return result["answer"].strip()
# Gradio
iface = gr.ChatInterface(
fn = querying,
chatbot=gr.Chatbot(height=1000),
textbox=gr.Textbox(placeholder="Who is Luo Ji?", container=False, scale=7),
title="Three-Body Question-Answering Chatbot",
theme="soft",
examples=["Briefly describe the Dark Forest Theory.",
"Who does Cheng Xin end up with in the end?"],
cache_examples=True,
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
submit_btn="Submit"
)
iface.launch(share=True)