rkihacker commited on
Commit
3a333bb
·
verified ·
1 Parent(s): 54de3fd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -25
main.py CHANGED
@@ -16,7 +16,7 @@ if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="7.1.0 (Streaming Space Fix)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
@@ -40,8 +40,7 @@ SUPPORTED_MODELS = {
40
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
  """
42
  Formats the input for Replicate's API, flattening the message history into a
43
- single 'prompt' string and handling images separately. This is the required
44
- format for all their current chat/vision models.
45
  """
46
  payload = {}
47
  prompt_parts = []
@@ -81,7 +80,7 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
81
  return payload
82
 
83
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
- """Handles the full streaming lifecycle with robust token parsing."""
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
@@ -105,37 +104,40 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
105
  return
106
 
107
  try:
108
- async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
109
  current_event = None
110
  async for line in sse.aiter_lines():
111
  if line.startswith("event:"):
112
  current_event = line[len("event:"):].strip()
113
  elif line.startswith("data:"):
114
- data = line[len("data:"):].strip()
 
 
 
 
 
 
 
 
 
115
  if current_event == "output":
116
- # --- START OF STREAMING FIX ---
117
- # Replicate streams tokens that can be plain text or JSON-encoded strings.
118
- # We need to robustly parse them to preserve spaces correctly.
119
  content_token = ""
120
  try:
121
- # Attempt to parse data as JSON. This handles tokens like "\" Hello\""
122
- decoded_data = json.loads(data)
123
- if isinstance(decoded_data, str):
124
- content_token = decoded_data
125
- else:
126
- # It's some other JSON type, convert to string
127
- content_token = str(decoded_data)
128
  except json.JSONDecodeError:
129
- # It's not valid JSON, so it's a plain text token.
130
  content_token = data
131
-
132
- if content_token:
133
- chunk = {
134
- "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
135
- "choices": [{"index": 0, "delta": {"content": content_token}, "finish_reason": None}]
136
- }
137
- yield json.dumps(chunk)
138
- # --- END OF STREAMING FIX ---
139
  elif current_event == "done":
140
  break
141
  except httpx.ReadTimeout:
 
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
  # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="8.0.0 (Definitive Spacing Fix)")
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
 
40
  def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
41
  """
42
  Formats the input for Replicate's API, flattening the message history into a
43
+ single 'prompt' string and handling images separately.
 
44
  """
45
  payload = {}
46
  prompt_parts = []
 
80
  return payload
81
 
82
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
83
+ """Handles the full streaming lifecycle with correct whitespace preservation."""
84
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
85
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
86
 
 
104
  return
105
 
106
  try:
107
+ async with client.stream("GET", stream_url, headers={"Accept": "text-event-stream"}, timeout=None) as sse:
108
  current_event = None
109
  async for line in sse.aiter_lines():
110
  if line.startswith("event:"):
111
  current_event = line[len("event:"):].strip()
112
  elif line.startswith("data:"):
113
+ # --- START OF DEFINITIVE SPACING FIX ---
114
+ # The .strip() method was the bug. It removed crucial whitespace.
115
+ # This new logic correctly implements the SSE spec.
116
+ raw_data = line[len("data:"):]
117
+ if raw_data.startswith(" "):
118
+ # Remove only the single, optional leading space
119
+ data = raw_data[1:]
120
+ else:
121
+ data = raw_data
122
+
123
  if current_event == "output":
124
+ # The data is now guaranteed to have its whitespace preserved.
125
+ # Replicate sometimes sends tokens as JSON strings (e.g., "\" a\""),
126
+ # so we still need to decode them.
127
  content_token = ""
128
  try:
129
+ content_token = json.loads(data)
 
 
 
 
 
 
130
  except json.JSONDecodeError:
131
+ # Not a JSON string, use the raw data
132
  content_token = data
133
+
134
+ # We must send content_token even if it's just a space
135
+ chunk = {
136
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
137
+ "choices": [{"index": 0, "delta": {"content": content_token}, "finish_reason": None}]
138
+ }
139
+ yield json.dumps(chunk)
140
+ # --- END OF DEFINITIVE SPACING FIX ---
141
  elif current_event == "done":
142
  break
143
  except httpx.ReadTimeout: