File size: 10,560 Bytes
ffda3b1
 
 
 
 
 
 
 
 
d1af088
ffda3b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1af088
ffda3b1
 
 
 
 
 
 
d1af088
 
ffda3b1
 
d1af088
ffda3b1
 
 
 
 
 
d1af088
ffda3b1
 
 
 
 
 
 
 
d1af088
ffda3b1
 
 
 
 
 
 
d1af088
ffda3b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1af088
ffda3b1
 
 
 
 
d1af088
ffda3b1
 
d1af088
ffda3b1
 
 
d1af088
ffda3b1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# --------------------------------------------------------------
# app.py – a Gradio chat UI for maya-research/maya1
# --------------------------------------------------------------

import json
import os
from pathlib import Path
from typing import List, Tuple, Dict

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from huggingface_hub import HfApi, Repository, create_repo, upload_folder

# ------------------- CONFIGURATION -----------------------------
MODEL_ID = "maya-research/maya1"           # the model you want to use
MAX_NEW_TOKENS = 256                       # generation length
TEMPERATURE = 0.7
TOP_P = 0.9
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Folder inside the Space where we keep per‑session JSON files
HISTORY_DIR = Path("history")
HISTORY_DIR.mkdir(exist_ok=True)

# ----------------------------------------------------------------
# 1️⃣  Load the model once (global, reused across requests)
# ----------------------------------------------------------------
print(f"🔧 Loading {MODEL_ID} …")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    do_sample=True,
)

# --------------------------------------------------------------
# 2️⃣  Helper functions for history persistence
# --------------------------------------------------------------
def _history_path(session_id: str) -> Path:
    """File that stores a JSON list of (user,assistant) pairs."""
    safe_id = session_id.replace("/", "_")
    return HISTORY_DIR / f"{safe_id}.json"


def load_history(session_id: str) -> List[Tuple[str, str]]:
    """Read JSON file → list of (user,assistant). Return [] if not present."""
    p = _history_path(session_id)
    if p.is_file():
        try:
            return json.loads(p.read_text(encoding="utf-8"))
        except Exception as e:
            print(f"⚠️ Failed to read history for {session_id}: {e}")
    return []


def save_history(session_id: str, chat: List[Tuple[str, str]]) -> None:
    """Write JSON file and push it back to the repo."""
    p = _history_path(session_id)
    p.write_text(json.dumps(chat, ensure_ascii=False, indent=2), encoding="utf-8")
    # -----------------------------------------------------------------
    # Push the new file to the repo (so it survives container restarts)
    # -----------------------------------------------------------------
    # NOTE: This only works if the Space has a write token (see step 3).
    try:
        api = HfApi()
        # `repo_id` is the full name of the Space (owner/space-name)
        repo_id = os.getenv("HF_SPACE_REPO")  # automatically set by the Hub
        if repo_id:
            api.upload_file(
                path_or_fileobj=str(p),
                path_in_repo=str(p),
                repo_id=repo_id,
                repo_type="space",
                token=os.getenv("HF_TOKEN"),
            )
    except Exception as exc:
        # Failing to push is not fatal – the file stays on the container.
        print(f"⚠️ Could not push history to hub: {exc}")


def list_sessions() -> List[str]:
    """Return a list of all stored session IDs (file names)."""
    return [f.stem for f in HISTORY_DIR.glob("*.json")]


# --------------------------------------------------------------
# 3️⃣  The generation function – called by Gradio
# --------------------------------------------------------------
def generate_reply(
    user_message: str,
    chat_history: List[Tuple[str, str]],
    session_id: str,
) -> Tuple[List[Tuple[str, str]], str]:
    """
    1. Append the user's new message.
    2. Build the full `messages` list in the format expected by the model's
       chat_template.
    3. Use the tokenizer's `apply_chat_template(..., add_generation_prompt=True)`
       to create the prompt.
    4. Run the pipeline, decode, and strip the extra tokens.
    5. Append the assistant answer and persist the whole chat.
    """

    # ----- 1️⃣ Append user message to history -----
    chat_history.append((user_message, ""))   # placeholder for assistant

    # ----- 2️⃣ Build the messages list for the template -----
    messages = [{"role": "user", "content": user_message}]
    # prepend previous exchanges (system messages are not needed here)
    for user, assistant in chat_history[:-1]:  # exclude the placeholder
        messages.append({"role": "user", "content": user})
        messages.append({"role": "assistant", "content": assistant})

    # ----- 3️⃣ Render the prompt with the model's chat template -----
    prompt = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    # Debug – uncomment if you want to see the raw prompt in logs
    # print("\n--- Prompt sent to model ---\n", prompt)

    # ----- 4️⃣ Generate the answer -----
    result = generator(prompt, max_new_tokens=MAX_NEW_TOKENS)[0]["generated_text"]
    # The pipeline returns the **whole** text (prompt + answer). Remove the prompt.
    answer = result[len(prompt) :].strip()
    # Some models still emit special tokens like <|eot_id|>; strip them.
    for stop in ["<|eot_id|>", "</s>", "</s>"]:
        answer = answer.replace(stop, "").strip()

    # ----- 5️⃣ Update history and persist -----
    chat_history[-1] = (user_message, answer)  # replace placeholder
    save_history(session_id, chat_history)

    return chat_history, answer


# --------------------------------------------------------------
# 4️⃣  UI definition
# --------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Default()) as demo:
    # -----------------------------------------------------------------
    # Top bar – session selector + "New chat" button
    # -----------------------------------------------------------------
    with gr.Row():
        session_dropdown = gr.Dropdown(
            choices=list_sessions(),
            label="🗂️  Load previous chat",
            interactive=True,
            allow_custom_value=False,
        )
        new_chat_btn = gr.Button("🆕 New chat", variant="primary")
        status_txt = gr.Markdown("", visible=False)

    # -----------------------------------------------------------------
    # Main chat area
    # -----------------------------------------------------------------
    chatbot = gr.Chatbot(label="🗨️  Maya‑1 Chat", height=600)
    txt = gr.Textbox(
        placeholder="Type your message and hit Enter …",
        label="Your message",
        container=False,
    )
    submit_btn = gr.Button("Send", variant="secondary")

    # -----------------------------------------------------------------
    # Hidden state – we keep the full list of (user,assistant) tuples
    # -----------------------------------------------------------------
    chat_state = gr.State([])               # List[Tuple[str,str]]
    session_state = gr.State("")            # Current session_id (string)

    # -----------------------------------------------------------------
    # 5️⃣  Callbacks
    # -----------------------------------------------------------------
    # When the app loads, generate a fresh anonymous session ID
    def init_session():
        import uuid
        sid = str(uuid.uuid4())
        return sid, []                     # session_state, chat_state

    demo.load(fn=init_session, outputs=[session_state, chat_state])

    # -----------------------------------------------------------------
    # New chat → reset everything and give a brand‑new session ID
    # -----------------------------------------------------------------
    def new_chat():
        import uuid
        sid = str(uuid.uuid4())
        return sid, [], []                 # session_id, empty chat_state, empty UI

    new_chat_btn.click(
        fn=new_chat,
        outputs=[session_state, chat_state, chatbot],
    )

    # -----------------------------------------------------------------
    # Load a saved session from the dropdown
    # -----------------------------------------------------------------
    def load_session(selected: str):
        if not selected:
            return "", [], []               # nothing selected → blank
        # The file name is the session_id we used when saving.
        session_id = selected
        history = load_history(session_id)
        # Convert List[Tuple] → format expected by Gradio.Chatbot
        ui_history = [(u, a) for u, a in history]
        return session_id, history, ui_history

    session_dropdown.change(
        fn=load_session,
        inputs=[session_dropdown],
        outputs=[session_state, chat_state, chatbot],
    )

    # -----------------------------------------------------------------
    # When the user hits "Enter" or clicks Send → generate a reply
    # -----------------------------------------------------------------
    def user_submit(user_msg: str, chat_hist: List[Tuple[str, str]], sid: str):
        # Call the generation function
        updated_hist, answer = generate_reply(user_msg, chat_hist, sid)
        # UI expects List[Tuple[user,assistant]]
        ui_hist = [(u, a) for u, a in updated_hist]
        return "", ui_hist, updated_hist, answer

    txt.submit(
        fn=user_submit,
        inputs=[txt, chat_state, session_state],
        outputs=[txt, chatbot, chat_state, status_txt],
    )
    submit_btn.click(
        fn=user_submit,
        inputs=[txt, chat_state, session_state],
        outputs=[txt, chatbot, chat_state, status_txt],
    )

    # -----------------------------------------------------------------
    # Keep the session‑dropdown up‑to‑date after each new save
    # -----------------------------------------------------------------
    def refresh_dropdown():
        return gr.Dropdown.update(choices=list_sessions())

    # Whenever we save a new session (i.e., after every reply) we refresh the list.
    chat_state.change(fn=refresh_dropdown, inputs=None, outputs=session_dropdown)

# --------------------------------------------------------------
# 6️⃣  Run the demo
# --------------------------------------------------------------
if __name__ == "__main__":
    demo.queue()               # enables concurrent users
    demo.launch()