CodexTrouter / ProTalk_Stable.py
prelington's picture
Create ProTalk_Stable.py
5f8a165 verified
!pip install -q transformers torch accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "distilgpt2"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
system_prompt = "You are ProTalk, a professional AI assistant. Answer politely, be witty, and remember the conversation context."
chat_history = []
MAX_HISTORY = 6 # only keep last 6 messages to avoid repetition
while True:
user_input = input("User: ")
if user_input.lower() == "exit":
break
chat_history.append(f"User: {user_input}")
# keep only last MAX_HISTORY entries
relevant_history = chat_history[-MAX_HISTORY:]
prompt = system_prompt + "\n" + "\n".join(relevant_history) + "\nProTalk:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# clean response to remove prompt echo
response = response.replace(prompt, "").strip()
print(f"ProTalk: {response}")
chat_history.append(f"ProTalk: {response}")