File size: 2,726 Bytes
4b6d8eb
cf9e48c
 
 
b074071
4b6d8eb
b074071
 
 
cf9e48c
 
b074071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf9e48c
 
b074071
 
cf9e48c
 
 
b074071
 
cf9e48c
b074071
 
 
cf9e48c
b074071
 
 
 
 
 
 
 
cf9e48c
b074071
 
cf9e48c
b074071
 
 
 
 
 
 
 
 
 
 
 
cf9e48c
 
 
b074071
cf9e48c
b074071
 
 
 
 
 
 
 
cf9e48c
b074071
 
 
 
4b6d8eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import os

# Global model cache
model = None
processor = None
device = "cuda" if torch.cuda.is_available() else "cpu"


def load_model():
    global model, processor
    if model is None:
        repo_id = "MERaLiON/MERaLiON-2-10B"
        print("Loading MERaLiON-2-10B model...")
        processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            repo_id,
            use_safetensors=True,
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        print("Model loaded successfully!")
    return model, processor


def meralion_inference(prompt, uploaded_file):
    global model, processor

    if uploaded_file is None:
        return "Please upload an audio file."

    # Load model on first run
    model, processor = load_model()

    try:
        # Load audio at 16kHz
        audio_array, sr = librosa.load(uploaded_file.name, sr=16000)

        # Prompt template
        prompt_template = "Instruction: {query}\nFollow the text instruction based on the following audio: <SpeechHere>"
        conversation = [
            {"role": "user", "content": prompt_template.format(query=prompt)}
        ]
        chat_prompt = processor.tokenizer.apply_chat_template(
            conversation=conversation, tokenize=False, add_generation_prompt=True
        )

        # Process inputs
        inputs = processor(text=chat_prompt, audios=audio_array)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=256, do_sample=True, temperature=0.7
            )
        generated_ids = outputs[:, inputs["input_ids"].size(1) :]
        response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return response

    except Exception as e:
        return f"Error during inference: {str(e)}"


with gr.Blocks() as demo:
    gr.Markdown("# MERaLiON-2-10B Audio Demo")
    with gr.Row():
        prompt_input = gr.Textbox(
            label="Enter Prompt", value="Please transcribe this speech.", lines=2
        )
        file_input = gr.File(
            label="Upload Audio File (WAV/MP3, max 300s)",
            file_types=[".wav", ".mp3", ".m4a"],
        )
    output_text = gr.Textbox(label="Model Output", lines=8)

    submit_btn = gr.Button("Run Inference", variant="primary")
    submit_btn.click(
        meralion_inference, inputs=[prompt_input, file_input], outputs=output_text
    )

demo.launch()