gh / app.py
dtrfktu678's picture
Update app.py
dd69fa1 verified
# --------------------------------------------------------------
# app.py – a Gradio chat UI for maya-research/maya1
# --------------------------------------------------------------
import json
import os
from pathlib import Path
from typing import List, Tuple, Dict
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import HfApi, Repository, create_repo, upload_folder
# ------------------- CONFIGURATION -----------------------------
MODEL_ID = "maya-research/maya1" # the model you want to use
MAX_NEW_TOKENS = 256 # generation length
TEMPERATURE = 0.7
TOP_P = 0.9
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Folder inside the Space where we keep per‑session JSON files
HISTORY_DIR = Path("history")
HISTORY_DIR.mkdir(exist_ok=True)
# ----------------------------------------------------------------
# 1️⃣ Load the model once (global, reused across requests)
# ----------------------------------------------------------------
print(f"🔧 Loading {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
)
# --------------------------------------------------------------
# 2️⃣ Helper functions for history persistence
# --------------------------------------------------------------
def _history_path(session_id: str) -> Path:
"""File that stores a JSON list of (user,assistant) pairs."""
safe_id = session_id.replace("/", "_")
return HISTORY_DIR / f"{safe_id}.json"
def load_history(session_id: str) -> List[Tuple[str, str]]:
"""Read JSON file → list of (user,assistant). Return [] if not present."""
p = _history_path(session_id)
if p.is_file():
try:
return json.loads(p.read_text(encoding="utf-8"))
except Exception as e:
print(f"⚠️ Failed to read history for {session_id}: {e}")
return []
def save_history(session_id: str, chat: List[Tuple[str, str]]) -> None:
"""Write JSON file and push it back to the repo."""
p = _history_path(session_id)
p.write_text(json.dumps(chat, ensure_ascii=False, indent=2), encoding="utf-8")
# -----------------------------------------------------------------
# Push the new file to the repo (so it survives container restarts)
# -----------------------------------------------------------------
# NOTE: This only works if the Space has a write token (see step 3).
try:
api = HfApi()
# `repo_id` is the full name of the Space (owner/space-name)
repo_id = os.getenv("HF_SPACE_REPO") # automatically set by the Hub
if repo_id:
api.upload_file(
path_or_fileobj=str(p),
path_in_repo=str(p),
repo_id=repo_id,
repo_type="space",
token=os.getenv("HF_TOKEN"),
)
except Exception as exc:
# Failing to push is not fatal – the file stays on the container.
print(f"⚠️ Could not push history to hub: {exc}")
def list_sessions() -> List[str]:
"""Return a list of all stored session IDs (file names)."""
return [f.stem for f in HISTORY_DIR.glob("*.json")]
# --------------------------------------------------------------
# 3️⃣ The generation function – called by Gradio
# --------------------------------------------------------------
def generate_reply(
user_message: str,
chat_history: List[Tuple[str, str]],
session_id: str,
) -> Tuple[List[Tuple[str, str]], str]:
"""
1. Append the user's new message.
2. Build the full `messages` list in the format expected by the model's
chat_template.
3. Use the tokenizer's `apply_chat_template(..., add_generation_prompt=True)`
to create the prompt.
4. Run the pipeline, decode, and strip the extra tokens.
5. Append the assistant answer and persist the whole chat.
"""
# ----- 1️⃣ Append user message to history -----
chat_history.append((user_message, "")) # placeholder for assistant
# ----- 2️⃣ Build the messages list for the template -----
messages = [{"role": "user", "content": user_message}]
# prepend previous exchanges (system messages are not needed here)
for user, assistant in chat_history[:-1]: # exclude the placeholder
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
# ----- 3️⃣ Render the prompt with the model's chat template -----
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
# Debug – uncomment if you want to see the raw prompt in logs
# print("\n--- Prompt sent to model ---\n", prompt)
# ----- 4️⃣ Generate the answer -----
result = generator(prompt, max_new_tokens=MAX_NEW_TOKENS)[0]["generated_text"]
# The pipeline returns the **whole** text (prompt + answer). Remove the prompt.
answer = result[len(prompt) :].strip()
# Some models still emit special tokens like <|eot_id|>; strip them.
for stop in ["<|eot_id|>", "</s>", "</s>"]:
answer = answer.replace(stop, "").strip()
# ----- 5️⃣ Update history and persist -----
chat_history[-1] = (user_message, answer) # replace placeholder
save_history(session_id, chat_history)
return chat_history, answer
# --------------------------------------------------------------
# 4️⃣ UI definition
# --------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Default()) as demo:
# -----------------------------------------------------------------
# Top bar – session selector + "New chat" button
# -----------------------------------------------------------------
with gr.Row():
session_dropdown = gr.Dropdown(
choices=list_sessions(),
label="🗂️ Load previous chat",
interactive=True,
allow_custom_value=False,
)
new_chat_btn = gr.Button("🆕 New chat", variant="primary")
status_txt = gr.Markdown("", visible=False)
# -----------------------------------------------------------------
# Main chat area
# -----------------------------------------------------------------
chatbot = gr.Chatbot(label="🗨️ Maya‑1 Chat", height=600)
txt = gr.Textbox(
placeholder="Type your message and hit Enter …",
label="Your message",
container=False,
)
submit_btn = gr.Button("Send", variant="secondary")
# -----------------------------------------------------------------
# Hidden state – we keep the full list of (user,assistant) tuples
# -----------------------------------------------------------------
chat_state = gr.State([]) # List[Tuple[str,str]]
session_state = gr.State("") # Current session_id (string)
# -----------------------------------------------------------------
# 5️⃣ Callbacks
# -----------------------------------------------------------------
# When the app loads, generate a fresh anonymous session ID
def init_session():
import uuid
sid = str(uuid.uuid4())
return sid, [] # session_state, chat_state
demo.load(fn=init_session, outputs=[session_state, chat_state])
# -----------------------------------------------------------------
# New chat → reset everything and give a brand‑new session ID
# -----------------------------------------------------------------
def new_chat():
import uuid
sid = str(uuid.uuid4())
return sid, [], [] # session_id, empty chat_state, empty UI
new_chat_btn.click(
fn=new_chat,
outputs=[session_state, chat_state, chatbot],
)
# -----------------------------------------------------------------
# Load a saved session from the dropdown
# -----------------------------------------------------------------
def load_session(selected: str):
if not selected:
return "", [], [] # nothing selected → blank
# The file name is the session_id we used when saving.
session_id = selected
history = load_history(session_id)
# Convert List[Tuple] → format expected by Gradio.Chatbot
ui_history = [(u, a) for u, a in history]
return session_id, history, ui_history
session_dropdown.change(
fn=load_session,
inputs=[session_dropdown],
outputs=[session_state, chat_state, chatbot],
)
# -----------------------------------------------------------------
# When the user hits "Enter" or clicks Send → generate a reply
# -----------------------------------------------------------------
def user_submit(user_msg: str, chat_hist: List[Tuple[str, str]], sid: str):
# Call the generation function
updated_hist, answer = generate_reply(user_msg, chat_hist, sid)
# UI expects List[Tuple[user,assistant]]
ui_hist = [(u, a) for u, a in updated_hist]
return "", ui_hist, updated_hist, answer
txt.submit(
fn=user_submit,
inputs=[txt, chat_state, session_state],
outputs=[txt, chatbot, chat_state, status_txt],
)
submit_btn.click(
fn=user_submit,
inputs=[txt, chat_state, session_state],
outputs=[txt, chatbot, chat_state, status_txt],
)
# -----------------------------------------------------------------
# Keep the session‑dropdown up‑to‑date after each new save
# -----------------------------------------------------------------
def refresh_dropdown():
return gr.Dropdown.update(choices=list_sessions())
# Whenever we save a new session (i.e., after every reply) we refresh the list.
chat_state.change(fn=refresh_dropdown, inputs=None, outputs=session_dropdown)
# --------------------------------------------------------------
# 6️⃣ Run the demo
# --------------------------------------------------------------
if __name__ == "__main__":
demo.queue() # enables concurrent users
demo.launch()