Commit c152ee28 authored by Your Name's avatar Your Name

Add --vision-model, fix --file-path to return URL by default, add HTTPS support

- Add --vision-model for image/video-to-text models
- When --file-path is set, return URL by default, base64 only if explicitly requested
- Add --https flag with auto-certificate generation
- Add --privkey and --pubkey for custom certificates
parent 29d2ed78
...@@ -2128,6 +2128,7 @@ class MultiModelManager: ...@@ -2128,6 +2128,7 @@ class MultiModelManager:
self.audio_models: List[str] = [] # List of audio model names self.audio_models: List[str] = [] # List of audio model names
self.tts_model: Optional[str] = None self.tts_model: Optional[str] = None
self.image_models: List[str] = [] # List of image model names self.image_models: List[str] = [] # List of image model names
self.vision_models: List[str] = [] # List of vision (image/video to text) model names
self.tool_parser = ToolCallParser() self.tool_parser = ToolCallParser()
self.current_model_key: Optional[str] = None self.current_model_key: Optional[str] = None
# Configuration for each model type # Configuration for each model type
...@@ -2149,6 +2150,11 @@ class MultiModelManager: ...@@ -2149,6 +2150,11 @@ class MultiModelManager:
def image_model(self) -> Optional[str]: def image_model(self) -> Optional[str]:
"""Get the first/default image model.""" """Get the first/default image model."""
return self.image_models[0] if self.image_models else None return self.image_models[0] if self.image_models else None
@property
def vision_model(self) -> Optional[str]:
"""Get the first/default vision model."""
return self.vision_models[0] if self.vision_models else None
def set_load_mode(self, mode: str): def set_load_mode(self, mode: str):
"""Set the load mode: 'ondemand', 'loadall', or 'loadswap'.""" """Set the load mode: 'ondemand', 'loadall', or 'loadswap'."""
...@@ -2176,6 +2182,12 @@ class MultiModelManager: ...@@ -2176,6 +2182,12 @@ class MultiModelManager:
self.image_models.append(model_name) self.image_models.append(model_name)
self.config[f"image:{model_name}"] = config or {} self.config[f"image:{model_name}"] = config or {}
def set_vision_model(self, model_name: str, config: Dict = None):
"""Add a vision (image/video to text) model."""
if model_name not in self.vision_models:
self.vision_models.append(model_name)
self.config[f"vision:{model_name}"] = config or {}
def set_model_alias(self, alias: str, model_name: str): def set_model_alias(self, alias: str, model_name: str):
"""Register an alias for a model.""" """Register an alias for a model."""
self.model_aliases[alias] = model_name self.model_aliases[alias] = model_name
...@@ -3131,7 +3143,11 @@ def get_cfg_scale(): ...@@ -3131,7 +3143,11 @@ def get_cfg_scale():
# Helper function to save generated images and return response dict # Helper function to save generated images and return response dict
def save_image_response(img, request_format="base64"): def save_image_response(img, request_format="base64"):
""" """
Save image to file path if configured, return response dict with b64_json and optional url. Save image to file path if configured, return response dict.
If --file-path is set and request_format is url (not base64), return only URL.
If --file-path is set and request_format is base64, return both URL and base64.
If --file-path is not set, return base64 as usual.
""" """
import base64 import base64
import io import io
...@@ -3143,13 +3159,7 @@ def save_image_response(img, request_format="base64"): ...@@ -3143,13 +3159,7 @@ def save_image_response(img, request_format="base64"):
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
img = Image.fromarray(img) img = Image.fromarray(img)
# Convert to base64 result = {}
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
result = {"b64_json": img_base64}
# Save to file path if configured # Save to file path if configured
if global_file_path: if global_file_path:
...@@ -3160,6 +3170,22 @@ def save_image_response(img, request_format="base64"): ...@@ -3160,6 +3170,22 @@ def save_image_response(img, request_format="base64"):
img.save(file_path, format="PNG") img.save(file_path, format="PNG")
# Add URL to response # Add URL to response
result["url"] = f"/v1/files/{filename}" result["url"] = f"/v1/files/{filename}"
# If client explicitly requested base64, include it
# Otherwise, only return URL when file-path is set
if request_format == "base64":
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
result["b64_json"] = img_base64
else:
# No file-path, return base64 as usual
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
result["b64_json"] = img_base64
return result return result
...@@ -4275,6 +4301,23 @@ def parse_args(): ...@@ -4275,6 +4301,23 @@ def parse_args():
default=8000, default=8000,
help="Port to bind to (default: 8000)", help="Port to bind to (default: 8000)",
) )
parser.add_argument(
"--https",
action="store_true",
help="Enable HTTPS with auto-generated certificate",
)
parser.add_argument(
"--privkey",
type=str,
default=None,
help="Path to HTTPS private key file",
)
parser.add_argument(
"--pubkey",
type=str,
default=None,
help="Path to HTTPS certificate file",
)
parser.add_argument( parser.add_argument(
"--offload-dir", "--offload-dir",
type=str, type=str,
...@@ -4376,6 +4419,13 @@ def parse_args(): ...@@ -4376,6 +4419,13 @@ def parse_args():
default=None, default=None,
help="Model for image generation (e.g., stable-diffusion-xl-base-1.0). Can be specified multiple times for multiple models.", help="Model for image generation (e.g., stable-diffusion-xl-base-1.0). Can be specified multiple times for multiple models.",
) )
parser.add_argument(
"--vision-model",
type=str,
action="append",
default=None,
help="Model for image/video-to-text (e.g., llava-1.5, LLaVA). Supports vulkan and cuda backends.",
)
parser.add_argument( parser.add_argument(
"--image-1", "--image-1",
action="store_true", action="store_true",
...@@ -4683,9 +4733,10 @@ def main(): ...@@ -4683,9 +4733,10 @@ def main():
# Validate: must have at least one model specified # Validate: must have at least one model specified
audio_models = args.audio_model if args.audio_model else [] audio_models = args.audio_model if args.audio_model else []
image_models = args.image_model if args.image_model else [] image_models = args.image_model if args.image_model else []
vision_models = args.vision_model if args.vision_model else []
if not model_names and not audio_models and not image_models and args.tts_model is None: if not model_names and not audio_models and not image_models and not vision_models and args.tts_model is None:
print("Error: At least one of --model, --audio-model, --image-model, or --tts-model must be specified.") print("Error: At least one of --model, --audio-model, --image-model, --vision-model, or --tts-model must be specified.")
print("") print("")
print("For NVIDIA backend (HuggingFace models):") print("For NVIDIA backend (HuggingFace models):")
print(" - microsoft/DialoGPT-medium") print(" - microsoft/DialoGPT-medium")
...@@ -5750,6 +5801,48 @@ def main(): ...@@ -5750,6 +5801,48 @@ def main():
models = multi_model_manager.list_models() models = multi_model_manager.list_models()
print(f"Available models: {[m.id for m in models]}") print(f"Available models: {[m.id for m in models]}")
uvicorn.run(app, host=args.host, port=args.port) # Run server with or without HTTPS
if args.https:
import ssl
import os
# Determine SSL context
ssl_keyfile = None
ssl_certfile = None
if args.privkey and args.pubkey:
# Use provided certificates
ssl_keyfile = args.privkey
ssl_certfile = args.pubkey
print(f"Using HTTPS with custom certificates: {args.pubkey}")
else:
# Auto-generate self-signed certificate
print("Generating self-signed HTTPS certificate...")
import subprocess
try:
# Generate self-signed cert
cert_path = "./cert.pem"
key_path = "./key.pem"
subprocess.run([
"openssl", "req", "-x509", "-newkey", "rsa:4096",
"-keyout", key_path, "-out", cert_path,
"-days", "365", "-nodes",
"-subj", "/CN=localhost"
], check=True, capture_output=True)
ssl_keyfile = key_path
ssl_certfile = cert_path
print(f"Generated self-signed certificate: {cert_path}")
except Exception as e:
print(f"Warning: Could not generate certificate: {e}")
print("Falling back to HTTP...")
uvicorn.run(app, host=args.host, port=args.port)
return
# Run with HTTPS
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
uvicorn.run(app, host=args.host, port=args.port, ssl=ssl_context)
else:
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment