Commit ef949827 authored by Your Name's avatar Your Name

API: validate requested models against CLI-registered models

- Add get_all_allowed_identifiers() to MultiModelManager returning all valid
  model identifiers (default model + short name + aliases, audio, tts, image,
  vision models, and custom aliases)
- Rewrite is_allowed_model() to check against the full allowed set with
  support for prefixed forms and short-name matching
- Add validation in request_model() that rejects unknown models with an error
  message listing all available models
- Fix get_model_for_request() to reject loading arbitrary models not in the
  allowed set
- Update all API endpoints (text, images, tts, transcriptions) to check for
  the error key and return HTTP 404 when a disallowed model is requested
parent b0a633c7
...@@ -492,6 +492,10 @@ async def create_image_generation(request: ImageGenerationRequest, http_request: ...@@ -492,6 +492,10 @@ async def create_image_generation(request: ImageGenerationRequest, http_request:
model_type="image" model_type="image"
) )
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
model_name = model_info['model_name'] model_name = model_info['model_name']
model_key = model_info['model_key'] model_key = model_info['model_key']
pipeline = model_info['model_object'] pipeline = model_info['model_object']
......
...@@ -303,6 +303,10 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request ...@@ -303,6 +303,10 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request
model_type="text" model_type="text"
) )
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
# Try to get the appropriate model (request_model handles VRAM cleanup) # Try to get the appropriate model (request_model handles VRAM cleanup)
mm = multi_model_manager.get_model_for_request(requested_model) mm = multi_model_manager.get_model_for_request(requested_model)
...@@ -1702,6 +1706,10 @@ async def completions(request: CompletionRequest): ...@@ -1702,6 +1706,10 @@ async def completions(request: CompletionRequest):
model_type="text" model_type="text"
) )
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
# Try to get the appropriate model (request_model handles VRAM cleanup) # Try to get the appropriate model (request_model handles VRAM cleanup)
mm = multi_model_manager.get_model_for_request(requested_model) mm = multi_model_manager.get_model_for_request(requested_model)
......
...@@ -60,6 +60,10 @@ async def create_transcription( ...@@ -60,6 +60,10 @@ async def create_transcription(
model_type="audio" model_type="audio"
) )
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
model_name = model_info['model_name'] model_name = model_info['model_name']
model_key = model_info['model_key'] model_key = model_info['model_key']
whisper_model = model_info['model_object'] whisper_model = model_info['model_object']
......
...@@ -59,6 +59,10 @@ async def create_speech(request: TTSRequest): ...@@ -59,6 +59,10 @@ async def create_speech(request: TTSRequest):
model_type="tts" model_type="tts"
) )
# Check if the model was rejected as not allowed
if model_info.get('error'):
raise HTTPException(status_code=404, detail=model_info['error'])
model_name = model_info['model_name'] model_name = model_info['model_name']
model_key = model_info['model_key'] model_key = model_info['model_key']
kokoro_model = model_info['model_object'] kokoro_model = model_info['model_object']
......
...@@ -652,6 +652,100 @@ class MultiModelManager: ...@@ -652,6 +652,100 @@ class MultiModelManager:
"""Register an alias for a model.""" """Register an alias for a model."""
self.model_aliases[alias] = model_name self.model_aliases[alias] = model_name
def get_all_allowed_identifiers(self) -> set:
"""
Return the set of all model names, aliases, and identifiers that are
valid for API requests. This includes every identifier that
``list_models()`` would return as well as the raw model paths/names
registered via the command line.
"""
allowed = set()
# Default / text model
if self.default_model:
allowed.add(self.default_model)
short = self.default_model.split("/")[-1] if "/" in self.default_model else self.default_model
allowed.add(short)
allowed.add("default")
# Audio models
if self.audio_models:
allowed.add("audio")
for m in self.audio_models:
allowed.add(m)
allowed.add(f"audio:{m}")
# TTS model
if self.tts_model:
allowed.add("tts")
allowed.add(self.tts_model)
allowed.add(f"tts:{self.tts_model}")
# Image models
if self.image_models:
allowed.add("image")
for m in self.image_models:
allowed.add(m)
allowed.add(f"image:{m}")
# Vision models
if self.vision_models:
allowed.add("vision")
for m in self.vision_models:
allowed.add(m)
allowed.add(f"vision:{m}")
# Custom aliases
for alias in self.model_aliases:
allowed.add(alias)
return allowed
def is_allowed_model(self, requested_or_resolved: str, model_type: str = None) -> bool:
"""
Check if a model name (raw request value *or* resolved name) is one of
the models registered via the command line or their aliases.
Args:
requested_or_resolved: The model name/path/alias to check.
model_type: Optional type hint ("image", "text", "audio", "tts",
"vision"). When provided the check is scoped to that
type first; if it fails it still falls back to the
full allowed set.
Returns:
True if the model is registered/allowed, False otherwise.
"""
if not requested_or_resolved:
return False
# Quick check against the full set of allowed identifiers
allowed = self.get_all_allowed_identifiers()
if requested_or_resolved in allowed:
return True
# Also accept prefixed forms that the caller may have built
if model_type and model_type != "text":
prefixed = f"{model_type}:{requested_or_resolved}"
if prefixed in allowed:
return True
# Short-name match: compare the last path component of the request
# against all allowed identifiers
req_short = requested_or_resolved.split("/")[-1] if "/" in requested_or_resolved else None
if req_short and req_short in allowed:
return True
# Reverse short-name match: check if any allowed identifier ends with
# the same filename as the request
if req_short:
for a in allowed:
a_short = a.split("/")[-1] if "/" in a else a
if a_short == req_short:
return True
return False
def get_model_for_request(self, requested_model: str): def get_model_for_request(self, requested_model: str):
"""Get the appropriate model manager for a request based on model name.""" """Get the appropriate model manager for a request based on model name."""
global global_args global global_args
...@@ -744,7 +838,13 @@ class MultiModelManager: ...@@ -744,7 +838,13 @@ class MultiModelManager:
self.current_model_key = key self.current_model_key = key
return model return model
# Model not found - try to load it as a new model # Validate the model is allowed before attempting to load it.
# This prevents loading arbitrary models not registered via command line.
if not self.is_allowed_model(requested_model):
print(f"Model '{requested_model}' is not an allowed model. Rejecting request.")
return None
# Model not found but allowed - try to load it
return self._load_model_by_name(requested_model) return self._load_model_by_name(requested_model)
def resolve_model_name(self, requested_model: str) -> Optional[str]: def resolve_model_name(self, requested_model: str) -> Optional[str]:
...@@ -1079,6 +1179,25 @@ class MultiModelManager: ...@@ -1079,6 +1179,25 @@ class MultiModelManager:
'already_loaded': False, 'already_loaded': False,
} }
# Step 1b: Validate that the resolved model is an allowed/registered model.
# This prevents API callers from requesting arbitrary models that were not
# specified on the command line (or registered as aliases).
if not self.is_allowed_model(resolved_name, model_type):
# Also try the original requested_model value (before alias resolution)
# in case the caller used a valid alias that resolved to something we
# didn't recognise above (shouldn't happen, but be safe).
allowed_ids = sorted(self.get_all_allowed_identifiers())
print(f"Model validation failed: '{resolved_name}' is not an allowed model. "
f"Allowed models: {allowed_ids}")
return {
'model_key': None,
'model_name': None,
'model_object': None,
'config': {},
'already_loaded': False,
'error': f"Model '{resolved_name}' is not available. Use one of: {', '.join(allowed_ids)}",
}
# Step 2: Build the model key (prefixed with type) # Step 2: Build the model key (prefixed with type)
if model_type and model_type != "text": if model_type and model_type != "text":
model_key = f"{model_type}:{resolved_name}" model_key = f"{model_type}:{resolved_name}"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment