Commit ef03dee8 authored by Your Name's avatar Your Name

Enhance --force-reasoning with stop/inject options and add reasoning extraction

- Added --force-reasoning with choices: 'stop', 'inject', 'both' (default)
- Add model-family detection for reasoning stop tokens
- Get appropriate stop tokens for Qwen, DeepSeek, Llama3, Mistral, Gemma, Hermes/Yi
- Add system prompt injection for forcing reasoning on non-native models
- Add extract_reasoning_content() function to parsers for extracting thinking tags
parent ed8397a0
Help on function generate_image in module stable_diffusion_cpp.stable_diffusion:
generate_image(
self,
prompt: str,
negative_prompt: str = '',
clip_skip: int = -1,
init_image: Union[PIL.Image.Image, str, NoneType] = None,
ref_images: Optional[List[Union[PIL.Image.Image, str]]] = None,
auto_resize_ref_image: bool = True,
increase_ref_index: bool = False,
mask_image: Union[PIL.Image.Image, str, NoneType] = None,
width: int = 512,
height: int = 512,
cfg_scale: float = 7.0,
image_cfg_scale: Optional[float] = None,
guidance: float = 3.5,
scheduler: Union[str, stable_diffusion_cpp.stable_diffusion_cpp.Scheduler, int, float, NoneType] = 'default',
sample_method: Union[str, stable_diffusion_cpp.stable_diffusion_cpp.SampleMethod, int, float, NoneType] = 'default',
sample_steps: int = 20,
eta: float = 0.0,
timestep_shift: int = 0,
sigmas: Optional[str] = None,
skip_layers: List[int] = [7, 8, 9],
skip_layer_start: float = 0.01,
skip_layer_end: float = 0.2,
slg_scale: float = 0.0,
strength: float = 0.75,
seed: int = 42,
batch_count: int = 1,
control_image: Union[PIL.Image.Image, str, NoneType] = None,
control_strength: float = 0.9,
pm_id_embed_path: str = '',
pm_id_images: Optional[List[Union[PIL.Image.Image, str]]] = None,
pm_style_strength: float = 20.0,
vae_tiling: bool = False,
vae_tile_overlap: float = 0.5,
vae_tile_size: Union[int, str, NoneType] = '0x0',
vae_relative_tile_size: Union[float, str, NoneType] = '0x0',
cache_mode: Union[str, stable_diffusion_cpp.stable_diffusion_cpp.SDCacheMode, int, float, NoneType] = 'disabled',
cache_reuse_threshold: float = 1.0,
cache_start_percent: float = 0.15,
cache_end_percent: float = 0.95,
cache_error_decay_rate: float = 1.0,
cache_use_relative_threshold: bool = True,
cache_reset_error_on_compute: bool = True,
cache_Fn_compute_blocks: int = 8,
cache_Bn_compute_blocks: int = 0,
cache_residual_diff_threshold: float = 0.08,
cache_max_warmup_steps: int = 8,
cache_max_continuous_cached_steps: int = -1,
cache_taylorseer_n_derivatives: int = 1,
cache_taylorseer_skip_interval: int = 1,
scm_mask: str = '',
scm_policy: Literal['dynamic', 'static'] = 'dynamic',
canny: bool = False,
upscale_factor: int = 1,
preview_method: Union[str, stable_diffusion_cpp.stable_diffusion_cpp.Preview, int, float] = 'none',
preview_noisy: bool = False,
preview_interval: int = 1,
preview_callback: Optional[Callable] = None,
progress_callback: Optional[Callable] = None
) -> List[PIL.Image.Image]
Generate images from a text prompt and or input images.
Args:
prompt: The prompt to render.
negative_prompt: The negative prompt.
clip_skip: Ignore last layers of CLIP network (1 ignores none, 2 ignores one layer, <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x).
init_image: An input image path or Pillow Image to direct the generation.
ref_images: A list of input image paths or Pillow Images for Flux Kontext models (can be used multiple times).
auto_resize_ref_image: Automatically resize reference images.
increase_ref_index: Automatically increase the indices of reference images based on the order they are listed (starting with 1).
mask_image: The inpainting mask image path or Pillow Image.
width: Image width, in pixel space.
height: Image height, in pixel space.
cfg_scale: Unconditional guidance scale.
image_cfg_scale: Image guidance scale for inpaint or instruct-pix2pix models.
guidance: Distilled guidance scale for models with guidance input.
scheduler: Denoiser sigma scheduler (default: discrete).
sample_method: Sampling method (default: euler for Flux/SD3/Wan, euler_a otherwise).
sample_steps: Number of sample steps.
eta: Eta in DDIM, only for DDIM and TCD.
timestep_shift: Shift timestep for NitroFusion models, default: 0, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant.
sigmas: Custom sigma values for the sampler, comma-separated (e.g. "14.61,7.8,3.5,0.0").
skip_layers: Layers to skip for SLG steps (SLG will be enabled at step int([STEPS]x[START]) and disabled at int([STEPS]x[END])).
skip_layer_start: SLG enabling point.
skip_layer_end: SLG disabling point.
slg_scale: Skip layer guidance (SLG) scale, only for DiT models.
strength: Strength for noising/unnoising.
seed: RNG seed (uses random seed for < 0).
batch_count: Number of images to generate.
control_image: A control condition image path or Pillow Image (Control Net).
control_strength: Strength to apply Control Net.
pm_id_embed_path: Path to PhotoMaker v2 id embed.
pm_id_images: A list of input image paths or Pillow Images for PhotoMaker input identity.
pm_style_strength: Strength for keeping PhotoMaker input identity.
vae_tiling: Process vae in tiles to reduce memory usage.
vae_tile_overlap: Tile overlap for vae tiling, in fraction of tile size.
vae_tile_size: Tile size for vae tiling ([X]x[Y] format).
vae_relative_tile_size: Relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 ([X]x[Y] format) (overrides `vae_tile_size`).
cache_mode: The caching method to use (default: disabled).
scm_mask: SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache.
scm_policy: SCM policy 'dynamic' or 'static'.
canny: Apply canny edge detection preprocessor to the `control_image`.
upscale_factor: Run the ESRGAN upscaler this many times.
preview_method: The preview method to use (default: none).
preview_noisy: Enables previewing noisy inputs of the models rather than the denoised outputs.
preview_interval: Interval in denoising steps between consecutive updates of the image preview (default: 1, meaning update at every step)
preview_callback: Callback function to call on each preview frame.
progress_callback: Callback function to call on each step end.
Returns:
A list of Pillow Images.
# codai module - AI model parsing utilities
from .models.parser import (
ModelParserDispatcher,
BaseParser,
QwenParser,
DeepSeekParser,
LlamaParser,
MistralParser,
ClaudeParser,
CommandRParser,
GemmaParser,
GrokParser,
PhiParser,
ApexBig50Parser,
)
from .models.templates import AgenticTemplateManager
# OpenAI-compatible backends
from .openai.litellm import (
LiteLLMBackend,
get_litellm_backend,
set_litellm_backend,
LITELLM_AVAILABLE,
)
__all__ = [
'ModelParserDispatcher',
'BaseParser',
'QwenParser',
'DeepSeekParser',
'LlamaParser',
'MistralParser',
'ClaudeParser',
'CommandRParser',
'GemmaParser',
'GrokParser',
'PhiParser',
'ApexBig50Parser',
'AgenticTemplateManager',
'LiteLLMBackend',
'get_litellm_backend',
'set_litellm_backend',
'LITELLM_AVAILABLE',
]
...@@ -19,7 +19,45 @@ import uuid ...@@ -19,7 +19,45 @@ import uuid
import re import re
import time import time
from difflib import get_close_matches from difflib import get_close_matches
from typing import Dict, List, Any, Optional from typing import Dict, List, Any, Optional, Tuple
def extract_reasoning_content(text: str, model_family: str = None) -> Tuple[str, str]:
"""Extract reasoning/thinking content from model output.
Returns tuple of (reasoning_content, clean_text).
"""
reasoning_content = ""
clean_text = text
# Define reasoning patterns for different model families
patterns = [
(r'<thought>(.*?)</thought>', 'qwen'),
(r'<think>(.*?)</think>', 'qwen'),
(r'<thought>(.*?)</thought>', 'deepseek'),
(r'<thought>(.*?)</thought>', 'llama3'),
(r'<thought>(.*?)</thought>', 'mistral'),
(r'<thought>(.*?)</thought>', 'gemma'),
(r'<\|im_start\|>assistant\n<thought>(.*?)</thought>', 'hermes'),
(r'<thought>(.*?)</thought>', 'generic'),
]
for pattern, _ in patterns:
try:
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
if matches:
reasoning_content = '\n'.join([m.strip() for m in matches if m.strip()])
clean_text = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE).strip()
break
except:
continue
# Cleanup
for p in [r'<thought>.*?</thought>', r'<think>.*?</think>']:
clean_text = re.sub(p, '', clean_text, flags=re.DOTALL | re.IGNORECASE)
return reasoning_content, clean_text
# Try to import litellm for response formatting # Try to import litellm for response formatting
# Fall back to plain dicts if litellm is not available or doesn't export these # Fall back to plain dicts if litellm is not available or doesn't export these
......
...@@ -3667,6 +3667,106 @@ def check_hf_chat_template(model_type: str = "text", model_name: str = None) -> ...@@ -3667,6 +3667,106 @@ def check_hf_chat_template(model_type: str = "text", model_name: str = None) ->
# None = don't inject, True = use default, string = use custom text # None = don't inject, True = use default, string = use custom text
global_system_prompt = None global_system_prompt = None
def get_model_family(model_name: str) -> str:
"""Detect model family from model name."""
model_lower = model_name.lower()
if 'qwen' in model_lower:
return 'qwen'
elif 'deepseek' in model_lower:
return 'deepseek'
elif 'llama-3' in model_lower or 'llama3' in model_lower or 'meta-llama-3' in model_lower:
return 'llama3'
elif 'llama' in model_lower or 'meta-llama' in model_lower:
return 'llama'
elif 'mistral' in model_lower or 'mixtral' in model_lower:
return 'mistral'
elif 'gemma' in model_lower:
return 'gemma'
elif 'yi' in model_lower:
return 'yi'
elif 'hermes' in model_lower:
return 'hermes'
else:
return 'unknown'
def get_reasoning_stop_tokens(model_family: str) -> tuple:
"""Get stop tokens for reasoning mode based on model family.
Returns tuple of (start_token, end_token, additional_stops)
"""
if model_family == 'qwen':
# Qwen uses <|im_start|> format with </tool_call> for thinking
return (
"<|im_start|>assistant\n",
"<|im_end|>",
["<|im_end|>", "<|endoftext|>"]
)
elif model_family == 'deepseek':
return (
"<Assistant>",
"<endofassistant>",
["<endofassistant>", "<User>", "<endoftext>"]
)
elif model_family == 'llama3':
return (
"<|start_header_id|>assistant<|end_header_id|>\n\n<thought>\n",
"</thought>",
["</thought>", "<|eot_id|>", "<|end_of_text|>"]
)
elif model_family == 'llama':
return (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
["<|eot_id|>", "<|end_of_text|>"]
)
elif model_family == 'mistral':
return (
"[/INST] <thought>\n",
"</thought>",
["</thought>", "</INST>", "[INST]"]
)
elif model_family == 'gemma':
return (
"<start_of_turn>model\n<thought>\n",
"</thought>",
["</thought>", "<end_of_turn>", "<start_of_turn>"]
)
elif model_family == 'yi' or model_family == 'hermes':
return (
"<|im_start|>assistant\n",
"<|im_end|>",
["<|im_end|>", "<|endoftext|>"]
)
else:
# Default fallback - try common tokens
return (
"<|im_start|>assistant\n",
"<|im_end|>",
["<|im_end|>", "<|endoftext|>"]
)
def get_reasoning_system_prompt(model_family: str) -> str:
"""Get system prompt injection for forcing reasoning on non-native models."""
if model_family == 'qwen':
return "You must reason step-by-step inside <thought> tags before every response."
elif model_family == 'deepseek':
return "You must reason step-by-step inside <0x00></think> tags before every response."
elif model_family in ('llama3', 'llama'):
return "You must reason step-by-step inside <thought> tags before every response."
elif model_family == 'mistral':
return "You must reason step-by-step inside <thought> tags before every response."
elif model_family == 'gemma':
return "You must reason step-by-step inside <thought> tags before every response. Use <start_of_turn>model for your response."
elif model_family in ('yi', 'hermes'):
return "You must reason step-by-step inside <|im_start|>assistant tags before every response."
else:
return "You must reason step-by-step before every response."
# Global debug flag # Global debug flag
global_debug = False global_debug = False
global_file_path = None global_file_path = None
...@@ -5419,31 +5519,57 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -5419,31 +5519,57 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
messages = [ChatMessage(role="system", content=system_text)] + list(messages) messages = [ChatMessage(role="system", content=system_text)] + list(messages)
# Enable thinking/reasoning mode if requested via API parameter OR CLI flag # Enable thinking/reasoning mode if requested via API parameter OR CLI flag
force_reasoning = getattr(global_args, 'force_reasoning', False) if global_args else False force_reasoning_mode = getattr(global_args, 'force_reasoning', None) if global_args else None
enable_thinking = getattr(request, 'enable_thinking', False) or force_reasoning enable_thinking_api = getattr(request, 'enable_thinking', False)
if enable_thinking:
from codai.models.templates import AgenticTemplateManager # Determine if reasoning should be enabled
template_manager = AgenticTemplateManager(request.model) # Force reasoning if: API param is true OR CLI flag is set (not None)
# Get the current system prompt if exists reasoning_enabled = enable_thinking_api or (force_reasoning_mode is not None)
system_content = None
for msg in messages: # Get model family for reasoning tokens
if msg.role == "system": model_family = get_model_family(request.model)
system_content = msg.content
break # Determine what to do: stop, inject, or both
if system_content: if reasoning_enabled:
# Inject agentic instructions # CLI flag takes precedence if set, otherwise check API param
system_content = template_manager.get_agent_system_prompt(system_content) if force_reasoning_mode:
reasoning_action = force_reasoning_mode # "stop", "inject", or "both"
else: else:
system_content = template_manager.get_agent_system_prompt("You are a helpful assistant.") reasoning_action = "inject" # Default to inject if only API param is set
# Update or add system message
system_found = False # Handle inject (system prompt injection)
for i, msg in enumerate(messages): if reasoning_action in ("inject", "both"):
if msg.role == "system": from codai.models.templates import AgenticTemplateManager
messages[i] = ChatMessage(role="system", content=system_content) template_manager = AgenticTemplateManager(request.model)
system_found = True # Get the current system prompt if exists
break system_content = None
if not system_found: for msg in messages:
messages = [ChatMessage(role="system", content=system_content)] + list(messages) if msg.role == "system":
system_content = msg.content
break
if system_content:
# Inject agentic instructions
system_content = template_manager.get_agent_system_prompt(system_content)
else:
system_content = template_manager.get_agent_system_prompt("You are a helpful assistant.")
# Update or add system message
system_found = False
for i, msg in enumerate(messages):
if msg.role == "system":
messages[i] = ChatMessage(role="system", content=system_content)
system_found = True
break
if not system_found:
messages = [ChatMessage(role="system", content=system_content)] + list(messages)
# Handle stop tokens - add to stop_sequences for generation
if reasoning_action in ("stop", "both"):
_, _, additional_stops = get_reasoning_stop_tokens(model_family)
# Add model-specific stop tokens to the existing stop sequences
for stop_token in additional_stops:
if stop_token not in stop_sequences:
stop_sequences.append(stop_token)
print(f"DEBUG: Added reasoning stop tokens for model family '{model_family}': {additional_stops}")
# Format messages with tools if provided # Format messages with tools if provided
if request.tools: if request.tools:
...@@ -6475,8 +6601,11 @@ def parse_args(): ...@@ -6475,8 +6601,11 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--force-reasoning", "--force-reasoning",
action="store_true", nargs="?",
help="Force reasoning/thinking mode for models that support it (e.g., Qwen3, DeepSeek R1). Enables extraction of reasoning content.", const="both",
default=None,
choices=["both", "stop", "inject"],
help="Force reasoning/thinking mode. Values: 'stop' (add stop tokens), 'inject' (add system prompt), 'both' (default, does both). Use for models like Qwen3, DeepSeek R1, Llama3.1, etc.",
) )
return parser.parse_args() return parser.parse_args()
def main(): def main():
......
This diff is collapsed.
# FastAPI and server dependencies
# CLI dependencies
# PyTorch - Uncomment the appropriate version for your system.
# IMPORTANT: Use quotes around version specifiers to prevent shell interpretation!
# The >= operator will be interpreted as output redirection without quotes!
#
# Option 1: Use exact versions (recommended for requirements.txt)
# Option 2: Use quotes: pip install "torch>=2.0.0"
# For NVIDIA (CUDA):
# torch==2.0.0
torchvision
torchaudio
# For AMD (ROCm) - see available versions at https://pytorch.org/get-started/locally/
# rocm6.0 is recommended for newer AMD GPUs, rocm5.6 for older ones
# --index-url https://download.pytorch.org/whl/rocm6.0
# torch==2.0.0
# torchvision==0.15.0
# torchaudio==2.0.0
# For CPU only:
torch
# ML dependencies
transformers
accelerate
# System resource detection
psutil
# Optional: for better performance
bitsandbytes>=0.41.0 # for 4-bit/8-bit quantization
sentencepiece>=0.1.99 # for some tokenizers
protobuf>=3.20.0 # for some models
# Optional: Flash Attention 2 for faster inference on supported GPUs
# Requires specific CUDA/ROCm versions and may need manual installation
# Install with: pip install flash-attn --no-build-isolation
#flash-attn>=2.5.0
# Installation instructions:
# IMPORTANT: Always use quotes or exact versions to avoid shell redirection issues!
#
# 1. For NVIDIA GPUs (CUDA 12.1):
# pip install torch torchvision torchaudio
#
# 2. For AMD GPUs (ROCm 6.0 recommended):
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0
#
# 3. For CPU only:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
#
# If you see "No such file or directory: '0.0'" errors, you forgot to use quotes!
# The shell interprets >= as redirection. Fix: pip install "torch>=2.0.0" (with quotes)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment