Detect chat template from model and use appropriate formatting - avoid Jinja...

Detect chat template from model and use appropriate formatting - avoid Jinja errors by using manual formatting when template detection fails
parent 576a6cfe
...@@ -1062,6 +1062,65 @@ class VulkanBackend(ModelBackend): ...@@ -1062,6 +1062,65 @@ class VulkanBackend(ModelBackend):
self.n_ctx = 2048 self.n_ctx = 2048
self.verbose = True self.verbose = True
self.main_gpu = 0 # Default to first GPU self.main_gpu = 0 # Default to first GPU
self.chat_template = None # Detected chat template name
self._detect_chat_template()
def _detect_chat_template(self):
"""Detect the chat template used by the model."""
try:
# Try to get the chat template from the model
# llama.cpp models have a chat_template attribute
from llama_cpp.llama_chat_format import ChatFormatterResponse
# We'll detect it when the model is loaded
self.chat_template = "unknown"
print("DEBUG: Chat template detection will happen after model load")
except Exception as e:
print(f"DEBUG: Could not initialize chat template detection: {e}")
self.chat_template = None
def _finalize_chat_template_detection(self):
"""Finalize chat template detection after model is loaded."""
try:
# Try to get the chat template name from the model's chat formatter
if hasattr(self.model, 'tokenizer') and self.model.tokenizer:
tokenizer = self.model.tokenizer
# Check if there's a chat_template attribute
if hasattr(tokenizer, 'chat_template'):
template = tokenizer.chat_template
if template:
# Detect common templates
template_str = str(template)
if 'qwen' in template_str.lower():
self.chat_template = "qwen"
elif 'phi' in template_str.lower():
self.chat_template = "phi"
elif 'llama3' in template_str.lower() or 'llama-3' in template_str.lower():
self.chat_template = "llama3"
elif 'chatml' in template_str.lower():
self.chat_template = "chatml"
else:
self.chat_template = "default"
print(f"DEBUG: Detected chat template: {self.chat_template}")
return
# Try a test message to see what format works
test_messages = [{"role": "user", "content": "test"}]
try:
self.model.create_chat_completion(messages=test_messages, max_tokens=1)
self.chat_template = "default"
print("DEBUG: Chat template detected via test: default")
except Exception as e:
error_str = str(e).lower()
if 'jinja' in error_str:
# Jinja template issue - try without tools
self.chat_template = "jinja_fallback"
print("DEBUG: Chat template detected: jinja_fallback (will use manual formatting)")
else:
self.chat_template = "unknown"
print(f"DEBUG: Chat template detection failed: {e}")
except Exception as e:
self.chat_template = "unknown"
print(f"DEBUG: Final chat template detection error: {e}")
def list_vulkan_devices(self): def list_vulkan_devices(self):
"""List available Vulkan GPU devices.""" """List available Vulkan GPU devices."""
...@@ -1196,6 +1255,10 @@ class VulkanBackend(ModelBackend): ...@@ -1196,6 +1255,10 @@ class VulkanBackend(ModelBackend):
self.model = Llama(**llama_kwargs) self.model = Llama(**llama_kwargs)
self.model_name = model_name self.model_name = model_name
print("\nModel loaded successfully with Vulkan!") print("\nModel loaded successfully with Vulkan!")
# Detect the chat template after model load
self._finalize_chat_template_detection()
print(f"DEBUG: Chat template: {self.chat_template}")
except Exception as e: except Exception as e:
print(f"Error loading model with Vulkan: {e}") print(f"Error loading model with Vulkan: {e}")
print("Make sure Vulkan drivers are installed:") print("Make sure Vulkan drivers are installed:")
...@@ -1262,6 +1325,14 @@ class VulkanBackend(ModelBackend): ...@@ -1262,6 +1325,14 @@ class VulkanBackend(ModelBackend):
if max_tokens is None: if max_tokens is None:
max_tokens = 512 max_tokens = 512
# Check if we should use manual formatting based on detected template
use_manual = self.chat_template in ("unknown", "jinja_fallback", None) or tools is None
if use_manual:
print(f"DEBUG: Using manual message formatting (template: {self.chat_template})")
prompt = self._manual_format_messages(messages)
return self.generate(prompt, max_tokens, temperature, top_p, stop)
try: try:
response = self.model.create_chat_completion( response = self.model.create_chat_completion(
messages=messages, messages=messages,
...@@ -1293,6 +1364,16 @@ class VulkanBackend(ModelBackend): ...@@ -1293,6 +1364,16 @@ class VulkanBackend(ModelBackend):
total_content = "" total_content = ""
chunk_count = 0 chunk_count = 0
# Check if we should use manual formatting based on detected template
use_manual = self.chat_template in ("unknown", "jinja_fallback", None) or tools is None
if use_manual:
print(f"DEBUG: Using manual message formatting for streaming (template: {self.chat_template})")
prompt = self._manual_format_messages(messages)
async for chunk in self.generate_stream(prompt, max_tokens, temperature, top_p, stop):
yield chunk
return
# Collect all chunks synchronously then yield them # Collect all chunks synchronously then yield them
# This avoids issues with generators across thread boundaries # This avoids issues with generators across thread boundaries
def collect_chunks(): def collect_chunks():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment