fix: harden multimodal client I/O handling

parent 05d3ae28
...@@ -169,18 +169,18 @@ def test_build_request_spec_for_transcription_uses_multipart_file(tmp_path): ...@@ -169,18 +169,18 @@ def test_build_request_spec_for_transcription_uses_multipart_file(tmp_path):
spec = build_request_spec(config) spec = build_request_spec(config)
assert spec == { assert spec["method"] == "POST"
"method": "POST", assert spec["url"] == "http://127.0.0.1:6745/v1/audio/transcriptions"
"url": "http://127.0.0.1:6745/v1/audio/transcriptions", assert spec["headers"] == {"Accept": "application/json"}
"headers": {"Accept": "application/json"}, assert spec["data"] == {
"data": { "model": "audio:test",
"model": "audio:test", "prompt": "Transcribe carefully",
"prompt": "Transcribe carefully",
},
"files": {
"file": ("sample.wav", b"wav-bytes"),
},
} }
uploaded_name, uploaded_file = spec["files"]["file"]
assert uploaded_name == "sample.wav"
assert uploaded_file.read() == b"wav-bytes"
assert uploaded_file.closed is False
uploaded_file.close()
def test_build_request_spec_for_audio_generation_uses_json_payload(tmp_path): def test_build_request_spec_for_audio_generation_uses_json_payload(tmp_path):
...@@ -406,6 +406,27 @@ def test_task5_handle_response_payload_returns_llm_text_without_artifact(tmp_pat ...@@ -406,6 +406,27 @@ def test_task5_handle_response_payload_returns_llm_text_without_artifact(tmp_pat
assert result["payload"] == payload assert result["payload"] == payload
def test_task5_handle_response_payload_flattens_structured_chat_content(tmp_path):
payload = {
"choices": [{
"message": {
"content": [
{"type": "text", "text": "hello"},
{"type": "input_text", "text": "from model"},
{"type": "tool_result", "value": 7},
]
}
}]
}
response = DummyResponse(payload)
result = handle_response_payload("llm", response, tmp_path)
assert result["text"] == 'hello\nfrom model\n{"type": "tool_result", "value": 7}'
assert result["artifact_path"] is None
assert result["payload"] == payload
def test_task5_handle_response_payload_downloads_url_artifact(monkeypatch, tmp_path): def test_task5_handle_response_payload_downloads_url_artifact(monkeypatch, tmp_path):
payload = { payload = {
"data": [{"url": "http://example.invalid/audio.wav", "text": "generated audio summary"}] "data": [{"url": "http://example.invalid/audio.wav", "text": "generated audio summary"}]
......
...@@ -4,6 +4,7 @@ import argparse ...@@ -4,6 +4,7 @@ import argparse
import base64 import base64
import json import json
import time import time
from contextlib import ExitStack
from pathlib import Path from pathlib import Path
import requests import requests
...@@ -113,6 +114,7 @@ def build_request_spec(config: dict) -> dict: ...@@ -113,6 +114,7 @@ def build_request_spec(config: dict) -> dict:
if mode == "transcription": if mode == "transcription":
audio_path = _require_file(config.get("audio_file"), "--audio-file") audio_path = _require_file(config.get("audio_file"), "--audio-file")
file_stack = ExitStack()
return { return {
"method": "POST", "method": "POST",
"url": f"{config['url']}/v1/audio/transcriptions", "url": f"{config['url']}/v1/audio/transcriptions",
...@@ -122,8 +124,9 @@ def build_request_spec(config: dict) -> dict: ...@@ -122,8 +124,9 @@ def build_request_spec(config: dict) -> dict:
"prompt": config["prompt"], "prompt": config["prompt"],
}, },
"files": { "files": {
"file": (audio_path.name, audio_path.read_bytes()), "file": (audio_path.name, file_stack.enter_context(audio_path.open("rb"))),
}, },
"_close": file_stack.close,
} }
if mode == "audio-generation": if mode == "audio-generation":
...@@ -220,12 +223,28 @@ def _write_artifact(output_dir: Path, mode: str, payload: bytes) -> Path: ...@@ -220,12 +223,28 @@ def _write_artifact(output_dir: Path, mode: str, payload: bytes) -> Path:
return artifact_path return artifact_path
def _stringify_chat_content(content) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict) and item.get("type") in {"text", "input_text"} and isinstance(item.get("text"), str):
parts.append(item["text"])
else:
parts.append(json.dumps(item, sort_keys=True) if isinstance(item, (dict, list)) else str(item))
return "\n".join(parts)
if isinstance(content, (dict, list)):
return json.dumps(content, sort_keys=True)
return str(content)
def handle_response_payload(mode: str, response, output_dir: Path) -> dict: def handle_response_payload(mode: str, response, output_dir: Path) -> dict:
response.raise_for_status() response.raise_for_status()
payload = response.json() payload = response.json()
if mode in {"llm", "video-doubt", "music-audio-doubt"}: if mode in {"llm", "video-doubt", "music-audio-doubt"}:
text = payload["choices"][0]["message"]["content"] text = _stringify_chat_content(payload["choices"][0]["message"]["content"])
return {"text": text, "artifact_path": None, "payload": payload} return {"text": text, "artifact_path": None, "payload": payload}
if mode == "transcription": if mode == "transcription":
...@@ -249,5 +268,10 @@ def handle_response_payload(mode: str, response, output_dir: Path) -> dict: ...@@ -249,5 +268,10 @@ def handle_response_payload(mode: str, response, output_dir: Path) -> dict:
def execute_request(spec: dict): def execute_request(spec: dict):
method = spec["method"] method = spec["method"]
kwargs = {key: value for key, value in spec.items() if key not in {"method", "url"}} cleanup = spec.get("_close")
return requests.request(method=method, url=spec["url"], timeout=300, **kwargs) kwargs = {key: value for key, value in spec.items() if key not in {"method", "url", "_close"}}
try:
return requests.request(method=method, url=spec["url"], timeout=300, **kwargs)
finally:
if cleanup is not None:
cleanup()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment