Spaces:
Sleeping
Sleeping
| import nltk | |
| try: | |
| nltk.download('averaged_perceptron_tagger_eng', quiet=True) | |
| nltk.download("punkt", quiet=True) | |
| nltk.download('punkt_tab', quiet=True) | |
| except Exception as e: | |
| print(f"Warning: NLTK download failed: {e}") | |
| import gradio as gr | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.document_loaders import UnstructuredFileLoader, PyPDFLoader | |
| from langchain.vectorstores.faiss import FAISS | |
| from langchain.vectorstores.utils import DistanceStrategy | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.schema import Document | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.vectorstores.base import VectorStoreRetriever | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| import os | |
| import tempfile | |
| # Prompt template optimized for Flan-T5 | |
| template = """Answer the question based on the context below. | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"]) | |
| # Load Flan-T5 model from hugging face hub - excellent for CPU and Q&A tasks | |
| # Alternative popular CPU-friendly models you can try: | |
| # - "google/flan-t5-small" (faster, smaller) | |
| # - "google/flan-t5-large" (better quality, slower) | |
| # - "microsoft/DialoGPT-medium" (conversational) | |
| model_id = "google/flan-t5-base" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_id, torch_dtype=torch.float32 | |
| ) | |
| # sentence transformers to be used in vector store | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/msmarco-distilbert-base-v4", | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": False}, | |
| ) | |
| def clean_response(text): | |
| """Clean up the generated response""" | |
| # Remove excessive whitespace and newlines | |
| text = ' '.join(text.split()) | |
| # Remove repetitive patterns | |
| words = text.split() | |
| cleaned_words = [] | |
| for word in words: | |
| # Skip if the same word appears too many times consecutively | |
| if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]): | |
| continue | |
| cleaned_words.append(word) | |
| cleaned_text = ' '.join(cleaned_words) | |
| # Truncate at natural stopping points | |
| sentences = cleaned_text.split('.') | |
| if len(sentences) > 1: | |
| # Keep complete sentences | |
| good_sentences = [] | |
| for sentence in sentences[:-1]: # Exclude last potentially incomplete sentence | |
| if len(sentence.strip()) > 5: # Avoid very short fragments | |
| good_sentences.append(sentence.strip()) | |
| if good_sentences: | |
| return '. '.join(good_sentences) + '.' | |
| return cleaned_text[:500] # Fallback: truncate to reasonable length | |
| # Returns a faiss vector store retriever given a txt or pdf file | |
| def prepare_vector_store_retriever(filename): | |
| # Load data based on file extension | |
| if filename.lower().endswith('.pdf'): | |
| loader = PyPDFLoader(filename) | |
| else: | |
| loader = UnstructuredFileLoader(filename) | |
| raw_documents = loader.load() | |
| # Split the text | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n\n", chunk_size=800, chunk_overlap=0, length_function=len | |
| ) | |
| documents = text_splitter.split_documents(raw_documents) | |
| # Creating a vectorstore | |
| vectorstore = FAISS.from_documents( | |
| documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT | |
| ) | |
| return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2}), vectorstore | |
| # Retrieval QA chain | |
| def get_retrieval_qa_chain(text_file, hf_model): | |
| retriever = default_retriever | |
| vectorstore = default_vectorstore | |
| if text_file != default_text_file or default_text_file is None: | |
| if text_file is not None and os.path.exists(text_file): | |
| retriever, vectorstore = prepare_vector_store_retriever(text_file) | |
| else: | |
| # Create a dummy retriever if no file is available | |
| dummy_doc = Document(page_content="No document loaded. Please upload a file to get started.") | |
| dummy_vectorstore = FAISS.from_documents([dummy_doc], embeddings) | |
| retriever = VectorStoreRetriever(vectorstore=dummy_vectorstore, search_kwargs={"k": 1}) | |
| vectorstore = dummy_vectorstore | |
| chain = RetrievalQA.from_chain_type( | |
| llm=hf_model, | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": QA_PROMPT}, | |
| ) | |
| return chain, vectorstore | |
| # Generates response using the question answering chain defined earlier | |
| def generate(question, answer, text_file, max_new_tokens): | |
| if not question.strip(): | |
| yield "Please enter a question." | |
| return | |
| try: | |
| # Create pipeline for text2text generation (Flan-T5) | |
| phi2_pipeline = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| ) | |
| hf_model = HuggingFacePipeline(pipeline=phi2_pipeline) | |
| qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model) | |
| query = f"{question}" | |
| if len(tokenizer.tokenize(query)) >= 512: | |
| yield "Your question is too long! Please shorten it." | |
| return | |
| # Get the response directly without streaming first | |
| try: | |
| result = qa_chain.invoke({"query": query}) | |
| # Extract the answer from the result | |
| if isinstance(result, dict): | |
| response = result.get('result', str(result)) | |
| else: | |
| response = str(result) | |
| # Clean the response | |
| cleaned_response = clean_response(response) | |
| yield cleaned_response | |
| except Exception as e: | |
| yield f"Error during generation: {str(e)}" | |
| return | |
| except Exception as e: | |
| yield f"Error: {str(e)}" | |
| # replaces the retriever in the question answering chain whenever a new file is uploaded | |
| def upload_file(file): | |
| if file is not None: | |
| # In Gradio, file is already a path to the uploaded file | |
| file_path = file.name if hasattr(file, 'name') else file | |
| filename = os.path.basename(file_path) | |
| return filename, file_path | |
| return None, None | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Retrieval Augmented Generation with Flan-T5: Question Answering demo | |
| ### This demo uses Google's Flan-T5 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt or PDF file and ask the model questions related to the content of that file. | |
| ### Features: | |
| - Support for both PDF and text files | |
| - Retrieval-based question answering using document context | |
| - Optimized for CPU performance using Flan-T5-Base model | |
| ### To get started, upload a text (.txt) or PDF (.pdf) file using the upload button below. | |
| The Flan-T5 model is efficient and works well on CPU, making it perfect for document Q&A tasks. | |
| Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to your query and inject it into our prompt. | |
| The model is then able to answer questions by incorporating knowledge from the newly provided document. | |
| """ | |
| ) | |
| default_text_file = "Oppenheimer-movie-wiki.txt" | |
| # Check if default file exists, if not, set to None | |
| if not os.path.exists(default_text_file): | |
| default_text_file = None | |
| default_retriever = None | |
| default_vectorstore = None | |
| initial_file_display = "No default file found - please upload a file" | |
| else: | |
| default_retriever, default_vectorstore = prepare_vector_store_retriever(default_text_file) | |
| initial_file_display = default_text_file | |
| text_file = gr.State(default_text_file) | |
| gr.Markdown( | |
| "## Upload a txt or PDF file to get started" | |
| ) | |
| file_name = gr.Textbox( | |
| label="Loaded file", value=initial_file_display, lines=1, interactive=False | |
| ) | |
| upload_button = gr.UploadButton( | |
| label="Click to upload a text or PDF file", file_types=[".txt", ".pdf"], file_count="single" | |
| ) | |
| upload_button.upload(upload_file, upload_button, [file_name, text_file]) | |
| gr.Markdown("## Enter your question") | |
| tokens_slider = gr.Slider( | |
| 8, | |
| 256, | |
| value=64, | |
| label="Maximum new tokens", | |
| info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3) | |
| with gr.Column(): | |
| ans = gr.Textbox(label="Answer", lines=4, interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| btn = gr.Button("Submit") | |
| with gr.Column(): | |
| clear = gr.ClearButton([ques, ans]) | |
| btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans]) | |
| examples = gr.Examples( | |
| examples=[ | |
| "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?", | |
| "In the plot of the movie, why did Lewis Strauss resent Robert Oppenheimer?", | |
| "How much money did the Oppenheimer movie make at the US and global box office?", | |
| "What score did the Oppenheimer movie get on Rotten Tomatoes and Metacritic?", | |
| ], | |
| inputs=[ques], | |
| ) | |
| demo.queue().launch() | |