Add configurable communication type: Unix vs TCP sockets

- Add comm_type configuration option (unix/tcp, default unix)
- Command line argument --comm-type for runtime selection
- Web configuration page includes communication type selection
- SocketCommunicator and SocketServer support both Unix and TCP
- Updated all processes to use configured communication type
- Documentation updated to reflect both socket types
- Unix sockets provide better performance for local communication
parent 2a1990a5
Pipeline #186 canceled with stages
...@@ -71,6 +71,9 @@ vidai-training-* ...@@ -71,6 +71,9 @@ vidai-training-*
# Result files # Result files
/tmp/vidai_results/ /tmp/vidai_results/
# Unix socket files
/tmp/vidai_*.sock
# Config (but keep structure) # Config (but keep structure)
/home/*/.config/vidai/ /home/*/.config/vidai/
~/.config/vidai/ ~/.config/vidai/
\ No newline at end of file
...@@ -70,6 +70,7 @@ Options: ...@@ -70,6 +70,7 @@ Options:
- `--flash`: Enable Flash Attention 2 - `--flash`: Enable Flash Attention 2
- `--analysis-backend {cuda,rocm}`: Backend for analysis - `--analysis-backend {cuda,rocm}`: Backend for analysis
- `--training-backend {cuda,rocm}`: Backend for training - `--training-backend {cuda,rocm}`: Backend for training
- `--comm-type {unix,tcp}`: Communication type for inter-process communication (default: unix)
- `--host HOST`: Host to bind server to (default: 0.0.0.0) - `--host HOST`: Host to bind server to (default: 0.0.0.0)
- `--port PORT`: Port to bind server to (default: 5000) - `--port PORT`: Port to bind server to (default: 5000)
- `--debug`: Enable debug mode - `--debug`: Enable debug mode
...@@ -101,10 +102,15 @@ Use the built executables from `dist/` directory. ...@@ -101,10 +102,15 @@ Use the built executables from `dist/` directory.
## API ## API
The backend communicates via TCP sockets: The backend communicates via configurable socket types for inter-process communication:
- Web interface: localhost:5001 **Unix Domain Sockets (default, recommended for performance):**
- Workers: localhost:5002 - Web interface: `/tmp/vidai_web.sock`
- Workers: `/tmp/vidai_workers.sock`
**TCP Sockets (for compatibility):**
- Web interface: `localhost:5001`
- Workers: `localhost:5002`
Message format: JSON with `msg_type`, `msg_id`, and `data` fields. Message format: JSON with `msg_type`, `msg_id`, and `data` fields.
......
...@@ -36,7 +36,7 @@ The Video AI Analysis Tool is designed as a multi-process application to provide ...@@ -36,7 +36,7 @@ The Video AI Analysis Tool is designed as a multi-process application to provide
### Communication Protocol ### Communication Protocol
All inter-process communication uses TCP sockets with JSON messages: All inter-process communication uses Unix domain sockets with JSON messages for optimal local performance:
```json ```json
{ {
......
...@@ -34,7 +34,8 @@ from vidai.config import ( ...@@ -34,7 +34,8 @@ from vidai.config import (
get_config, set_config, get_default_model, set_default_model, get_config, set_config, get_default_model, set_default_model,
get_analysis_backend, set_analysis_backend, get_training_backend, set_training_backend, get_analysis_backend, set_analysis_backend, get_training_backend, set_training_backend,
get_optimize, set_optimize, get_ffmpeg, set_ffmpeg, get_flash, set_flash, get_optimize, set_optimize, get_ffmpeg, set_ffmpeg, get_flash, set_flash,
get_host, set_host, get_port, set_port, get_debug, set_debug, get_allowed_dir, set_allowed_dir get_host, set_host, get_port, set_port, get_debug, set_debug, get_allowed_dir, set_allowed_dir,
get_comm_type, set_comm_type
) )
def main(): def main():
...@@ -108,6 +109,13 @@ Examples: ...@@ -108,6 +109,13 @@ Examples:
help=f'Backend for training (default: {default_training_backend})' help=f'Backend for training (default: {default_training_backend})'
) )
parser.add_argument(
'--comm-type',
choices=['unix', 'tcp'],
default=get_comm_type(),
help='Communication type for inter-process communication (default: unix)'
)
parser.add_argument( parser.add_argument(
'--host', '--host',
default=default_host, default=default_host,
...@@ -138,6 +146,7 @@ Examples: ...@@ -138,6 +146,7 @@ Examples:
set_flash(args.flash) set_flash(args.flash)
set_analysis_backend(args.analysis_backend) set_analysis_backend(args.analysis_backend)
set_training_backend(args.training_backend) set_training_backend(args.training_backend)
set_comm_type(args.comm_type)
set_host(args.host) set_host(args.host)
set_port(args.port) set_port(args.port)
set_debug(args.debug) set_debug(args.debug)
......
...@@ -22,7 +22,7 @@ Manages request routing between web interface and worker processes. ...@@ -22,7 +22,7 @@ Manages request routing between web interface and worker processes.
import time import time
import threading import threading
from .comm import SocketServer, Message from .comm import SocketServer, Message
from .config import get_analysis_backend, get_training_backend, set_analysis_backend, set_training_backend from .config import get_analysis_backend, get_training_backend, set_analysis_backend, set_training_backend, get_comm_type
worker_sockets = {} # type: dict worker_sockets = {} # type: dict
...@@ -105,12 +105,24 @@ def backend_process() -> None: ...@@ -105,12 +105,24 @@ def backend_process() -> None:
"""Main backend process loop.""" """Main backend process loop."""
print("Starting Video AI Backend...") print("Starting Video AI Backend...")
# Start web server on port 5001 comm_type = get_comm_type()
web_server = SocketServer(port=5001) print(f"Using {comm_type} sockets for communication")
if comm_type == 'unix':
# Start web server on Unix socket
web_server = SocketServer(socket_path='/tmp/vidai_web.sock', comm_type='unix')
web_server.start(handle_web_message)
# Start worker server on Unix socket
worker_server = SocketServer(socket_path='/tmp/vidai_workers.sock', comm_type='unix')
worker_server.start(worker_message_handler)
else:
# Start web server on TCP
web_server = SocketServer(host='localhost', port=5001, comm_type='tcp')
web_server.start(handle_web_message) web_server.start(handle_web_message)
# Start worker server on port 5002 # Start worker server on TCP
worker_server = SocketServer(port=5002) worker_server = SocketServer(host='localhost', port=5002, comm_type='tcp')
worker_server.start(worker_message_handler) worker_server.start(worker_message_handler)
try: try:
......
...@@ -36,15 +36,23 @@ class Message: ...@@ -36,15 +36,23 @@ class Message:
class SocketCommunicator: class SocketCommunicator:
"""Handles socket-based communication.""" """Handles socket-based communication using Unix domain or TCP sockets."""
def __init__(self, host: str = 'localhost', port: int = 5001): def __init__(self, socket_path: str = '/tmp/vidai_web.sock', host: str = 'localhost', port: int = 5001, comm_type: str = 'unix'):
self.comm_type = comm_type
if comm_type == 'unix':
self.socket_path = socket_path
else:
self.host = host self.host = host
self.port = port self.port = port
self.sock: Optional[socket.socket] = None self.sock: Optional[socket.socket] = None
def connect(self) -> None: def connect(self) -> None:
"""Connect to the server.""" """Connect to the server."""
if self.comm_type == 'unix':
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self.socket_path)
else:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.host, self.port)) self.sock.connect((self.host, self.port))
...@@ -81,9 +89,13 @@ class SocketCommunicator: ...@@ -81,9 +89,13 @@ class SocketCommunicator:
class SocketServer: class SocketServer:
"""Simple socket server for handling connections.""" """Simple socket server for handling connections (Unix or TCP)."""
def __init__(self, host: str = 'localhost', port: int = 5001): def __init__(self, socket_path: str = '/tmp/vidai_backend.sock', host: str = 'localhost', port: int = 5001, comm_type: str = 'unix'):
self.comm_type = comm_type
if comm_type == 'unix':
self.socket_path = socket_path
else:
self.host = host self.host = host
self.port = port self.port = port
self.server_sock: Optional[socket.socket] = None self.server_sock: Optional[socket.socket] = None
...@@ -93,9 +105,22 @@ class SocketServer: ...@@ -93,9 +105,22 @@ class SocketServer:
def start(self, message_handler) -> None: def start(self, message_handler) -> None:
"""Start the server.""" """Start the server."""
self.message_handler = message_handler self.message_handler = message_handler
if self.comm_type == 'unix':
# Clean up any existing socket file
try:
os.unlink(self.socket_path)
except OSError:
pass
self.server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.server_sock.bind(self.socket_path)
# Set permissions so other users can connect
os.chmod(self.socket_path, 0o666)
else:
self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_sock.bind((self.host, self.port)) self.server_sock.bind((self.host, self.port))
self.server_sock.listen(5) self.server_sock.listen(5)
self.running = True self.running = True
threading.Thread(target=self._accept_loop, daemon=True).start() threading.Thread(target=self._accept_loop, daemon=True).start()
...@@ -146,3 +171,9 @@ class SocketServer: ...@@ -146,3 +171,9 @@ class SocketServer:
self.running = False self.running = False
if self.server_sock: if self.server_sock:
self.server_sock.close() self.server_sock.close()
if self.comm_type == 'unix':
# Clean up socket file
try:
os.unlink(self.socket_path)
except OSError:
pass
\ No newline at end of file
...@@ -143,6 +143,16 @@ def set_allowed_dir(dir_path: str) -> None: ...@@ -143,6 +143,16 @@ def set_allowed_dir(dir_path: str) -> None:
set_config('allowed_dir', dir_path) set_config('allowed_dir', dir_path)
def get_comm_type() -> str:
"""Get communication type."""
return get_config('comm_type', 'unix')
def set_comm_type(comm_type: str) -> None:
"""Set communication type."""
set_config('comm_type', comm_type)
def get_all_settings() -> dict: def get_all_settings() -> dict:
"""Get all configuration settings.""" """Get all configuration settings."""
config = get_all_config() config = get_all_config()
...@@ -158,5 +168,6 @@ def get_all_settings() -> dict: ...@@ -158,5 +168,6 @@ def get_all_settings() -> dict:
'port': int(config.get('port', '5000')), 'port': int(config.get('port', '5000')),
'debug': config.get('debug', 'false').lower() == 'true', 'debug': config.get('debug', 'false').lower() == 'true',
'allowed_dir': config.get('allowed_dir', ''), 'allowed_dir': config.get('allowed_dir', ''),
'comm_type': config.get('comm_type', 'unix'),
'system_prompt': get_system_prompt_content() 'system_prompt': get_system_prompt_content()
} }
\ No newline at end of file
...@@ -71,7 +71,8 @@ def init_db(conn: sqlite3.Connection) -> None: ...@@ -71,7 +71,8 @@ def init_db(conn: sqlite3.Connection) -> None:
'host': '0.0.0.0', 'host': '0.0.0.0',
'port': '5000', 'port': '5000',
'debug': 'false', 'debug': 'false',
'allowed_dir': '' 'allowed_dir': '',
'comm_type': 'unix'
} }
for key, value in defaults.items(): for key, value in defaults.items():
......
...@@ -26,13 +26,17 @@ import json ...@@ -26,13 +26,17 @@ import json
import uuid import uuid
import time import time
from .comm import SocketCommunicator, Message from .comm import SocketCommunicator, Message
from .config import get_system_prompt_content, set_system_prompt_content, get_all_settings, set_analysis_backend, set_training_backend, set_default_model, set_frame_interval from .config import get_system_prompt_content, set_system_prompt_content, get_all_settings, set_analysis_backend, set_training_backend, set_default_model, set_frame_interval, get_comm_type, set_comm_type
app = Flask(__name__) app = Flask(__name__)
os.makedirs('static', exist_ok=True) os.makedirs('static', exist_ok=True)
# Communicator to backend # Communicator to backend
comm = SocketCommunicator(port=5001) comm_type = get_comm_type()
if comm_type == 'unix':
comm = SocketCommunicator(socket_path='/tmp/vidai_web.sock', comm_type='unix')
else:
comm = SocketCommunicator(host='localhost', port=5001, comm_type='tcp')
comm.connect() comm.connect()
def send_to_backend(msg_type: str, data: dict) -> str: def send_to_backend(msg_type: str, data: dict) -> str:
...@@ -192,6 +196,7 @@ def config(): ...@@ -192,6 +196,7 @@ def config():
# Update local config # Update local config
set_analysis_backend(request.form.get('analysis_backend', 'cuda')) set_analysis_backend(request.form.get('analysis_backend', 'cuda'))
set_training_backend(request.form.get('training_backend', 'cuda')) set_training_backend(request.form.get('training_backend', 'cuda'))
set_comm_type(request.form.get('comm_type', 'unix'))
set_default_model(request.form.get('default_model', 'Qwen/Qwen2.5-VL-7B-Instruct')) set_default_model(request.form.get('default_model', 'Qwen/Qwen2.5-VL-7B-Instruct'))
set_frame_interval(int(request.form.get('frame_interval', 10))) set_frame_interval(int(request.form.get('frame_interval', 10)))
...@@ -239,6 +244,12 @@ def config(): ...@@ -239,6 +244,12 @@ def config():
<option value="rocm" {"selected" if current_config['training_backend'] == 'rocm' else ""}>ROCm</option> <option value="rocm" {"selected" if current_config['training_backend'] == 'rocm' else ""}>ROCm</option>
</select> </select>
</label> </label>
<label>Communication Type:
<select name="comm_type">
<option value="unix" {"selected" if current_config['comm_type'] == 'unix' else ""}>Unix Socket</option>
<option value="tcp" {"selected" if current_config['comm_type'] == 'tcp' else ""}>TCP Socket</option>
</select>
</label>
<label>Default Model: <input type="text" name="default_model" value="{current_config['default_model']}"></label> <label>Default Model: <input type="text" name="default_model" value="{current_config['default_model']}"></label>
<label>Frame Interval (seconds): <input type="number" name="frame_interval" value="{current_config['frame_interval']}" min="1"></label> <label>Frame Interval (seconds): <input type="number" name="frame_interval" value="{current_config['frame_interval']}" min="1"></label>
<input type="submit" value="Save Configuration"> <input type="submit" value="Save Configuration">
......
...@@ -29,7 +29,7 @@ import json ...@@ -29,7 +29,7 @@ import json
import cv2 import cv2
import time import time
from .comm import SocketCommunicator, Message from .comm import SocketCommunicator, Message
from .config import get_system_prompt_content from .config import get_system_prompt_content, get_comm_type
# Set PyTorch CUDA memory management # Set PyTorch CUDA memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
...@@ -174,7 +174,11 @@ def worker_process(backend_type: str): ...@@ -174,7 +174,11 @@ def worker_process(backend_type: str):
"""Main worker process.""" """Main worker process."""
print(f"Starting Analysis Worker for {backend_type}...") print(f"Starting Analysis Worker for {backend_type}...")
comm = SocketCommunicator(port=5002) comm_type = get_comm_type()
if comm_type == 'unix':
comm = SocketCommunicator(socket_path='/tmp/vidai_workers.sock', comm_type='unix')
else:
comm = SocketCommunicator(host='localhost', port=5002, comm_type='tcp')
comm.connect() comm.connect()
# Register with backend # Register with backend
......
...@@ -27,6 +27,7 @@ import shutil ...@@ -27,6 +27,7 @@ import shutil
import json import json
import time import time
from .comm import SocketCommunicator, Message from .comm import SocketCommunicator, Message
from .config import get_comm_type
def train_model(train_path, output_model, description): def train_model(train_path, output_model, description):
"""Perform training.""" """Perform training."""
...@@ -46,7 +47,11 @@ def worker_process(backend_type: str): ...@@ -46,7 +47,11 @@ def worker_process(backend_type: str):
"""Main worker process.""" """Main worker process."""
print(f"Starting Training Worker for {backend_type}...") print(f"Starting Training Worker for {backend_type}...")
comm = SocketCommunicator(port=5002) comm_type = get_comm_type()
if comm_type == 'unix':
comm = SocketCommunicator(socket_path='/tmp/vidai_workers.sock', comm_type='unix')
else:
comm = SocketCommunicator(host='localhost', port=5002, comm_type='tcp')
comm.connect() comm.connect()
# Register with backend # Register with backend
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment