Use llama.cpp's create_chat_completion for proper chat template handling

- Add generate_chat() and generate_chat_stream() methods to VulkanBackend
- These use create_chat_completion() which properly applies model's chat template
- Fallback to manual formatting if create_chat_completion fails
- Update API endpoints to pass messages dict directly instead of formatted prompt
- Fixes garbled output with Qwen3 and other models that use custom chat templates
parent eea67af6
......@@ -1009,6 +1009,73 @@ class VulkanBackend(ModelBackend):
return output["choices"][0]["text"]
def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None, tools: Optional[List] = None) -> str:
"""Generate chat completion using llama-cpp's create_chat_completion."""
if max_tokens is None:
max_tokens = 512
try:
response = self.model.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop or [],
tools=tools,
)
return response["choices"][0]["message"].get("content", "")
except Exception as e:
print(f"Warning: create_chat_completion failed ({e}), falling back to text generation")
# Fallback: format messages manually and use text generation
prompt = self._manual_format_messages(messages)
return self.generate(prompt, max_tokens, temperature, top_p, stop)
async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None, tools: Optional[List] = None) -> AsyncGenerator[str, None]:
"""Generate chat completion streaming using llama-cpp."""
if max_tokens is None:
max_tokens = 512
try:
stream = self.model.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop or [],
tools=tools,
stream=True,
)
for chunk in stream:
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except Exception as e:
print(f"Warning: create_chat_completion stream failed ({e}), falling back to text generation")
# Fallback: format messages manually and use text generation
prompt = self._manual_format_messages(messages)
async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
yield chunk
def _manual_format_messages(self, messages: List[Dict]) -> str:
"""Manual fallback for formatting messages when create_chat_completion fails."""
formatted = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
formatted.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
formatted.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
formatted.append(f"<|im_start|>assistant\n{content}<|im_end|>")
formatted.append("<|im_start|>assistant\n")
return "\n".join(formatted)
async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
......@@ -1107,6 +1174,20 @@ class ModelManager:
raise RuntimeError("No model loaded")
return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
def generate_chat(self, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None, tools: Optional[List] = None) -> str:
"""Generate chat completion non-streaming."""
if self.backend is None:
raise RuntimeError("No model loaded")
# Use generate_chat if available (Vulkan backend), otherwise format and use generate
if hasattr(self.backend, 'generate_chat'):
return self.backend.generate_chat(messages, max_tokens, temperature, top_p, stop, tools)
else:
# Fallback for NVIDIA backend
prompt = self.format_messages([ChatMessage(**m) for m in messages])
return self.backend.generate(prompt, max_tokens, temperature, top_p, stop)
async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
......@@ -1116,6 +1197,22 @@ class ModelManager:
async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
yield chunk
async def generate_chat_stream(self, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None, tools: Optional[List] = None) -> AsyncGenerator[str, None]:
"""Generate chat completion streaming."""
if self.backend is None:
raise RuntimeError("No model loaded")
# Use generate_chat_stream if available (Vulkan backend), otherwise format and use generate_stream
if hasattr(self.backend, 'generate_chat_stream'):
async for chunk in self.backend.generate_chat_stream(messages, max_tokens, temperature, top_p, stop, tools):
yield chunk
else:
# Fallback for NVIDIA backend
prompt = self.format_messages([ChatMessage(**m) for m in messages])
async for chunk in self.backend.generate_stream(prompt, max_tokens, temperature, top_p, stop):
yield chunk
@property
def model_name(self) -> str:
if self.backend is None:
......@@ -1324,9 +1421,6 @@ async def chat_completions(request: ChatCompletionRequest):
if request.tools:
messages = format_tools_for_prompt(request.tools, messages)
# Convert messages to prompt
prompt = model_manager.format_messages(messages)
# Prepare stop sequences
stop_sequences = []
if request.stop:
......@@ -1335,39 +1429,67 @@ async def chat_completions(request: ChatCompletionRequest):
else:
stop_sequences = request.stop
# Convert messages to dict format for chat completion
messages_dict = []
for msg in messages:
msg_dict = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.name:
msg_dict["name"] = msg.name
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
# Convert tools to dict format if present
tools_dict = None
if request.tools:
tools_dict = []
for tool in request.tools:
tools_dict.append({
"type": tool.type,
"function": {
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters
}
})
if request.stream:
return StreamingResponse(
stream_chat_response(
prompt,
messages_dict,
request.model,
request.max_tokens,
request.temperature,
request.top_p,
stop_sequences,
request.tools,
tools_dict,
),
media_type="text/event-stream",
)
else:
return await generate_chat_response(
prompt,
messages_dict,
request.model,
request.max_tokens,
request.temperature,
request.top_p,
stop_sequences,
request.tools,
tools_dict,
)
async def stream_chat_response(
prompt: str,
messages: List[Dict],
model_name: str,
max_tokens: Optional[int],
temperature: float,
top_p: float,
stop: List[str],
tools: Optional[List[Tool]],
tools: Optional[List[Dict]],
) -> AsyncGenerator[str, None]:
"""Stream chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
......@@ -1376,12 +1498,14 @@ async def stream_chat_response(
generated_text = ""
try:
async for chunk in model_manager.generate_stream(
prompt=prompt,
# Use generate_chat_stream for proper chat template handling
async for chunk in model_manager.generate_chat_stream(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
tools=tools,
):
# Filter malformed content from each chunk
filtered_chunk = filter_malformed_content(chunk)
......@@ -1405,7 +1529,17 @@ async def stream_chat_response(
# Check for tool calls in complete output
if tools:
tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tools)
# Convert tools back to Tool objects for parsing
from typing import cast
tool_objects = []
for t in tools:
tool_func = ToolFunction(
name=t["function"]["name"],
description=t["function"].get("description"),
parameters=t["function"].get("parameters")
)
tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tool_objects)
if tool_calls:
data = {
"id": completion_id,
......@@ -1443,25 +1577,27 @@ async def stream_chat_response(
async def generate_chat_response(
prompt: str,
messages: List[Dict],
model_name: str,
max_tokens: Optional[int],
temperature: float,
top_p: float,
stop: List[str],
tools: Optional[List[Tool]],
tools: Optional[List[Dict]],
) -> Dict:
"""Generate non-streaming chat completion response."""
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
try:
generated_text = model_manager.generate(
prompt=prompt,
# Use generate_chat for proper chat template handling
generated_text = model_manager.generate_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
tools=tools,
)
# Filter out malformed content from generated text
......@@ -1476,20 +1612,25 @@ async def generate_chat_response(
# Check for tool calls
if tools:
tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tools)
# Convert tools back to Tool objects for parsing
tool_objects = []
for t in tools:
tool_func = ToolFunction(
name=t["function"]["name"],
description=t["function"].get("description"),
parameters=t["function"].get("parameters")
)
tool_objects.append(Tool(type=t.get("type", "function"), function=tool_func))
tool_calls = model_manager.tool_parser.extract_tool_calls(generated_text, tool_objects)
if tool_calls:
response_message["tool_calls"] = tool_calls
response_message["content"] = None
finish_reason = "tool_calls"
# Calculate token counts if tokenizer available
if model_manager.tokenizer:
prompt_tokens = len(model_manager.tokenizer.encode(prompt))
completion_tokens = len(model_manager.tokenizer.encode(generated_text))
else:
# Rough estimate for Vulkan backend
prompt_tokens = len(prompt.split())
completion_tokens = len(generated_text.split())
# Calculate token counts - rough estimate since we don't have direct access to tokenizer
prompt_text = "\n".join([m.get("content", "") for m in messages])
prompt_tokens = len(prompt_text.split())
completion_tokens = len(generated_text.split()) if generated_text else 0
return {
"id": completion_id,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment