My AI Journey Day 2: Playing with LangChain RAG
Experimenting with LangChain based RAG:
Steps which need to be executed
- Initializing OllamaEmbeddings to use local Ollama Embedding Model
- Initializing OllamaLLM to use local Ollama LLM Model
- Initializing ChromaDB as a local vector store
- Using PyPDFLoader for extracting the PDF document
- Load and split the PDF document
- Initialize RecursiveCharacterTextSplitter for text chunking
- Chunk each extracted page from the pdf document and list with generate UUID for each chunk in the list
- Add only the new chunked documents in the local Vector Store
- Execute the Similarity Search Query on the Vector Store
- Initialize ChatPromptTemplate with predefined Prompt
- Create new Prompt from the ChatPromptTemplate and the search context from the Vector Store
- Execute LLM search
#!/usr/bin/env python3 import asyncio from importlib.metadata import metadata from langchain_community.document_loaders import PyPDFLoader from langchain_ollama import OllamaEmbeddings from langchain_chroma import Chroma from langchain_core.documents import Document from langchain.prompts import ChatPromptTemplate from langchain_ollama import OllamaLLM from langchain_text_splitters import RecursiveCharacterTextSplitter from uuid import uuid4 class FAQChat: def __init__(self, db_path = "./", collection_name = "example_collection", embedding_llm = "mxbai-embed-large:latest", generating_llm = "tinyllama:latest", pdf_file = "") -> None: self.db_path = db_path self.collection_name = collection_name self.embedding_llm = embedding_llm self.embeddings = OllamaEmbeddings(model=self.embedding_llm) self.pdf_file = pdf_file self.model = OllamaLLM(model=generating_llm) self.vector_store = Chroma( collection_name=self.collection_name, embedding_function=self.embeddings, persist_directory=self.db_path, ) self.prompt_template = """ Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you are unsure. Don't try to make up an answer. {context} Question: {question} Answer: """ async def pdf_extract_and_vectorize(self): loader = PyPDFLoader(self.pdf_file) pages = [] documents_data = [] async for page in loader.alazy_load(): pages.append(page) for i in range(len(pages)): single_page_data = pages[i].page_content document_page = Document( page_content=single_page_data, metadata={"source": "faq"}, id=i, ) documents_data.append(document_page) print(single_page_data) uuids = [str(uuid4()) for _ in range(len(documents_data))] self.vector_store.add_documents(documents=documents_data, ids=uuids) async def pdf_extract_and_vectorize_chunks(self): loader = PyPDFLoader(self.pdf_file) pages = loader.load() chunksSplitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=200, length_function=len, is_separator_regex=False) pageChunks = chunksSplitter.split_documents(pages) for i in pageChunks: i.metadata["chunk_id"] = str(uuid4()) chunk_data = self.vector_store.get() present_ids = chunk_data["ids"] new_chunks = [i for i in pageChunks if i.metadata.get("chunk_id") not in present_ids] if len(new_chunks) > 0: self.vector_store.add_documents(new_chunks, ids = [i.metadata["chunk_id"] for i in new_chunks]) else: print("Nothing to persist") async def chat_thru_vectors(self, query_text = "Who has written the programming language Python?"): context = self.vector_store.similarity_search_with_score(query_text, k=1) context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in context]) prompt_template = ChatPromptTemplate.from_template(self.prompt_template) prompt = prompt_template.format(context=context_text, question=query_text) response_text = self.model.invoke(prompt) print(response_text) def main(): faqChat = FAQChat( "./chroma_langchain_db", "example_collection", "mxbai-embed-large:latest", "tinyllama:latest", "./Ethernet_FAQ.pdf" ) loop = asyncio.get_event_loop() # loop.run_until_complete(faqChat.pdf_extract_and_vectorize()) print("Vectorizing") loop.run_until_complete(faqChat.pdf_extract_and_vectorize_chunks()) print("Retrieving") loop.run_until_complete(faqChat.chat_thru_vectors("What is a network heartbeat, give me a fast, simple and short answer?")) loop.close() if __name__ == '__main__': main()