Spaces:
Sleeping
Sleeping
Commit
·
831dfea
1
Parent(s):
f27624c
modify the pipeline
Browse files
app.py
CHANGED
|
@@ -31,11 +31,11 @@ import tempfile
|
|
| 31 |
# Prompt template
|
| 32 |
template = """Context: {context}
|
| 33 |
|
| 34 |
-
Question: {
|
| 35 |
|
| 36 |
-
Answer:
|
|
|
|
| 37 |
|
| 38 |
-
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
|
| 39 |
|
| 40 |
# Load Phi-2 model from hugging face hub
|
| 41 |
model_id = "microsoft/phi-2"
|
|
@@ -148,18 +148,16 @@ def generate(question, answer, text_file, max_new_tokens):
|
|
| 148 |
)
|
| 149 |
phi2_pipeline = pipeline(
|
| 150 |
"text-generation",
|
| 151 |
-
tokenizer=tokenizer,
|
| 152 |
model=model,
|
|
|
|
| 153 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 154 |
pad_token_id=tokenizer.eos_token_id,
|
| 155 |
eos_token_id=tokenizer.eos_token_id,
|
| 156 |
-
do_sample=True,
|
| 157 |
-
temperature=0.7,
|
| 158 |
-
top_p=0.9,
|
| 159 |
-
repetition_penalty=1.1,
|
| 160 |
streamer=streamer,
|
| 161 |
)
|
| 162 |
|
|
|
|
| 163 |
hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
|
| 164 |
qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model)
|
| 165 |
|
|
|
|
| 31 |
# Prompt template
|
| 32 |
template = """Context: {context}
|
| 33 |
|
| 34 |
+
Question: {query}
|
| 35 |
|
| 36 |
+
Answer:"""
|
| 37 |
+
QA_PROMPT = PromptTemplate(template=template, input_variables=["query", "context"])
|
| 38 |
|
|
|
|
| 39 |
|
| 40 |
# Load Phi-2 model from hugging face hub
|
| 41 |
model_id = "microsoft/phi-2"
|
|
|
|
| 148 |
)
|
| 149 |
phi2_pipeline = pipeline(
|
| 150 |
"text-generation",
|
|
|
|
| 151 |
model=model,
|
| 152 |
+
tokenizer=tokenizer,
|
| 153 |
max_new_tokens=max_new_tokens,
|
| 154 |
+
do_sample=False, # ← greedy
|
| 155 |
pad_token_id=tokenizer.eos_token_id,
|
| 156 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
streamer=streamer,
|
| 158 |
)
|
| 159 |
|
| 160 |
+
|
| 161 |
hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
|
| 162 |
qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model)
|
| 163 |
|