Refactor of the main.py in multiple files. 0.99.65

parent 05082603
...@@ -54,7 +54,7 @@ from .auth.qwen import QwenOAuth2 ...@@ -54,7 +54,7 @@ from .auth.qwen import QwenOAuth2
from .handlers import RequestHandler, RotationHandler, AutoselectHandler from .handlers import RequestHandler, RotationHandler, AutoselectHandler
from .utils import count_messages_tokens, split_messages_into_chunks, get_max_request_tokens_for_model from .utils import count_messages_tokens, split_messages_into_chunks, get_max_request_tokens_for_model
__version__ = "0.99.64" __version__ = "0.99.65"
__all__ = [ __all__ = [
# Config # Config
"config", "config",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""
Jinja2 template setup, proxy-aware URL helpers, and ProxyHeadersMiddleware.
Extracted from main.py.
"""
import hashlib
import logging
from pathlib import Path
from fastapi import Request
from fastapi.templating import Jinja2Templates
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class ProxyHeadersMiddleware(BaseHTTPMiddleware):
"""Handle X-Forwarded-* proxy headers."""
async def dispatch(self, request: Request, call_next):
forwarded_proto = request.headers.get("X-Forwarded-Proto")
forwarded_host = request.headers.get("X-Forwarded-Host")
forwarded_port = request.headers.get("X-Forwarded-Port")
forwarded_prefix = request.headers.get("X-Forwarded-Prefix") or request.headers.get("X-Script-Name")
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_proto or forwarded_host or forwarded_prefix:
logger.debug(f"Proxy headers detected - Proto: {forwarded_proto}, Host: {forwarded_host}, Prefix: {forwarded_prefix}")
if forwarded_proto:
request.scope["scheme"] = forwarded_proto
if forwarded_host:
if ":" in forwarded_host and not forwarded_port:
host_parts = forwarded_host.split(":", 1)
request.scope["server"] = (host_parts[0], int(host_parts[1]))
else:
port = int(forwarded_port) if forwarded_port else (443 if forwarded_proto == "https" else 80)
request.scope["server"] = (forwarded_host, port)
elif forwarded_port:
current_host = request.scope.get("server", ("localhost", 80))[0]
request.scope["server"] = (current_host, int(forwarded_port))
if forwarded_prefix:
forwarded_prefix = forwarded_prefix.rstrip("/")
request.scope["root_path"] = forwarded_prefix
original_path = request.scope.get("path", "")
if original_path.startswith(forwarded_prefix):
request.scope["path"] = original_path[len(forwarded_prefix):] or "/"
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
request.scope["client"] = (client_ip, request.scope.get("client", ("", 0))[1])
return await call_next(request)
def get_base_url(request: Request) -> str:
scheme = request.scope.get("scheme", "http")
server = request.scope.get("server", ("localhost", 80))
host, port = server[0], server[1]
root_path = request.scope.get("root_path", "")
if (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return f"{scheme}://{host}{root_path}"
return f"{scheme}://{host}:{port}{root_path}"
def url_for(request: Request, path: str) -> str:
root_path = request.scope.get("root_path", "")
if not path.startswith("/"):
path = "/" + path
is_behind_proxy = "x-forwarded-host" in request.headers or "x-forwarded-proto" in request.headers
if is_behind_proxy:
return (root_path + path) if (root_path and root_path != "/") else path
return f"{get_base_url(request)}{path}"
def create_templates(template_dir: str) -> Jinja2Templates:
templates = Jinja2Templates(directory=template_dir)
templates.env.loader.searchpath.insert(0, template_dir)
return templates
def setup_template_globals(templates: Jinja2Templates, version: str):
from aisbf import __version__
def md5_filter(s):
if not s:
return hashlib.md5(b'').hexdigest().lower()
return hashlib.md5(s.encode('utf-8')).hexdigest().lower()
templates.env.filters['md5'] = md5_filter
templates.env.globals['url_for'] = url_for
templates.env.globals['get_base_url'] = get_base_url
templates.env.globals['__version__'] = version
templates.env.cache.clear()
def patch_template_response(templates: Jinja2Templates):
"""Inject is_aisbf_cloud / welcome_shown into every TemplateResponse automatically."""
original = templates.TemplateResponse
def patched(*args, **kwargs):
if 'context' in kwargs and 'request' in kwargs['context']:
req = kwargs['context']['request']
if hasattr(req.state, 'is_aisbf_cloud'):
kwargs['context']['is_aisbf_cloud'] = req.state.is_aisbf_cloud
if hasattr(req.state, 'welcome_shown'):
kwargs['context']['welcome_shown'] = req.state.welcome_shown
return original(*args, **kwargs)
templates.TemplateResponse = patched
...@@ -27,6 +27,7 @@ import re ...@@ -27,6 +27,7 @@ import re
import uuid import uuid
import hashlib import hashlib
import threading import threading
import time
import time as time_module import time as time_module
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -45,7 +46,6 @@ from .context import ContextManager, get_context_config_for_model ...@@ -45,7 +46,6 @@ from .context import ContextManager, get_context_config_for_model
from .classifier import content_classifier from .classifier import content_classifier
from .classifier import SemanticClassifier from .classifier import SemanticClassifier
from .cache import get_response_cache from .cache import get_response_cache
import time as time_module
from .analytics import get_analytics from .analytics import get_analytics
from .streaming_optimization import ( from .streaming_optimization import (
get_streaming_optimizer, get_streaming_optimizer,
......
...@@ -916,21 +916,6 @@ class BaseProviderHandler: ...@@ -916,21 +916,6 @@ class BaseProviderHandler:
logger.info(f"[{self.provider_id}] API key present, validation passed") logger.info(f"[{self.provider_id}] API key present, validation passed")
return True return True
# Check if API key is provided
if not self.api_key:
logger.error(f"[{self.provider_id}] API key required but not provided")
return False
# Check for placeholder/empty API key
if isinstance(self.api_key, str):
stripped = self.api_key.strip()
if not stripped or stripped.startswith('YOUR_') or 'placeholder' in stripped.lower():
logger.error(f"[{self.provider_id}] Invalid API key format")
return False
logger.info(f"[{self.provider_id}] API key present, validation passed")
return True
def parse_429_response(self, response_data: Union[Dict, str], headers: Dict = None) -> Optional[int]: def parse_429_response(self, response_data: Union[Dict, str], headers: Dict = None) -> Optional[int]:
""" """
Parse 429 rate limit response to extract wait time in seconds. Parse 429 rate limit response to extract wait time in seconds.
......
This diff is collapsed.
...@@ -93,6 +93,18 @@ class CodexProviderHandler(BaseProviderHandler): ...@@ -93,6 +93,18 @@ class CodexProviderHandler(BaseProviderHandler):
self._use_api_key_mode = bool(api_key or _cfg_api_key) self._use_api_key_mode = bool(api_key or _cfg_api_key)
self._account_id = None # Will be extracted from ID token in OAuth2 mode self._account_id = None # Will be extracted from ID token in OAuth2 mode
# Base URL for API requests
_endpoint = (provider_config.get('endpoint') if isinstance(provider_config, dict)
else getattr(provider_config, 'endpoint', None)) if provider_config else None
self.base_url = (_endpoint or 'https://chatgpt.com/backend-api').rstrip('/')
# Initialize OpenAI client for API key mode
if self._use_api_key_mode:
effective_key = api_key or _cfg_api_key
self.client = OpenAI(api_key=effective_key, base_url=self.base_url)
else:
self.client = None
def validate_credentials(self) -> bool: def validate_credentials(self) -> bool:
""" """
Validate Codex credentials. Validate Codex credentials.
......
...@@ -124,7 +124,6 @@ class KiloProviderHandler(BaseProviderHandler): ...@@ -124,7 +124,6 @@ class KiloProviderHandler(BaseProviderHandler):
endpoint = 'https://kilo.ai/api/openrouter/v1' endpoint = 'https://kilo.ai/api/openrouter/v1'
self._kilo_endpoint = endpoint self._kilo_endpoint = endpoint
self.client = OpenAI(base_url=endpoint, api_key=api_key or "placeholder") self.client = OpenAI(base_url=endpoint, api_key=api_key or "placeholder")
def validate_credentials(self) -> bool: def validate_credentials(self) -> bool:
......
...@@ -28,6 +28,7 @@ import os ...@@ -28,6 +28,7 @@ import os
import json import json
import uuid import uuid
import logging import logging
from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from ...config import config from ...config import config
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "aisbf" name = "aisbf"
version = "0.99.64" version = "0.99.65"
description = "AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations" description = "AISBF - AI Service Broker Framework || AI Should Be Free - A modular proxy server for managing multiple AI provider integrations"
readme = "README.md" readme = "README.md"
license = "GPL-3.0-or-later" license = "GPL-3.0-or-later"
...@@ -50,6 +50,9 @@ Documentation = "https://git.nexlab.net/nexlab/aisbf.git" ...@@ -50,6 +50,9 @@ Documentation = "https://git.nexlab.net/nexlab/aisbf.git"
[tool.setuptools] [tool.setuptools]
packages = [ packages = [
"aisbf", "aisbf",
"aisbf.app",
"aisbf.routes",
"aisbf.routes.dashboard",
"aisbf.auth", "aisbf.auth",
"aisbf.providers", "aisbf.providers",
"aisbf.providers.kiro", "aisbf.providers.kiro",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment