Commit e7f781f3 authored by Your Name's avatar Your Name

Fix get_reasoning_stop_tokens to return 3 values

parent 8e072ebb
...@@ -22,20 +22,87 @@ def get_model_family(model_name: str) -> str: ...@@ -22,20 +22,87 @@ def get_model_family(model_name: str) -> str:
return 'llama' return 'llama'
if 'mistral' in model_lower: if 'mistral' in model_lower:
return 'mistral' return 'mistral'
if 'deepseek' in model_lower:
return 'deepseek'
if 'gemma' in model_lower:
return 'gemma'
if 'yi' in model_lower:
return 'yi'
if 'hermes' in model_lower:
return 'hermes'
return 'generic' return 'generic'
def get_reasoning_stop_tokens(model_family: str) -> tuple: def get_reasoning_stop_tokens(model_family: str) -> tuple:
"""Get stop tokens for reasoning mode based on model family.""" """Get stop tokens for reasoning mode based on model family.
Returns tuple of (start_token, end_token, additional_stops)
"""
if model_family == 'qwen': if model_family == 'qwen':
return ('<|im_end|>', '<|endoftext|>') return (
if model_family == 'deepseek': "<|im_start|>assistant\n",
return ('</Thinking>',) "<|im_end|>",
return ('<|end|>',) ["<|im_end|>", "<|endoftext|>"]
)
elif model_family == 'deepseek':
return (
"<|Assistant|>",
"<|endofassistant|>",
["<|endofassistant|>", "<|User|>", "<|endoftext|>"]
)
elif model_family == 'llama3':
return (
"<|start_header_id|>assistant<|end_header_id|>\n\n<thought>\n",
"</thought>",
["</thought>", "<|eot_id|>", "<|end_of_text|>"]
)
elif model_family == 'llama':
return (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
["<|eot_id|>", "<|end_of_text|>"]
)
elif model_family == 'mistral':
return (
"[/INST] <thought>\n",
"</thought>",
["</thought>", "</INST>", "[INST]"]
)
elif model_family == 'gemma':
return (
"<start_of_turn>model\n<thought>\n",
"</thought>",
["</thought>", "<end_of_turn>", "<start_of_turn>"]
)
elif model_family == 'yi' or model_family == 'hermes':
return (
"<|im_start|>assistant\n",
"<|im_end|>",
["<|im_end|>", "<|endoftext|>"]
)
else:
# Default fallback - try common tokens
return (
"<|im_start|>assistant\n",
"<|im_end|>",
["<|im_end|>", "<|endoftext|>"]
)
def get_reasoning_system_prompt(model_family: str) -> str: def get_reasoning_system_prompt(model_family: str) -> str:
"""Get the system prompt injection for forcing reasoning on non-native models.""" """Get system prompt injection for forcing reasoning on non-native models."""
if model_family == 'qwen': if model_family == 'qwen':
return "Please think carefully before responding." return "You must reason step-by-step inside <thought> tags before every response."
return "" elif model_family == 'deepseek':
return "You must reason step-by-step inside <thought> tags before every response."
elif model_family in ('llama3', 'llama'):
return "You must reason step-by-step inside <thought> tags before every response."
elif model_family == 'mistral':
return "You must reason step-by-step inside <thought> tags before every response."
elif model_family == 'gemma':
return "You must reason step-by-step inside <thought> tags before every response. Use <start_of_turn>model for your response."
elif model_family in ('yi', 'hermes'):
return "You must reason step-by-step inside <|im_start|>assistant tags before every response."
else:
return "You must reason step-by-step before every response."
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment