|
|
import datetime |
|
|
import builtins |
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import threading |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
from queue import Empty, Queue |
|
|
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, cast |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from fastapi import FastAPI, WebSocket |
|
|
from fastapi.responses import FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from starlette.websockets import WebSocketDisconnect, WebSocketState |
|
|
|
|
|
from vibevoice.modular.modeling_vibevoice_streaming_inference import ( |
|
|
VibeVoiceStreamingForConditionalGenerationInference, |
|
|
) |
|
|
from vibevoice.processor.vibevoice_streaming_processor import ( |
|
|
VibeVoiceStreamingProcessor, |
|
|
) |
|
|
from vibevoice.modular.streamer import AudioStreamer |
|
|
|
|
|
import copy |
|
|
|
|
|
BASE = Path(__file__).parent |
|
|
SAMPLE_RATE = 24_000 |
|
|
|
|
|
|
|
|
def get_timestamp(): |
|
|
timestamp = datetime.datetime.utcnow().replace( |
|
|
tzinfo=datetime.timezone.utc |
|
|
).astimezone( |
|
|
datetime.timezone(datetime.timedelta(hours=8)) |
|
|
).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] |
|
|
return timestamp |
|
|
|
|
|
class StreamingTTSService: |
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
device: str = "cuda", |
|
|
inference_steps: int = 5, |
|
|
) -> None: |
|
|
self.model_path = Path(model_path) |
|
|
self.inference_steps = inference_steps |
|
|
self.sample_rate = SAMPLE_RATE |
|
|
|
|
|
self.processor: Optional[VibeVoiceStreamingProcessor] = None |
|
|
self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None |
|
|
self.voice_presets: Dict[str, Path] = {} |
|
|
self.default_voice_key: Optional[str] = None |
|
|
self._voice_cache: Dict[str, Tuple[object, Path, str]] = {} |
|
|
|
|
|
if device == "mpx": |
|
|
print("Note: device 'mpx' detected, treating it as 'mps'.") |
|
|
device = "mps" |
|
|
if device == "mps" and not torch.backends.mps.is_available(): |
|
|
print("Warning: MPS not available. Falling back to CPU.") |
|
|
device = "cpu" |
|
|
self.device = device |
|
|
self._torch_device = torch.device(device) |
|
|
|
|
|
def load(self) -> None: |
|
|
print(f"[startup] Loading processor from {self.model_path}") |
|
|
self.processor = VibeVoiceStreamingProcessor.from_pretrained(str(self.model_path)) |
|
|
|
|
|
|
|
|
|
|
|
if self.device == "mps": |
|
|
load_dtype = torch.float32 |
|
|
device_map = None |
|
|
attn_impl_primary = "sdpa" |
|
|
elif self.device == "cuda": |
|
|
load_dtype = torch.bfloat16 |
|
|
device_map = 'cuda' |
|
|
attn_impl_primary = "flash_attention_2" |
|
|
else: |
|
|
load_dtype = torch.float32 |
|
|
device_map = 'cpu' |
|
|
attn_impl_primary = "sdpa" |
|
|
print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") |
|
|
|
|
|
try: |
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
|
|
str(self.model_path), |
|
|
torch_dtype=load_dtype, |
|
|
device_map=device_map, |
|
|
attn_implementation=attn_impl_primary, |
|
|
) |
|
|
|
|
|
if self.device == "mps": |
|
|
self.model.to("mps") |
|
|
except Exception as e: |
|
|
if attn_impl_primary == 'flash_attention_2': |
|
|
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") |
|
|
|
|
|
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
|
|
str(self.model_path), |
|
|
torch_dtype=load_dtype, |
|
|
device_map=self.device, |
|
|
attn_implementation='sdpa', |
|
|
) |
|
|
print("Load model with SDPA successfully ") |
|
|
else: |
|
|
raise e |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( |
|
|
self.model.model.noise_scheduler.config, |
|
|
algorithm_type="sde-dpmsolver++", |
|
|
beta_schedule="squaredcos_cap_v2", |
|
|
) |
|
|
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) |
|
|
|
|
|
self.voice_presets = self._load_voice_presets() |
|
|
preset_name = os.environ.get("VOICE_PRESET") |
|
|
self.default_voice_key = self._determine_voice_key(preset_name) |
|
|
self._ensure_voice_cached(self.default_voice_key) |
|
|
|
|
|
def _load_voice_presets(self) -> Dict[str, Path]: |
|
|
voices_dir = BASE.parent / "voices" / "streaming_model" |
|
|
if not voices_dir.exists(): |
|
|
raise RuntimeError(f"Voices directory not found: {voices_dir}") |
|
|
|
|
|
presets: Dict[str, Path] = {} |
|
|
for pt_path in voices_dir.glob("*.pt"): |
|
|
presets[pt_path.stem] = pt_path |
|
|
|
|
|
if not presets: |
|
|
raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}") |
|
|
|
|
|
print(f"[startup] Found {len(presets)} voice presets") |
|
|
return dict(sorted(presets.items())) |
|
|
|
|
|
def _determine_voice_key(self, name: Optional[str]) -> str: |
|
|
if name and name in self.voice_presets: |
|
|
return name |
|
|
|
|
|
default_key = "en-WHTest_man" |
|
|
if default_key in self.voice_presets: |
|
|
return default_key |
|
|
|
|
|
first_key = next(iter(self.voice_presets)) |
|
|
print(f"[startup] Using fallback voice preset: {first_key}") |
|
|
return first_key |
|
|
|
|
|
def _ensure_voice_cached(self, key: str) -> Tuple[object, Path, str]: |
|
|
if key not in self.voice_presets: |
|
|
raise RuntimeError(f"Voice preset {key!r} not found") |
|
|
|
|
|
if key not in self._voice_cache: |
|
|
preset_path = self.voice_presets[key] |
|
|
print(f"[startup] Loading voice preset {key} from {preset_path}") |
|
|
print(f"[startup] Loading prefilled prompt from {preset_path}") |
|
|
prefilled_outputs = torch.load( |
|
|
preset_path, |
|
|
map_location=self._torch_device, |
|
|
weights_only=False, |
|
|
) |
|
|
self._voice_cache[key] = prefilled_outputs |
|
|
|
|
|
return self._voice_cache[key] |
|
|
|
|
|
def _get_voice_resources(self, requested_key: Optional[str]) -> Tuple[str, object, Path, str]: |
|
|
key = requested_key if requested_key and requested_key in self.voice_presets else self.default_voice_key |
|
|
if key is None: |
|
|
key = next(iter(self.voice_presets)) |
|
|
self.default_voice_key = key |
|
|
|
|
|
prefilled_outputs = self._ensure_voice_cached(key) |
|
|
return key, prefilled_outputs |
|
|
|
|
|
def _prepare_inputs(self, text: str, prefilled_outputs: object): |
|
|
if not self.processor or not self.model: |
|
|
raise RuntimeError("StreamingTTSService not initialized") |
|
|
|
|
|
processor_kwargs = { |
|
|
"text": text.strip(), |
|
|
"cached_prompt": prefilled_outputs, |
|
|
"padding": True, |
|
|
"return_tensors": "pt", |
|
|
"return_attention_mask": True, |
|
|
} |
|
|
|
|
|
processed = self.processor.process_input_with_cached_prompt(**processor_kwargs) |
|
|
|
|
|
prepared = { |
|
|
key: value.to(self._torch_device) if hasattr(value, "to") else value |
|
|
for key, value in processed.items() |
|
|
} |
|
|
return prepared |
|
|
|
|
|
def _run_generation( |
|
|
self, |
|
|
inputs, |
|
|
audio_streamer: AudioStreamer, |
|
|
errors, |
|
|
cfg_scale: float, |
|
|
do_sample: bool, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
refresh_negative: bool, |
|
|
prefilled_outputs, |
|
|
stop_event: threading.Event, |
|
|
) -> None: |
|
|
try: |
|
|
self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=None, |
|
|
cfg_scale=cfg_scale, |
|
|
tokenizer=self.processor.tokenizer, |
|
|
generation_config={ |
|
|
"do_sample": do_sample, |
|
|
"temperature": temperature if do_sample else 1.0, |
|
|
"top_p": top_p if do_sample else 1.0, |
|
|
}, |
|
|
audio_streamer=audio_streamer, |
|
|
stop_check_fn=stop_event.is_set, |
|
|
verbose=False, |
|
|
refresh_negative=refresh_negative, |
|
|
all_prefilled_outputs=copy.deepcopy(prefilled_outputs), |
|
|
) |
|
|
except Exception as exc: |
|
|
errors.append(exc) |
|
|
traceback.print_exc() |
|
|
audio_streamer.end() |
|
|
|
|
|
def stream( |
|
|
self, |
|
|
text: str, |
|
|
cfg_scale: float = 1.5, |
|
|
do_sample: bool = False, |
|
|
temperature: float = 0.9, |
|
|
top_p: float = 0.9, |
|
|
refresh_negative: bool = True, |
|
|
inference_steps: Optional[int] = None, |
|
|
voice_key: Optional[str] = None, |
|
|
log_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None, |
|
|
stop_event: Optional[threading.Event] = None, |
|
|
) -> Iterator[np.ndarray]: |
|
|
if not text.strip(): |
|
|
return |
|
|
text = text.replace("’", "'") |
|
|
selected_voice, prefilled_outputs = self._get_voice_resources(voice_key) |
|
|
|
|
|
def emit(event: str, **payload: Any) -> None: |
|
|
if log_callback: |
|
|
try: |
|
|
log_callback(event, **payload) |
|
|
except Exception as exc: |
|
|
print(f"[log_callback] Error while emitting {event}: {exc}") |
|
|
|
|
|
steps_to_use = self.inference_steps |
|
|
if inference_steps is not None: |
|
|
try: |
|
|
parsed_steps = int(inference_steps) |
|
|
if parsed_steps > 0: |
|
|
steps_to_use = parsed_steps |
|
|
except (TypeError, ValueError): |
|
|
pass |
|
|
if self.model: |
|
|
self.model.set_ddpm_inference_steps(num_steps=steps_to_use) |
|
|
self.inference_steps = steps_to_use |
|
|
|
|
|
inputs = self._prepare_inputs(text, prefilled_outputs) |
|
|
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None) |
|
|
errors: list = [] |
|
|
stop_signal = stop_event or threading.Event() |
|
|
|
|
|
thread = threading.Thread( |
|
|
target=self._run_generation, |
|
|
kwargs={ |
|
|
"inputs": inputs, |
|
|
"audio_streamer": audio_streamer, |
|
|
"errors": errors, |
|
|
"cfg_scale": cfg_scale, |
|
|
"do_sample": do_sample, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"refresh_negative": refresh_negative, |
|
|
"prefilled_outputs": prefilled_outputs, |
|
|
"stop_event": stop_signal, |
|
|
}, |
|
|
daemon=True, |
|
|
) |
|
|
thread.start() |
|
|
|
|
|
generated_samples = 0 |
|
|
|
|
|
try: |
|
|
stream = audio_streamer.get_stream(0) |
|
|
for audio_chunk in stream: |
|
|
if torch.is_tensor(audio_chunk): |
|
|
audio_chunk = audio_chunk.detach().cpu().to(torch.float32).numpy() |
|
|
else: |
|
|
audio_chunk = np.asarray(audio_chunk, dtype=np.float32) |
|
|
|
|
|
if audio_chunk.ndim > 1: |
|
|
audio_chunk = audio_chunk.reshape(-1) |
|
|
|
|
|
peak = np.max(np.abs(audio_chunk)) if audio_chunk.size else 0.0 |
|
|
if peak > 1.0: |
|
|
audio_chunk = audio_chunk / peak |
|
|
|
|
|
generated_samples += int(audio_chunk.size) |
|
|
emit( |
|
|
"model_progress", |
|
|
generated_sec=generated_samples / self.sample_rate, |
|
|
chunk_sec=audio_chunk.size / self.sample_rate, |
|
|
) |
|
|
|
|
|
chunk_to_yield = audio_chunk.astype(np.float32, copy=False) |
|
|
|
|
|
yield chunk_to_yield |
|
|
finally: |
|
|
stop_signal.set() |
|
|
audio_streamer.end() |
|
|
thread.join() |
|
|
if errors: |
|
|
emit("generation_error", message=str(errors[0])) |
|
|
raise errors[0] |
|
|
|
|
|
def chunk_to_pcm16(self, chunk: np.ndarray) -> bytes: |
|
|
chunk = np.clip(chunk, -1.0, 1.0) |
|
|
pcm = (chunk * 32767.0).astype(np.int16) |
|
|
return pcm.tobytes() |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def _startup() -> None: |
|
|
model_path = os.environ.get("MODEL_PATH") |
|
|
if not model_path: |
|
|
raise RuntimeError("MODEL_PATH not set in environment") |
|
|
|
|
|
device = os.environ.get("MODEL_DEVICE", "cuda") |
|
|
|
|
|
service = StreamingTTSService( |
|
|
model_path=model_path, |
|
|
device=device |
|
|
) |
|
|
service.load() |
|
|
|
|
|
app.state.tts_service = service |
|
|
app.state.model_path = model_path |
|
|
app.state.device = device |
|
|
app.state.websocket_lock = asyncio.Lock() |
|
|
print("[startup] Model ready.") |
|
|
|
|
|
|
|
|
def streaming_tts(text: str, **kwargs) -> Iterator[np.ndarray]: |
|
|
service: StreamingTTSService = app.state.tts_service |
|
|
yield from service.stream(text, **kwargs) |
|
|
|
|
|
@app.websocket("/stream") |
|
|
async def websocket_stream(ws: WebSocket) -> None: |
|
|
await ws.accept() |
|
|
text = ws.query_params.get("text", "") |
|
|
print(f"Client connected, text={text!r}") |
|
|
cfg_param = ws.query_params.get("cfg") |
|
|
steps_param = ws.query_params.get("steps") |
|
|
voice_param = ws.query_params.get("voice") |
|
|
|
|
|
try: |
|
|
cfg_scale = float(cfg_param) if cfg_param is not None else 1.5 |
|
|
except ValueError: |
|
|
cfg_scale = 1.5 |
|
|
if cfg_scale <= 0: |
|
|
cfg_scale = 1.5 |
|
|
try: |
|
|
inference_steps = int(steps_param) if steps_param is not None else None |
|
|
if inference_steps is not None and inference_steps <= 0: |
|
|
inference_steps = None |
|
|
except ValueError: |
|
|
inference_steps = None |
|
|
|
|
|
service: StreamingTTSService = app.state.tts_service |
|
|
lock: asyncio.Lock = app.state.websocket_lock |
|
|
|
|
|
if lock.locked(): |
|
|
busy_message = { |
|
|
"type": "log", |
|
|
"event": "backend_busy", |
|
|
"data": {"message": "Please wait for the other requests to complete."}, |
|
|
"timestamp": get_timestamp(), |
|
|
} |
|
|
print("Please wait for the other requests to complete.") |
|
|
try: |
|
|
await ws.send_text(json.dumps(busy_message)) |
|
|
except Exception: |
|
|
pass |
|
|
await ws.close(code=1013, reason="Service busy") |
|
|
return |
|
|
|
|
|
acquired = False |
|
|
try: |
|
|
await lock.acquire() |
|
|
acquired = True |
|
|
|
|
|
log_queue: "Queue[Dict[str, Any]]" = Queue() |
|
|
|
|
|
def enqueue_log(event: str, **data: Any) -> None: |
|
|
log_queue.put({"event": event, "data": data}) |
|
|
|
|
|
async def flush_logs() -> None: |
|
|
while True: |
|
|
try: |
|
|
entry = log_queue.get_nowait() |
|
|
except Empty: |
|
|
break |
|
|
message = { |
|
|
"type": "log", |
|
|
"event": entry.get("event"), |
|
|
"data": entry.get("data", {}), |
|
|
"timestamp": get_timestamp(), |
|
|
} |
|
|
try: |
|
|
await ws.send_text(json.dumps(message)) |
|
|
except Exception: |
|
|
break |
|
|
|
|
|
enqueue_log( |
|
|
"backend_request_received", |
|
|
text_length=len(text or ""), |
|
|
cfg_scale=cfg_scale, |
|
|
inference_steps=inference_steps, |
|
|
voice=voice_param, |
|
|
) |
|
|
|
|
|
stop_signal = threading.Event() |
|
|
|
|
|
iterator = streaming_tts( |
|
|
text, |
|
|
cfg_scale=cfg_scale, |
|
|
inference_steps=inference_steps, |
|
|
voice_key=voice_param, |
|
|
log_callback=enqueue_log, |
|
|
stop_event=stop_signal, |
|
|
) |
|
|
sentinel = object() |
|
|
first_ws_send_logged = False |
|
|
|
|
|
await flush_logs() |
|
|
|
|
|
try: |
|
|
while ws.client_state == WebSocketState.CONNECTED: |
|
|
await flush_logs() |
|
|
chunk = await asyncio.to_thread(next, iterator, sentinel) |
|
|
if chunk is sentinel: |
|
|
break |
|
|
chunk = cast(np.ndarray, chunk) |
|
|
payload = service.chunk_to_pcm16(chunk) |
|
|
await ws.send_bytes(payload) |
|
|
if not first_ws_send_logged: |
|
|
first_ws_send_logged = True |
|
|
enqueue_log("backend_first_chunk_sent") |
|
|
await flush_logs() |
|
|
except WebSocketDisconnect: |
|
|
print("Client disconnected (WebSocketDisconnect)") |
|
|
enqueue_log("client_disconnected") |
|
|
stop_signal.set() |
|
|
finally: |
|
|
stop_signal.set() |
|
|
enqueue_log("backend_stream_complete") |
|
|
await flush_logs() |
|
|
try: |
|
|
iterator_close = getattr(iterator, "close", None) |
|
|
if callable(iterator_close): |
|
|
iterator_close() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
while not log_queue.empty(): |
|
|
try: |
|
|
log_queue.get_nowait() |
|
|
except Empty: |
|
|
break |
|
|
if ws.client_state == WebSocketState.CONNECTED: |
|
|
await ws.close() |
|
|
print("WS handler exit") |
|
|
finally: |
|
|
if acquired: |
|
|
lock.release() |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def index(): |
|
|
return FileResponse(BASE / "index.html") |
|
|
|
|
|
|
|
|
@app.get("/config") |
|
|
def get_config(): |
|
|
service: StreamingTTSService = app.state.tts_service |
|
|
voices = sorted(service.voice_presets.keys()) |
|
|
return { |
|
|
"voices": voices, |
|
|
"default_voice": service.default_voice_key, |
|
|
} |
|
|
|
|
|
|