Spaces:
Sleeping
Sleeping
File size: 10,560 Bytes
ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 d1af088 ffda3b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
# --------------------------------------------------------------
# app.py – a Gradio chat UI for maya-research/maya1
# --------------------------------------------------------------
import json
import os
from pathlib import Path
from typing import List, Tuple, Dict
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import HfApi, Repository, create_repo, upload_folder
# ------------------- CONFIGURATION -----------------------------
MODEL_ID = "maya-research/maya1" # the model you want to use
MAX_NEW_TOKENS = 256 # generation length
TEMPERATURE = 0.7
TOP_P = 0.9
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Folder inside the Space where we keep per‑session JSON files
HISTORY_DIR = Path("history")
HISTORY_DIR.mkdir(exist_ok=True)
# ----------------------------------------------------------------
# 1️⃣ Load the model once (global, reused across requests)
# ----------------------------------------------------------------
print(f"🔧 Loading {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
)
# --------------------------------------------------------------
# 2️⃣ Helper functions for history persistence
# --------------------------------------------------------------
def _history_path(session_id: str) -> Path:
"""File that stores a JSON list of (user,assistant) pairs."""
safe_id = session_id.replace("/", "_")
return HISTORY_DIR / f"{safe_id}.json"
def load_history(session_id: str) -> List[Tuple[str, str]]:
"""Read JSON file → list of (user,assistant). Return [] if not present."""
p = _history_path(session_id)
if p.is_file():
try:
return json.loads(p.read_text(encoding="utf-8"))
except Exception as e:
print(f"⚠️ Failed to read history for {session_id}: {e}")
return []
def save_history(session_id: str, chat: List[Tuple[str, str]]) -> None:
"""Write JSON file and push it back to the repo."""
p = _history_path(session_id)
p.write_text(json.dumps(chat, ensure_ascii=False, indent=2), encoding="utf-8")
# -----------------------------------------------------------------
# Push the new file to the repo (so it survives container restarts)
# -----------------------------------------------------------------
# NOTE: This only works if the Space has a write token (see step 3).
try:
api = HfApi()
# `repo_id` is the full name of the Space (owner/space-name)
repo_id = os.getenv("HF_SPACE_REPO") # automatically set by the Hub
if repo_id:
api.upload_file(
path_or_fileobj=str(p),
path_in_repo=str(p),
repo_id=repo_id,
repo_type="space",
token=os.getenv("HF_TOKEN"),
)
except Exception as exc:
# Failing to push is not fatal – the file stays on the container.
print(f"⚠️ Could not push history to hub: {exc}")
def list_sessions() -> List[str]:
"""Return a list of all stored session IDs (file names)."""
return [f.stem for f in HISTORY_DIR.glob("*.json")]
# --------------------------------------------------------------
# 3️⃣ The generation function – called by Gradio
# --------------------------------------------------------------
def generate_reply(
user_message: str,
chat_history: List[Tuple[str, str]],
session_id: str,
) -> Tuple[List[Tuple[str, str]], str]:
"""
1. Append the user's new message.
2. Build the full `messages` list in the format expected by the model's
chat_template.
3. Use the tokenizer's `apply_chat_template(..., add_generation_prompt=True)`
to create the prompt.
4. Run the pipeline, decode, and strip the extra tokens.
5. Append the assistant answer and persist the whole chat.
"""
# ----- 1️⃣ Append user message to history -----
chat_history.append((user_message, "")) # placeholder for assistant
# ----- 2️⃣ Build the messages list for the template -----
messages = [{"role": "user", "content": user_message}]
# prepend previous exchanges (system messages are not needed here)
for user, assistant in chat_history[:-1]: # exclude the placeholder
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
# ----- 3️⃣ Render the prompt with the model's chat template -----
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
# Debug – uncomment if you want to see the raw prompt in logs
# print("\n--- Prompt sent to model ---\n", prompt)
# ----- 4️⃣ Generate the answer -----
result = generator(prompt, max_new_tokens=MAX_NEW_TOKENS)[0]["generated_text"]
# The pipeline returns the **whole** text (prompt + answer). Remove the prompt.
answer = result[len(prompt) :].strip()
# Some models still emit special tokens like <|eot_id|>; strip them.
for stop in ["<|eot_id|>", "</s>", "</s>"]:
answer = answer.replace(stop, "").strip()
# ----- 5️⃣ Update history and persist -----
chat_history[-1] = (user_message, answer) # replace placeholder
save_history(session_id, chat_history)
return chat_history, answer
# --------------------------------------------------------------
# 4️⃣ UI definition
# --------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Default()) as demo:
# -----------------------------------------------------------------
# Top bar – session selector + "New chat" button
# -----------------------------------------------------------------
with gr.Row():
session_dropdown = gr.Dropdown(
choices=list_sessions(),
label="🗂️ Load previous chat",
interactive=True,
allow_custom_value=False,
)
new_chat_btn = gr.Button("🆕 New chat", variant="primary")
status_txt = gr.Markdown("", visible=False)
# -----------------------------------------------------------------
# Main chat area
# -----------------------------------------------------------------
chatbot = gr.Chatbot(label="🗨️ Maya‑1 Chat", height=600)
txt = gr.Textbox(
placeholder="Type your message and hit Enter …",
label="Your message",
container=False,
)
submit_btn = gr.Button("Send", variant="secondary")
# -----------------------------------------------------------------
# Hidden state – we keep the full list of (user,assistant) tuples
# -----------------------------------------------------------------
chat_state = gr.State([]) # List[Tuple[str,str]]
session_state = gr.State("") # Current session_id (string)
# -----------------------------------------------------------------
# 5️⃣ Callbacks
# -----------------------------------------------------------------
# When the app loads, generate a fresh anonymous session ID
def init_session():
import uuid
sid = str(uuid.uuid4())
return sid, [] # session_state, chat_state
demo.load(fn=init_session, outputs=[session_state, chat_state])
# -----------------------------------------------------------------
# New chat → reset everything and give a brand‑new session ID
# -----------------------------------------------------------------
def new_chat():
import uuid
sid = str(uuid.uuid4())
return sid, [], [] # session_id, empty chat_state, empty UI
new_chat_btn.click(
fn=new_chat,
outputs=[session_state, chat_state, chatbot],
)
# -----------------------------------------------------------------
# Load a saved session from the dropdown
# -----------------------------------------------------------------
def load_session(selected: str):
if not selected:
return "", [], [] # nothing selected → blank
# The file name is the session_id we used when saving.
session_id = selected
history = load_history(session_id)
# Convert List[Tuple] → format expected by Gradio.Chatbot
ui_history = [(u, a) for u, a in history]
return session_id, history, ui_history
session_dropdown.change(
fn=load_session,
inputs=[session_dropdown],
outputs=[session_state, chat_state, chatbot],
)
# -----------------------------------------------------------------
# When the user hits "Enter" or clicks Send → generate a reply
# -----------------------------------------------------------------
def user_submit(user_msg: str, chat_hist: List[Tuple[str, str]], sid: str):
# Call the generation function
updated_hist, answer = generate_reply(user_msg, chat_hist, sid)
# UI expects List[Tuple[user,assistant]]
ui_hist = [(u, a) for u, a in updated_hist]
return "", ui_hist, updated_hist, answer
txt.submit(
fn=user_submit,
inputs=[txt, chat_state, session_state],
outputs=[txt, chatbot, chat_state, status_txt],
)
submit_btn.click(
fn=user_submit,
inputs=[txt, chat_state, session_state],
outputs=[txt, chatbot, chat_state, status_txt],
)
# -----------------------------------------------------------------
# Keep the session‑dropdown up‑to‑date after each new save
# -----------------------------------------------------------------
def refresh_dropdown():
return gr.Dropdown.update(choices=list_sessions())
# Whenever we save a new session (i.e., after every reply) we refresh the list.
chat_state.change(fn=refresh_dropdown, inputs=None, outputs=session_dropdown)
# --------------------------------------------------------------
# 6️⃣ Run the demo
# --------------------------------------------------------------
if __name__ == "__main__":
demo.queue() # enables concurrent users
demo.launch() |