Add progress reporting, token counting, and configurable VRAM overhead for models

parent 11d83f87
...@@ -108,6 +108,15 @@ def handle_web_message(message: Message, client_sock=None) -> Message: ...@@ -108,6 +108,15 @@ def handle_web_message(message: Message, client_sock=None) -> Message:
return result return result
else: else:
return Message('result_pending', message.msg_id, {'status': 'pending'}) return Message('result_pending', message.msg_id, {'status': 'pending'})
elif message.msg_type == 'get_progress':
# Check for pending progress update
job_id = message.data.get('job_id')
progress_key = f"progress_{job_id}"
if progress_key in pending_results:
progress = pending_results[progress_key]
return progress
else:
return Message('progress_pending', message.msg_id, {'status': 'no_progress'})
return Message('error', message.msg_id, {'error': 'Unknown message type'}) return Message('error', message.msg_id, {'error': 'Unknown message type'})
...@@ -118,10 +127,26 @@ def handle_worker_message(message: Message, client_sock) -> None: ...@@ -118,10 +127,26 @@ def handle_worker_message(message: Message, client_sock) -> None:
if worker_type: if worker_type:
worker_sockets[worker_type] = client_sock worker_sockets[worker_type] = client_sock
print(f"Worker {worker_type} registered") print(f"Worker {worker_type} registered")
elif message.msg_type == 'progress':
# Store progress update for web to poll
progress_key = f"progress_{message.data.get('job_id')}"
pending_results[progress_key] = message
elif message.msg_type in ['analyze_response', 'train_response']: elif message.msg_type in ['analyze_response', 'train_response']:
# Store result for web to poll # Store result for web to poll
pending_results[message.msg_id] = message pending_results[message.msg_id] = message
# Update token usage in database if provided
tokens_used = message.data.get('tokens_used')
if tokens_used:
try:
from .database import update_queue_status
# Extract job_id from message_id (format: job_xxx)
job_id = message.msg_id
if job_id.startswith('job_'):
update_queue_status(job_id, 'completed', used_tokens=tokens_used)
except Exception as e:
print(f"Error updating token usage: {e}")
def worker_message_handler(message: Message, client_sock) -> None: def worker_message_handler(message: Message, client_sock) -> None:
"""Handler for worker messages.""" """Handler for worker messages."""
......
...@@ -50,8 +50,8 @@ class BaseModel(ABC): ...@@ -50,8 +50,8 @@ class BaseModel(ABC):
pass pass
@abstractmethod @abstractmethod
def generate(self, inputs: Dict[str, Any], **kwargs) -> str: def generate(self, inputs: Dict[str, Any], **kwargs) -> tuple[str, int]:
"""Generate response from inputs.""" """Generate response from inputs. Returns (result, tokens_used)"""
pass pass
@abstractmethod @abstractmethod
...@@ -115,7 +115,7 @@ class VisionLanguageModel(BaseModel): ...@@ -115,7 +115,7 @@ class VisionLanguageModel(BaseModel):
self.processor = AutoProcessor.from_pretrained(proc_path) self.processor = AutoProcessor.from_pretrained(proc_path)
def generate(self, inputs: Dict[str, Any], **kwargs) -> str: def generate(self, inputs: Dict[str, Any], **kwargs) -> tuple[str, int]:
"""Generate response for vision-language inputs.""" """Generate response for vision-language inputs."""
if not self.model or not self.processor: if not self.model or not self.processor:
raise RuntimeError("Model not loaded") raise RuntimeError("Model not loaded")
...@@ -147,7 +147,13 @@ class VisionLanguageModel(BaseModel): ...@@ -147,7 +147,13 @@ class VisionLanguageModel(BaseModel):
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
) )
return output_text[0] if output_text else "" result = output_text[0] if output_text else ""
# Estimate tokens used (input + output)
input_tokens = model_inputs['input_ids'].shape[1] if 'input_ids' in model_inputs else 0
output_tokens = generated_ids_trimmed[0].shape[0] if generated_ids_trimmed else 0
total_tokens = input_tokens + output_tokens
return result, total_tokens
class TextOnlyModel(BaseModel): class TextOnlyModel(BaseModel):
...@@ -173,7 +179,7 @@ class TextOnlyModel(BaseModel): ...@@ -173,7 +179,7 @@ class TextOnlyModel(BaseModel):
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **kwargs) self.model = AutoModelForCausalLM.from_pretrained(self.model_path, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
def generate(self, inputs: Dict[str, Any], **kwargs) -> str: def generate(self, inputs: Dict[str, Any], **kwargs) -> tuple[str, int]:
"""Generate response for text inputs.""" """Generate response for text inputs."""
if not self.model or not self.tokenizer: if not self.model or not self.tokenizer:
raise RuntimeError("Model not loaded") raise RuntimeError("Model not loaded")
...@@ -199,7 +205,13 @@ class TextOnlyModel(BaseModel): ...@@ -199,7 +205,13 @@ class TextOnlyModel(BaseModel):
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
) )
return output_text[0] if output_text else "" result = output_text[0] if output_text else ""
# Estimate tokens used (input + output)
input_tokens = model_inputs['input_ids'].shape[1] if 'input_ids' in model_inputs else 0
output_tokens = generated_ids_trimmed[0].shape[0] if generated_ids_trimmed else 0
total_tokens = input_tokens + output_tokens
return result, total_tokens
class ModelFactory: class ModelFactory:
......
...@@ -152,7 +152,10 @@ def analyze_single_image(image_path, prompt, model): ...@@ -152,7 +152,10 @@ def analyze_single_image(image_path, prompt, model):
} }
] ]
return model.generate({"messages": messages}, max_new_tokens=128) result = model.generate({"messages": messages}, max_new_tokens=128)
# For now, estimate tokens (could be improved with actual token counting)
estimated_tokens = len(result.split()) + len(prompt.split())
return result, estimated_tokens
def check_job_cancelled(job_id): def check_job_cancelled(job_id):
"""Check if a job has been cancelled.""" """Check if a job has been cancelled."""
...@@ -163,10 +166,11 @@ def check_job_cancelled(job_id): ...@@ -163,10 +166,11 @@ def check_job_cancelled(job_id):
except: except:
return False return False
def analyze_media(media_path, prompt, model_path, interval=10, job_id=None): def analyze_media(media_path, prompt, model_path, interval=10, job_id=None, comm=None):
"""Analyze media using dynamic model loading.""" """Analyze media using dynamic model loading."""
print(f"DEBUG: Starting analyze_media for job {job_id}, media_path={media_path}") print(f"DEBUG: Starting analyze_media for job {job_id}, media_path={media_path}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
total_tokens = 0
# Get model with reference counting # Get model with reference counting
print(f"DEBUG: Loading model {model_path} for job {job_id}") print(f"DEBUG: Loading model {model_path} for job {job_id}")
...@@ -187,6 +191,17 @@ def analyze_media(media_path, prompt, model_path, interval=10, job_id=None): ...@@ -187,6 +191,17 @@ def analyze_media(media_path, prompt, model_path, interval=10, job_id=None):
print(f"DEBUG: Detected video, extracting frames for job {job_id}") print(f"DEBUG: Detected video, extracting frames for job {job_id}")
frames, output_dir = extract_frames(media_path, interval, optimize=True) frames, output_dir = extract_frames(media_path, interval, optimize=True)
total_frames = len(frames) total_frames = len(frames)
# Send progress update for frame extraction
if comm:
progress_msg = Message('progress', f'progress_{job_id}', {
'job_id': job_id,
'stage': 'frame_extraction',
'progress': 10,
'message': f'Extracted {total_frames} frames'
})
comm.send_message(progress_msg)
descriptions = [] descriptions = []
for i, (frame_path, ts) in enumerate(frames): for i, (frame_path, ts) in enumerate(frames):
...@@ -206,48 +221,83 @@ def analyze_media(media_path, prompt, model_path, interval=10, job_id=None): ...@@ -206,48 +221,83 @@ def analyze_media(media_path, prompt, model_path, interval=10, job_id=None):
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
except: except:
pass pass
return "Job cancelled by user" return "Job cancelled by user", total_tokens
desc = analyze_single_image(frame_path, full_prompt, model) desc, tokens = analyze_single_image(frame_path, full_prompt, model)
total_tokens += tokens
print(f"DEBUG: Frame {i+1} analyzed for job {job_id}") print(f"DEBUG: Frame {i+1} analyzed for job {job_id}")
descriptions.append(f"At {ts:.2f}s: {desc}") descriptions.append(f"At {ts:.2f}s: {desc}")
os.unlink(frame_path) os.unlink(frame_path)
# Send progress update
if comm:
progress_percent = 10 + int((i + 1) / total_frames * 70) # 10-80% for frame processing
progress_msg = Message('progress', f'progress_{job_id}', {
'job_id': job_id,
'stage': 'frame_analysis',
'progress': progress_percent,
'message': f'Analyzed frame {i+1}/{total_frames}'
})
comm.send_message(progress_msg)
if output_dir: if output_dir:
import shutil import shutil
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
print(f"DEBUG: All frames processed, generating summary for job {job_id}") print(f"DEBUG: All frames processed, generating summary for job {job_id}")
# Send progress update for summary generation
if comm:
progress_msg = Message('progress', f'progress_{job_id}', {
'job_id': job_id,
'stage': 'summary_generation',
'progress': 85,
'message': 'Generating video summary'
})
comm.send_message(progress_msg)
# Check for cancellation before summary # Check for cancellation before summary
if job_id and check_job_cancelled(job_id): if job_id and check_job_cancelled(job_id):
print(f"DEBUG: Job {job_id} cancelled before summary") print(f"DEBUG: Job {job_id} cancelled before summary")
return "Job cancelled by user" return "Job cancelled by user", total_tokens
# Generate summary # Generate summary
if model.supports_vision(): if model.supports_vision():
# Use vision model for summary # Use vision model for summary
summary_prompt = f"Summarize the video based on frame descriptions: {' '.join(descriptions)}" summary_prompt = f"Summarize the video based on frame descriptions: {' '.join(descriptions)}"
messages = [{"role": "user", "content": [{"type": "text", "text": summary_prompt}]}] messages = [{"role": "user", "content": [{"type": "text", "text": summary_prompt}]}]
summary = model.generate({"messages": messages}, max_new_tokens=256) summary, summary_tokens = model.generate({"messages": messages}, max_new_tokens=256)
else: else:
# Use text-only model for summary # Use text-only model for summary
summary = model.generate(f"Summarize the video based on frame descriptions: {' '.join(descriptions)}", max_new_tokens=256) summary, summary_tokens = model.generate(f"Summarize the video based on frame descriptions: {' '.join(descriptions)}", max_new_tokens=256)
total_tokens += summary_tokens
print(f"DEBUG: Summary generated for job {job_id}") print(f"DEBUG: Summary generated for job {job_id}")
# Send final progress update
if comm:
progress_msg = Message('progress', f'progress_{job_id}', {
'job_id': job_id,
'stage': 'completed',
'progress': 100,
'message': 'Analysis completed'
})
comm.send_message(progress_msg)
result = f"Frame Descriptions:\n" + "\n".join(descriptions) + f"\n\nSummary:\n{summary}" result = f"Frame Descriptions:\n" + "\n".join(descriptions) + f"\n\nSummary:\n{summary}"
return result return result, total_tokens
else: else:
print(f"DEBUG: Detected image, analyzing for job {job_id}") print(f"DEBUG: Detected image, analyzing for job {job_id}")
# Check for cancellation before processing image # Check for cancellation before processing image
if job_id and check_job_cancelled(job_id): if job_id and check_job_cancelled(job_id):
print(f"DEBUG: Job {job_id} cancelled before image analysis") print(f"DEBUG: Job {job_id} cancelled before image analysis")
return "Job cancelled by user" return "Job cancelled by user", total_tokens
result = analyze_single_image(media_path, full_prompt, model) result, tokens = analyze_single_image(media_path, full_prompt, model)
total_tokens += tokens
print(f"DEBUG: Image analysis completed for job {job_id}") print(f"DEBUG: Image analysis completed for job {job_id}")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return result return result, total_tokens
def worker_process(backend_type: str): def worker_process(backend_type: str):
"""Main worker process.""" """Main worker process."""
print(f"DEBUG: Starting Analysis Worker for {backend_type}...") print(f"DEBUG: Starting Analysis Worker for {backend_type}...")
...@@ -282,14 +332,14 @@ def worker_process(backend_type: str): ...@@ -282,14 +332,14 @@ def worker_process(backend_type: str):
interval = data.get('interval', 10) interval = data.get('interval', 10)
job_id = message.msg_id # Use message ID for job identification job_id = message.msg_id # Use message ID for job identification
print(f"DEBUG: Starting analysis of {media_path} with model {model_path} for job {job_id}") print(f"DEBUG: Starting analysis of {media_path} with model {model_path} for job {job_id}")
result = analyze_media(media_path, prompt, model_path, interval, job_id) result, tokens_used = analyze_media(media_path, prompt, model_path, interval, job_id, comm)
print(f"DEBUG: Analysis completed for job {message.msg_id}") print(f"DEBUG: Analysis completed for job {message.msg_id}, used {tokens_used} tokens")
# Release model reference (don't unload yet, per requirements) # Release model reference (don't unload yet, per requirements)
release_model(model_path) release_model(model_path)
# Send result back # Send result back
response = Message('analyze_response', message.msg_id, {'result': result}) response = Message('analyze_response', message.msg_id, {'result': result, 'tokens_used': tokens_used})
print(f"DEBUG: Sending analyze_response for job {message.msg_id}") print(f"DEBUG: Sending analyze_response for job {message.msg_id}")
comm.send_message(response) comm.send_message(response)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment