Commit 895e94ca authored by Your Name's avatar Your Name

Remove all GGML_VK_VISIBLE_DEVICES environment variable handling - user sets it externally

parent a4674d60
......@@ -27,10 +27,6 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, validator, field_validator, ConfigDict
from pydantic_core import PydanticCustomError
from threading import Thread
# Per-model semaphores for request concurrency control
model_semaphores: dict = {}
load_mode = {"mode": "ondemand"} # Track load mode globally
......@@ -61,14 +57,10 @@ def get_cached_model_path(url: str) -> Optional[str]:
print(f"Using cached model: {cached_path}")
return cached_path
return None
def is_huggingface_model_id(path: str) -> bool:
"""Check if the path is a Hugging Face model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507-Q3_K_S')."""
# Must contain / but not be a URL
return '/' in path and not path.startswith('http://') and not path.startswith('https://')
def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str = '.gguf') -> Optional[str]:
"""Download a model from Hugging Face by model ID. Returns cached path or None on failure."""
try:
......@@ -92,8 +84,6 @@ def download_huggingface_model(model_id: str, cache_dir: str, file_pattern: str
except Exception as e:
print(f"Error downloading from Hugging Face: {e}")
return None
def download_model(url: str, cache_dir: str) -> str:
"""Download a model from URL with progress reporting. Returns cached path."""
import requests
......@@ -155,8 +145,6 @@ def download_model(url: str, cache_dir: str) -> str:
print(f"File size: {total_mb:.1f} MB")
return model_path
# =============================================================================
# Backend Detection and Imports
# =============================================================================
......@@ -181,8 +169,6 @@ def detect_available_backends():
pass
return backends
# =============================================================================
# Flash Attention Detection (for NVIDIA backend)
# =============================================================================
......@@ -194,8 +180,6 @@ def check_flash_attn_availability() -> bool:
return True
except ImportError:
return False
# =============================================================================
# Pydantic Models for API
# =============================================================================
......@@ -204,13 +188,9 @@ class ToolFunction(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict] = None
class Tool(BaseModel):
type: str = "function"
function: ToolFunction
class ChatMessage(BaseModel):
role: str
content: Optional[Union[str, List[Dict]]] = None
......@@ -241,8 +221,6 @@ class ChatMessage(BaseModel):
parts.append(str(item))
return '\n'.join(parts)
return str(v)
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
......@@ -264,8 +242,6 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
......@@ -286,20 +262,14 @@ class CompletionRequest(BaseModel):
user: Optional[str] = None
model_config = ConfigDict(extra="allow") # Allow extra fields to prevent 422 errors
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "huggingface"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelInfo]
# =============================================================================
# Audio Transcription Models
# =============================================================================
......@@ -315,13 +285,9 @@ class TranscriptionRequest(BaseModel):
timestamp_granularities: Optional[List[str]] = None
model_config = ConfigDict(extra="allow")
class TranscriptionResponse(BaseModel):
text: str
model_config = ConfigDict(extra="allow")
# =============================================================================
# Image Generation Models
# =============================================================================
......@@ -338,14 +304,10 @@ class ImageGenerationRequest(BaseModel):
user: Optional[str] = None
model_config = ConfigDict(extra="allow")
class ImageGenerationResponse(BaseModel):
created: int
data: List[Dict]
model_config = ConfigDict(extra="allow")
# =============================================================================
# Content Filtering Utility
# =============================================================================
......@@ -374,8 +336,6 @@ def filter_malformed_content(text: str) -> str:
# Don't strip single newlines or whitespace - they might be valid content
return filtered
# =============================================================================
# Tool Parsing
# =============================================================================
......@@ -593,8 +553,6 @@ class ToolCallParser:
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> List[ChatMessage]:
"""Format tools into the system message or add a tool description."""
if not tools:
......@@ -639,8 +597,6 @@ def format_tools_for_prompt(tools: List[Tool], messages: List[ChatMessage]) -> L
new_messages.insert(0, ChatMessage(role="system", content=tools_text))
return new_messages
# =============================================================================
# Abstract Model Backend
# =============================================================================
......@@ -681,8 +637,6 @@ class ModelBackend(ABC):
def cleanup(self) -> None:
"""Cleanup resources."""
pass
# =============================================================================
# NVIDIA/HuggingFace Backend
# =============================================================================
......@@ -1272,8 +1226,6 @@ class NvidiaBackend(ModelBackend):
self.tokenizer = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
# =============================================================================
# Vulkan Backend (llama-cpp-python)
# =============================================================================
......@@ -1532,12 +1484,6 @@ class VulkanBackend(ModelBackend):
print(f"DEBUG: Detected {num_devices} Vulkan GPU devices")
# Also try to set GGML_VK_VISIBLE_DEVICES env var to force the device
# This affects which GPU does the actual computation
if main_gpu >= 0:
os.environ['GGML_VK_VISIBLE_DEVICES'] = str(main_gpu)
print(f"DEBUG: Set GGML_VK_VISIBLE_DEVICES={main_gpu}")
if single_gpu:
# Build tensor_split to force all layers onto one GPU
# tensor_split is a list where index = GPU device, value = weight (0.0 = don't use)
......@@ -1844,8 +1790,6 @@ class VulkanBackend(ModelBackend):
if self.model is not None:
del self.model
self.model = None
# =============================================================================
# Model Manager
# =============================================================================
......@@ -1976,8 +1920,6 @@ class ModelManager:
if self.backend is not None:
self.backend.cleanup()
self.backend = None
# =============================================================================
# Whisper Server Manager - manages whisper-server subprocess
# =============================================================================
......@@ -1987,8 +1929,6 @@ import signal
import requests
import time
import threading
class WhisperServerManager:
"""Manages whisper-server subprocess for audio transcription with model swapping support."""
......@@ -2167,8 +2107,6 @@ class WhisperServerManager:
"model": self.current_model,
"url": self.base_url
}
# =============================================================================
# Multi-Model Manager (supports audio transcription and image generation)
# =============================================================================
......@@ -2467,12 +2405,8 @@ class MultiModelManager:
for model in self.models.values():
model.cleanup()
self.models.clear()
# Global multi-model manager
multi_model_manager = MultiModelManager()
# Global model manager (for backward compatibility)
model_manager = ModelManager()
......@@ -2545,12 +2479,8 @@ class QueueManager:
return keys.index(request_id) + 1
except ValueError:
return 0
# Global queue manager
queue_manager = QueueManager()
# =============================================================================
# FastAPI Application
# =============================================================================
......@@ -2566,8 +2496,6 @@ async def lifespan(app: FastAPI):
# Stop whisper-server if running
if multi_model_manager.whisper_server:
multi_model_manager.whisper_server.stop()
app = FastAPI(
title="OpenAI-Compatible API",
description="OpenAI-compatible API supporting NVIDIA (CUDA) and Vulkan backends",
......@@ -2705,15 +2633,11 @@ async def log_requests(request: Request, call_next):
finally:
if request.url.path in ["/v1/chat/completions", "/v1/completions"]:
pass # End logging already done above for successful responses
@app.get("/v1/models", response_model=ModelList)
async def list_models():
"""List available models."""
models = multi_model_manager.list_models()
return ModelList(data=models)
# =============================================================================
# Audio Transcription Endpoint
# =============================================================================
......@@ -2778,18 +2702,17 @@ async def create_transcription(
# Check if Vulkan is available for whispercpp
whisper_vulkan_available = False
whisper_vulkan_device = os.environ.get('GGML_VK_VISIBLE_DEVICES', '0')
try:
# Check if whispercpp is installed and has Vulkan support
import whispercpp
# Try to detect Vulkan support by checking if we can list devices
# whispercpp doesn't have a direct Vulkan check, but we can verify by environment
if os.environ.get('GGML_VK_VISIBLE_DEVICES') or os.environ.get('VK_DEVICE_SELECT_DEVICE'):
if os.environ.get('VK_DEVICE_SELECT_DEVICE'):
whisper_vulkan_available = True
print(f"Whisper Vulkan: Using GPU device {whisper_vulkan_device}")
print(f"Whisper Vulkan: Using configured Vulkan device")
elif os.path.exists('/dev/dri'): # Linux DRM devices exist = AMD/Intel GPU
whisper_vulkan_available = True
print(f"Whisper Vulkan: Auto-detected GPU, using device {whisper_vulkan_device}")
print(f"Whisper Vulkan: Auto-detected GPU")
except ImportError:
pass
......@@ -3131,8 +3054,6 @@ async def create_transcription(
finally:
# Cleanup temp file
os.unlink(tmp_path)
# =============================================================================
# Image Generation Endpoint
# =============================================================================
......@@ -3378,8 +3299,6 @@ async def create_image_generation(request: ImageGenerationRequest):
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Image generation error: {str(e)}")
# =============================================================================
# Text-to-Speech Endpoint
# =============================================================================
......@@ -3392,13 +3311,9 @@ class TTSRequest(BaseModel):
speed: Optional[float] = 1.0
model_config = ConfigDict(extra="allow")
class TTSResponse(BaseModel):
audio: str # base64 encoded audio
model_config = ConfigDict(extra="allow")
@app.post("/v1/audio/speech")
async def create_speech(request: TTSRequest):
"""
......@@ -3515,8 +3430,6 @@ async def create_speech(request: TTSRequest):
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"TTS error: {str(e)}")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completions endpoint with streaming and tool support."""
......@@ -3877,8 +3790,6 @@ async def stream_chat_response(
finally:
# Always clean up queue state
await queue_manager.finish_processing()
async def generate_chat_response(
messages: List[Dict],
model_name: str,
......@@ -3958,8 +3869,6 @@ async def generate_chat_response(
except Exception as e:
print(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
@app.post("/v1/completions")
async def completions(request: CompletionRequest):
"""Text completions endpoint."""
......@@ -4007,8 +3916,6 @@ async def completions(request: CompletionRequest):
stop_sequences,
current_manager,
)
async def stream_completion_response(
prompt: str,
model_name: str,
......@@ -4050,8 +3957,6 @@ async def stream_completion_response(
print(f"Error during streaming completion: {e}")
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
async def generate_completion_response(
prompt: str,
model_name: str,
......@@ -4102,8 +4007,6 @@ async def generate_completion_response(
except Exception as e:
print(f"Error during completion: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
# =============================================================================
# Main Entry Point
# =============================================================================
......@@ -4384,8 +4287,6 @@ def parse_args():
help="Enable debug mode - dumps full request/response to stdout for troubleshooting",
)
return parser.parse_args()
def main():
"""Main entry point."""
global global_system_prompt, model_manager, multi_model_manager, global_debug, global_args
......@@ -4814,10 +4715,9 @@ def main():
print(f"llama.cpp load error: {llama_error}")
print(f"Trying stable-diffusion-cpp-python fallback...")
# Try stable-diffusion-cpp-python as fallback
# Set Vulkan device for image models (GGML_VK_VISIBLE_DEVICES=1 for GPU1)
if args.image_vulkan_device is not None:
os.environ['GGML_VK_VISIBLE_DEVICES'] = str(args.image_vulkan_device)
print(f"Setting GGML_VK_VISIBLE_DEVICES={args.image_vulkan_device} for image model (sd.cpp)")
try:
from stable_diffusion_cpp import StableDiffusion
......@@ -5137,10 +5037,10 @@ def main():
# Check if Vulkan is available for whispercpp
whisper_vulkan_available = False
whisper_vulkan_device = os.environ.get('GGML_VK_VISIBLE_DEVICES', '0')
whisper_vulkan_device = os.environ.get('VK_DEVICE_SELECT_DEVICE', '0')
try:
import whispercpp
if os.environ.get('GGML_VK_VISIBLE_DEVICES') or os.environ.get('VK_DEVICE_SELECT_DEVICE'):
if os.environ.get('VK_DEVICE_SELECT_DEVICE'):
whisper_vulkan_available = True
print(f"Whisper Vulkan: Will use GPU device {whisper_vulkan_device}")
elif os.path.exists('/dev/dri'):
......@@ -5346,10 +5246,9 @@ def main():
print(f"llama.cpp load error: {llama_error}")
print(f"Trying stable-diffusion-cpp-python fallback...")
# Try stable-diffusion-cpp-python as fallback
# Set Vulkan device for image models (GGML_VK_VISIBLE_DEVICES=1 for GPU1)
if args.image_vulkan_device is not None:
os.environ['GGML_VK_VISIBLE_DEVICES'] = str(args.image_vulkan_device)
print(f"Setting GGML_VK_VISIBLE_DEVICES={args.image_vulkan_device} for image model (sd.cpp)")
try:
from stable_diffusion_cpp import StableDiffusion
......@@ -5570,7 +5469,5 @@ def main():
print(f"Available models: {[m.id for m in models]}")
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment