Update main.py
Browse files
main.py
CHANGED
|
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
|
|
| 16 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 17 |
|
| 18 |
# FastAPI Init
|
| 19 |
-
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="
|
| 20 |
|
| 21 |
# --- Pydantic Models ---
|
| 22 |
class ModelCard(BaseModel):
|
|
@@ -40,8 +40,7 @@ SUPPORTED_MODELS = {
|
|
| 40 |
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
| 41 |
"""
|
| 42 |
Formats the input for Replicate's API, flattening the message history into a
|
| 43 |
-
single 'prompt' string and handling images separately.
|
| 44 |
-
format for all their current chat/vision models.
|
| 45 |
"""
|
| 46 |
payload = {}
|
| 47 |
prompt_parts = []
|
|
@@ -81,7 +80,7 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
|
|
| 81 |
return payload
|
| 82 |
|
| 83 |
async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
| 84 |
-
"""Handles the full streaming lifecycle with
|
| 85 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 86 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
| 87 |
|
|
@@ -105,37 +104,40 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
|
| 105 |
return
|
| 106 |
|
| 107 |
try:
|
| 108 |
-
async with client.stream("GET", stream_url, headers={"Accept": "text
|
| 109 |
current_event = None
|
| 110 |
async for line in sse.aiter_lines():
|
| 111 |
if line.startswith("event:"):
|
| 112 |
current_event = line[len("event:"):].strip()
|
| 113 |
elif line.startswith("data:"):
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if current_event == "output":
|
| 116 |
-
#
|
| 117 |
-
# Replicate
|
| 118 |
-
#
|
| 119 |
content_token = ""
|
| 120 |
try:
|
| 121 |
-
|
| 122 |
-
decoded_data = json.loads(data)
|
| 123 |
-
if isinstance(decoded_data, str):
|
| 124 |
-
content_token = decoded_data
|
| 125 |
-
else:
|
| 126 |
-
# It's some other JSON type, convert to string
|
| 127 |
-
content_token = str(decoded_data)
|
| 128 |
except json.JSONDecodeError:
|
| 129 |
-
#
|
| 130 |
content_token = data
|
| 131 |
-
|
| 132 |
-
if
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
elif current_event == "done":
|
| 140 |
break
|
| 141 |
except httpx.ReadTimeout:
|
|
|
|
| 16 |
raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
|
| 17 |
|
| 18 |
# FastAPI Init
|
| 19 |
+
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="8.0.0 (Definitive Spacing Fix)")
|
| 20 |
|
| 21 |
# --- Pydantic Models ---
|
| 22 |
class ModelCard(BaseModel):
|
|
|
|
| 40 |
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
| 41 |
"""
|
| 42 |
Formats the input for Replicate's API, flattening the message history into a
|
| 43 |
+
single 'prompt' string and handling images separately.
|
|
|
|
| 44 |
"""
|
| 45 |
payload = {}
|
| 46 |
prompt_parts = []
|
|
|
|
| 80 |
return payload
|
| 81 |
|
| 82 |
async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
|
| 83 |
+
"""Handles the full streaming lifecycle with correct whitespace preservation."""
|
| 84 |
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
|
| 85 |
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
|
| 86 |
|
|
|
|
| 104 |
return
|
| 105 |
|
| 106 |
try:
|
| 107 |
+
async with client.stream("GET", stream_url, headers={"Accept": "text-event-stream"}, timeout=None) as sse:
|
| 108 |
current_event = None
|
| 109 |
async for line in sse.aiter_lines():
|
| 110 |
if line.startswith("event:"):
|
| 111 |
current_event = line[len("event:"):].strip()
|
| 112 |
elif line.startswith("data:"):
|
| 113 |
+
# --- START OF DEFINITIVE SPACING FIX ---
|
| 114 |
+
# The .strip() method was the bug. It removed crucial whitespace.
|
| 115 |
+
# This new logic correctly implements the SSE spec.
|
| 116 |
+
raw_data = line[len("data:"):]
|
| 117 |
+
if raw_data.startswith(" "):
|
| 118 |
+
# Remove only the single, optional leading space
|
| 119 |
+
data = raw_data[1:]
|
| 120 |
+
else:
|
| 121 |
+
data = raw_data
|
| 122 |
+
|
| 123 |
if current_event == "output":
|
| 124 |
+
# The data is now guaranteed to have its whitespace preserved.
|
| 125 |
+
# Replicate sometimes sends tokens as JSON strings (e.g., "\" a\""),
|
| 126 |
+
# so we still need to decode them.
|
| 127 |
content_token = ""
|
| 128 |
try:
|
| 129 |
+
content_token = json.loads(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
except json.JSONDecodeError:
|
| 131 |
+
# Not a JSON string, use the raw data
|
| 132 |
content_token = data
|
| 133 |
+
|
| 134 |
+
# We must send content_token even if it's just a space
|
| 135 |
+
chunk = {
|
| 136 |
+
"id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
|
| 137 |
+
"choices": [{"index": 0, "delta": {"content": content_token}, "finish_reason": None}]
|
| 138 |
+
}
|
| 139 |
+
yield json.dumps(chunk)
|
| 140 |
+
# --- END OF DEFINITIVE SPACING FIX ---
|
| 141 |
elif current_event == "done":
|
| 142 |
break
|
| 143 |
except httpx.ReadTimeout:
|