Implement dynamic model management system

- Add vidai/models.py with abstract base classes for different model types
- Support multiple HuggingFace transformers models (vision-language, text-only)
- Auto-detect model types from model names/paths
- Add model type configuration options (--model-type)
- Update worker_analysis.py to use dynamic model loading instead of hardcoded Qwen
- Maintain backward compatibility with existing model configurations
parent 9f967d46
...@@ -33,6 +33,7 @@ sys.path.insert(0, os.path.dirname(__file__)) ...@@ -33,6 +33,7 @@ sys.path.insert(0, os.path.dirname(__file__))
from vidai.config import ( from vidai.config import (
get_config, set_config, get_default_model, set_default_model, get_config, set_config, get_default_model, set_default_model,
get_default_model_type, set_default_model_type,
get_analysis_backend, set_analysis_backend, get_training_backend, set_training_backend, get_analysis_backend, set_analysis_backend, get_training_backend, set_training_backend,
get_optimize, set_optimize, get_ffmpeg, set_ffmpeg, get_flash, set_flash, get_optimize, set_optimize, get_ffmpeg, set_ffmpeg, get_flash, set_flash,
get_host, set_host, get_port, set_port, get_debug, set_debug, get_allowed_dir, set_allowed_dir, get_host, set_host, get_port, set_port, get_debug, set_debug, get_allowed_dir, set_allowed_dir,
...@@ -58,6 +59,7 @@ Examples: ...@@ -58,6 +59,7 @@ Examples:
# Read defaults from config # Read defaults from config
default_model = get_default_model() default_model = get_default_model()
default_model_type = get_default_model_type()
default_analysis_backend = get_analysis_backend() default_analysis_backend = get_analysis_backend()
default_training_backend = get_training_backend() default_training_backend = get_training_backend()
default_optimize = get_optimize() default_optimize = get_optimize()
...@@ -78,6 +80,13 @@ Examples: ...@@ -78,6 +80,13 @@ Examples:
help=f'Default model path or HuggingFace model name (default: {default_model})' help=f'Default model path or HuggingFace model name (default: {default_model})'
) )
parser.add_argument(
'--model-type',
choices=['auto', 'qwen2.5-vl', 'qwen-vl', 'text-only', 'llama', 'mistral', 'gpt'],
default=default_model_type,
help=f'Default model type for auto-detection (default: {default_model_type})'
)
parser.add_argument( parser.add_argument(
'--dir', '--dir',
default=get_config('allowed_dir', ''), default=get_config('allowed_dir', ''),
...@@ -210,6 +219,7 @@ Examples: ...@@ -210,6 +219,7 @@ Examples:
# Update config with command line values # Update config with command line values
set_default_model(args.model) set_default_model(args.model)
set_default_model_type(args.model_type)
set_allowed_dir(args.dir) set_allowed_dir(args.dir)
set_optimize(args.optimize) set_optimize(args.optimize)
set_ffmpeg(args.ffmpeg) set_ffmpeg(args.ffmpeg)
......
...@@ -103,6 +103,16 @@ def set_default_model(model: str) -> None: ...@@ -103,6 +103,16 @@ def set_default_model(model: str) -> None:
set_config('default_model', model) set_config('default_model', model)
def get_default_model_type() -> str:
"""Get the default model type."""
return get_config('default_model_type', 'auto')
def set_default_model_type(model_type: str) -> None:
"""Set the default model path."""
set_config('default_model', model)
def get_frame_interval() -> int: def get_frame_interval() -> int:
"""Get the default frame interval.""" """Get the default frame interval."""
return int(get_config('frame_interval', '10')) return int(get_config('frame_interval', '10'))
......
# Video AI Model Management Module
# Copyright (C) 2024 Stefy Lanza <stefy@sexhack.me>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Dynamic model management for Video AI.
Supports multiple HuggingFace transformers models with different architectures.
"""
import os
import torch
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
class BaseModel(ABC):
"""Abstract base class for AI models."""
def __init__(self, model_path: str, device: str = "auto", **kwargs):
self.model_path = model_path
self.device = device
self.model = None
self.processor = None
self.tokenizer = None
self.kwargs = kwargs
@abstractmethod
def load_model(self) -> None:
"""Load the model and processor."""
pass
@abstractmethod
def generate(self, inputs: Dict[str, Any], **kwargs) -> str:
"""Generate response from inputs."""
pass
@abstractmethod
def supports_vision(self) -> bool:
"""Check if model supports vision inputs."""
return False
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model:
del self.model
self.model = None
if self.processor:
del self.processor
self.processor = None
if self.tokenizer:
del self.tokenizer
self.tokenizer = None
torch.cuda.empty_cache()
class VisionLanguageModel(BaseModel):
"""Vision-language model like Qwen2.5-VL."""
def supports_vision(self) -> bool:
return True
def load_model(self) -> None:
"""Load Qwen2.5-VL model."""
from transformers import Qwen2_5_VLForConditionalGeneration
kwargs = {"device_map": "auto", "low_cpu_mem_usage": True, **self.kwargs}
# Try Flash Attention if requested
if os.environ.get('VIDAI_FLASH', '').lower() == 'true':
try:
import flash_attn
kwargs["attn_implementation"] = "flash_attention_2"
kwargs["dtype"] = torch.float16
except ImportError:
pass
if os.path.exists(self.model_path):
try:
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_path, **kwargs
)
proc_path = self.model_path
except:
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct", **kwargs
)
proc_path = "Qwen/Qwen2.5-VL-7B-Instruct"
else:
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_path, **kwargs
)
proc_path = self.model_path
self.processor = AutoProcessor.from_pretrained(proc_path)
def generate(self, inputs: Dict[str, Any], **kwargs) -> str:
"""Generate response for vision-language inputs."""
if not self.model or not self.processor:
raise RuntimeError("Model not loaded")
# Prepare inputs
if isinstance(inputs, dict) and 'messages' in inputs:
# Chat format
messages = inputs['messages']
model_inputs = self.processor.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True,
return_dict=True, return_tensors="pt"
)
else:
# Direct inputs
model_inputs = inputs
model_inputs = {k: v.to(self.model.device) for k, v in model_inputs.items()}
# Generate
gen_kwargs = {"max_new_tokens": 128, **kwargs}
generated_ids = self.model.generate(**model_inputs, **gen_kwargs)
# Extract new tokens
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs['input_ids'], generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
class TextOnlyModel(BaseModel):
"""Text-only language model."""
def supports_vision(self) -> bool:
return False
def load_model(self) -> None:
"""Load text-only model."""
kwargs = {"device_map": "auto", "low_cpu_mem_usage": True, **self.kwargs}
if os.environ.get('VIDAI_FLASH', '').lower() == 'true':
try:
import flash_attn
kwargs["attn_implementation"] = "flash_attention_2"
kwargs["dtype"] = torch.float16
except ImportError:
pass
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
def generate(self, inputs: Dict[str, Any], **kwargs) -> str:
"""Generate response for text inputs."""
if not self.model or not self.tokenizer:
raise RuntimeError("Model not loaded")
if isinstance(inputs, str):
# Simple text input
model_inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
elif isinstance(inputs, dict):
model_inputs = {k: v.to(self.model.device) if torch.is_tensor(v) else v
for k, v in inputs.items()}
else:
raise ValueError("Unsupported input format")
gen_kwargs = {"max_new_tokens": 128, **kwargs}
generated_ids = self.model.generate(**model_inputs, **gen_kwargs)
# Extract new tokens
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs['input_ids'], generated_ids)
]
output_text = self.tokenizer.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
class ModelFactory:
"""Factory for creating model instances."""
_model_types = {
'qwen2.5-vl': VisionLanguageModel,
'qwen-vl': VisionLanguageModel, # Backward compatibility
'text-only': TextOnlyModel,
'llama': TextOnlyModel,
'mistral': TextOnlyModel,
'gpt': TextOnlyModel,
}
@classmethod
def create_model(cls, model_type: str, model_path: str, **kwargs) -> BaseModel:
"""Create a model instance based on type."""
if model_type not in cls._model_types:
# Try to infer from model path/name
if 'qwen' in model_path.lower() and ('vl' in model_path.lower() or 'vision' in model_path.lower()):
model_type = 'qwen2.5-vl'
else:
model_type = 'text-only'
model_class = cls._model_types[model_type]
return model_class(model_path, **kwargs)
@classmethod
def register_model_type(cls, name: str, model_class: type) -> None:
"""Register a new model type."""
cls._model_types[name] = model_class
# Global model cache
_model_cache: Dict[str, BaseModel] = {}
def get_model(model_path: str, model_type: str = None, **kwargs) -> BaseModel:
"""Get or create a cached model instance."""
cache_key = f"{model_type or 'auto'}:{model_path}"
if cache_key not in _model_cache:
_model_cache[cache_key] = ModelFactory.create_model(model_type, model_path, **kwargs)
model = _model_cache[cache_key]
if not model.model: # Load if not loaded
model.load_model()
return model
def unload_all_models() -> None:
"""Unload all cached models."""
for model in _model_cache.values():
model.unload_model()
_model_cache.clear()
\ No newline at end of file
...@@ -22,13 +22,13 @@ Handles image/video analysis requests. ...@@ -22,13 +22,13 @@ Handles image/video analysis requests.
import os import os
import sys import sys
import torch import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import tempfile import tempfile
import subprocess import subprocess
import json import json
import cv2 import cv2
import time import time
from .comm import SocketCommunicator, Message from .comm import SocketCommunicator, Message
from .models import get_model
from .config import get_system_prompt_content, get_comm_type from .config import get_system_prompt_content, get_comm_type
# Set PyTorch CUDA memory management # Set PyTorch CUDA memory management
...@@ -93,7 +93,8 @@ def extract_frames(video_path, interval=10, optimize=False): ...@@ -93,7 +93,8 @@ def extract_frames(video_path, interval=10, optimize=False):
def is_video(file_path): def is_video(file_path):
return file_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')) return file_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv'))
def analyze_single_image(image_path, prompt, model, processor): def analyze_single_image(image_path, prompt, model):
"""Analyze a single image using the dynamic model."""
messages = [ messages = [
{ {
"role": "user", "role": "user",
...@@ -104,72 +105,54 @@ def analyze_single_image(image_path, prompt, model, processor): ...@@ -104,72 +105,54 @@ def analyze_single_image(image_path, prompt, model, processor):
} }
] ]
inputs = processor.apply_chat_template( return model.generate({"messages": messages}, max_new_tokens=128)
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
def analyze_media(media_path, prompt, model_path, interval=10): def analyze_media(media_path, prompt, model_path, interval=10):
"""Analyze media using dynamic model loading."""
torch.cuda.empty_cache() torch.cuda.empty_cache()
if model_path not in model_cache:
kwargs = {"device_map": "auto", "low_cpu_mem_usage": True} # Get model dynamically
if os.path.exists(model_path): model = get_model(model_path, model_type=None) # Auto-detect type
# Get system prompt
try: try:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, **kwargs) from .config import get_system_prompt_content
proc_path = model_path
except:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", **kwargs)
proc_path = "Qwen/Qwen2.5-VL-7B-Instruct"
else:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", **kwargs)
proc_path = "Qwen/Qwen2.5-VL-7B-Instruct"
model_cache[model_path] = model
processor_cache[model_path] = AutoProcessor.from_pretrained(proc_path)
else:
model = model_cache[model_path]
proc_path = model_path if os.path.exists(model_path) else "Qwen/Qwen2.5-VL-7B-Instruct"
processor = processor_cache[model_path]
system_prompt = get_system_prompt_content() system_prompt = get_system_prompt_content()
full_prompt = system_prompt + " " + prompt if system_prompt else prompt full_prompt = system_prompt + " " + prompt if system_prompt else prompt
except:
full_prompt = prompt
if is_video(media_path): if is_video(media_path):
frames, output_dir = extract_frames(media_path, interval, optimize=True) frames, output_dir = extract_frames(media_path, interval, optimize=True)
total_frames = len(frames) total_frames = len(frames)
descriptions = [] descriptions = []
for i, (frame_path, ts) in enumerate(frames): for i, (frame_path, ts) in enumerate(frames):
desc = analyze_single_image(frame_path, full_prompt, model, processor) desc = analyze_single_image(frame_path, full_prompt, model)
descriptions.append(f"At {ts:.2f}s: {desc}") descriptions.append(f"At {ts:.2f}s: {desc}")
os.unlink(frame_path) os.unlink(frame_path)
if output_dir: if output_dir:
import shutil import shutil
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
# Generate summary
if model.supports_vision():
# Use vision model for summary
summary_prompt = f"Summarize the video based on frame descriptions: {' '.join(descriptions)}" summary_prompt = f"Summarize the video based on frame descriptions: {' '.join(descriptions)}"
messages = [{"role": "user", "content": [{"type": "text", "text": summary_prompt}]}] messages = [{"role": "user", "content": [{"type": "text", "text": summary_prompt}]}]
inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt") summary = model.generate({"messages": messages}, max_new_tokens=256)
inputs = {k: v.to(model.device) for k, v in inputs.items()} else:
generated_ids = model.generate(**inputs, max_new_tokens=256) # Use text-only model for summary
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)] summary = model.generate(f"Summarize the video based on frame descriptions: {' '.join(descriptions)}", max_new_tokens=256)
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
summary = output_text[0]
result = f"Frame Descriptions:\n" + "\n".join(descriptions) + f"\n\nSummary:\n{summary}" result = f"Frame Descriptions:\n" + "\n".join(descriptions) + f"\n\nSummary:\n{summary}"
return result return result
else: else:
result = analyze_single_image(media_path, full_prompt, model, processor) result = analyze_single_image(media_path, full_prompt, model)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return result return result
model_cache = {}
processor_cache = {}
def worker_process(backend_type: str): def worker_process(backend_type: str):
"""Main worker process.""" """Main worker process."""
print(f"Starting Analysis Worker for {backend_type}...") print(f"Starting Analysis Worker for {backend_type}...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment