refactor(credentials): move provider validation into handler layer

- Add BaseProviderHandler.validate_credentials() with default api_key check
- Implement provider-specific validate_credentials() overrides:
  - Kiro: validates creds_file/sqlite_db/token persistence (file for admin,
    DB path for users)
  - OpenAI/Anthropic/Google: validate api_key presence/format
  - Claude/Codex/Kilo/Qwen: validate OAuth2 or API key
  - Ollama: always valid (no authentication required)
- get_provider_handler() now calls handler.validate_credentials() after
  instantiation, raising ValueError on failure
- Replace all credential validation in main.py API endpoints with
  handler-level checks, removing duplicate logic
- get_provider_models() now uses get_provider_handler() for unified validation
  instead of scattered inline checks
- Remove obsolete validate_kiro_credentials() function from main.py
- All validation respects user vs admin credential storage (DB vs files)
parent 002ed209
......@@ -136,5 +136,16 @@ def get_provider_handler(provider_id: str, api_key: Optional[str] = None, user_i
handler.provider_config = provider_config
logger.info(f"Handler created: {handler.__class__.__name__}")
# Validate credentials for this provider
try:
if not handler.validate_credentials():
logger.error(f"Provider '{provider_id}' credentials validation failed")
raise ValueError(f"Provider '{provider_id}' credentials not valid or not configured")
logger.info(f"Credentials validated for provider '{provider_id}'")
except Exception as e:
logger.error(f"Error validating credentials for provider '{provider_id}': {e}")
raise
logger.info(f"=== get_provider_handler END ===")
return handler
......@@ -34,6 +34,20 @@ class AnthropicProviderHandler(BaseProviderHandler):
super().__init__(provider_id, api_key)
self.client = Anthropic(api_key=api_key)
def validate_credentials(self) -> bool:
"""Validate Anthropic API key presence."""
if not self.api_key:
logging.error(f"[{self.provider_id}] API key required but not provided")
return False
stripped = self.api_key.strip()
if not stripped or stripped.startswith('YOUR_'):
logging.error(f"[{self.provider_id}] API key appears to be a placeholder")
return False
logging.info(f"[{self.provider_id}] API key validated")
return True
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
......
......@@ -857,6 +857,80 @@ class BaseProviderHandler:
except Exception:
pass
def validate_credentials(self) -> bool:
"""
Validate provider credentials.
Returns:
True if credentials are valid or validation is not needed.
False if credentials are invalid/missing.
Base implementation checks only if api_key_required=True and api_key is missing/empty.
Override in subclasses for provider-specific validation (e.g., Kiro credential files).
"""
import logging
logger = logging.getLogger(__name__)
# Determine which config to use: user-specific or global
is_user_context = self.user_id is not None and hasattr(self, 'user_provider_config') and self.user_provider_config is not None
provider_config = self.user_provider_config if is_user_context else None
if provider_config is None:
from ..config import config
provider_config = config.providers.get(self.provider_id)
if provider_config is None:
logger.error(f"[{self.provider_id}] Provider configuration not found")
return False
# Check if this provider requires authentication
if isinstance(provider_config, dict):
api_key_required = provider_config.get('api_key_required', False)
else:
api_key_required = getattr(provider_config, 'api_key_required', False)
if not api_key_required:
logger.debug(f"[{self.provider_id}] No API key required, skipping credential validation")
return True
# Check if API key is provided (either from config or passed to constructor)
if not self.api_key:
# Also check if it might be in config
if isinstance(provider_config, dict):
api_key_from_config = provider_config.get('api_key')
else:
api_key_from_config = getattr(provider_config, 'api_key', None)
if api_key_from_config:
self.api_key = api_key_from_config
else:
logger.error(f"[{self.provider_id}] API key required but not provided")
return False
# Check for placeholder/empty API key
if isinstance(self.api_key, str):
stripped = self.api_key.strip()
if not stripped or stripped.startswith('YOUR_'):
logger.error(f"[{self.provider_id}] Invalid API key format")
return False
logger.info(f"[{self.provider_id}] API key present, validation passed")
return True
# Check if API key is provided
if not self.api_key:
logger.error(f"[{self.provider_id}] API key required but not provided")
return False
# Check for placeholder/empty API key
if isinstance(self.api_key, str):
stripped = self.api_key.strip()
if not stripped or stripped.startswith('YOUR_') or 'placeholder' in stripped.lower():
logger.error(f"[{self.provider_id}] Invalid API key format")
return False
logger.info(f"[{self.provider_id}] API key present, validation passed")
return True
def parse_429_response(self, response_data: Union[Dict, str], headers: Dict = None) -> Optional[int]:
"""
Parse 429 rate limit response to extract wait time in seconds.
......
......@@ -205,745 +205,35 @@ class ClaudeProviderHandler(BaseProviderHandler):
# Initialize persistent identifiers for metadata
self._init_session_identifiers()
@staticmethod
def _cli_credentials_to_oauth_tokens(cli_creds: dict) -> Optional[dict]:
def validate_credentials(self) -> bool:
"""
Convert Claude CLI .credentials.json format to AISBF OAuth2 token format.
Validate Claude credentials.
CLI format: claudeAiOauth.accessToken/refreshToken/expiresAt(ms)/scopes
AISBF format: access_token/refresh_token/expires_at(s)/scope
"""
if not isinstance(cli_creds, dict):
return None
oauth = cli_creds.get('claudeAiOauth', {})
if not isinstance(oauth, dict):
return None
access_token = oauth.get('accessToken', '')
if not access_token:
return None
expires_at_ms = oauth.get('expiresAt', 0) or 0
scopes = oauth.get('scopes', [])
return {
'access_token': access_token,
'refresh_token': oauth.get('refreshToken', ''),
'expires_at': expires_at_ms / 1000.0,
'scope': ' '.join(scopes) if isinstance(scopes, list) else '',
'subscription_type': oauth.get('subscriptionType', 'pro'),
'rate_limit_tier': oauth.get('rateLimitTier', 'default_claude_ai'),
}
Checks OAuth2 tokens via self.auth.is_authenticated() or API token if configured.
def _save_auth_to_db(self, credentials: Dict) -> None:
"""Save OAuth2 credentials back to the database after a token refresh."""
if self.user_id is None:
return
try:
from ..database import DatabaseRegistry
db = DatabaseRegistry.get_config_database()
if db:
db.save_user_oauth2_credentials(
user_id=self.user_id,
provider_id=self.provider_id,
auth_type='claude_oauth2',
credentials=credentials
)
_logging.getLogger(__name__).info(
f"ClaudeProviderHandler: Saved refreshed credentials to DB for user {self.user_id}"
)
except Exception as e:
_logging.getLogger(__name__).warning(
f"ClaudeProviderHandler: Failed to save credentials to DB: {e}"
)
def _get_api_token(self) -> Optional[str]:
"""Return the Anthropic API token if configured, enabling standard API key auth."""
if isinstance(self.provider_config, dict):
claude_cfg = self.provider_config.get('claude_config', {}) or {}
else:
claude_cfg = getattr(self.provider_config, 'claude_config', {}) or {}
if isinstance(claude_cfg, dict):
return claude_cfg.get('api_token') or None
return None
def _load_auth_from_db(self, provider_id: str, credentials_file: str):
"""
Load OAuth2 credentials:
- Admin users (user_id=None): load from file, fall back to CLI credentials file
- Regular users: load from database (OAuth2 first, then CLI credentials)
Returns:
True if credentials are valid, False otherwise.
"""
from ..auth.claude import ClaudeAuth
import logging
logger = logging.getLogger(__name__)
if self.user_id is None:
# Admin: load from OAuth2 credentials file
logging.getLogger(__name__).info(f"ClaudeProviderHandler: Admin user, loading credentials from file: {credentials_file}")
auth = ClaudeAuth(credentials_file=credentials_file)
# Fallback: if no tokens, try to extract from CLI credentials file
if not auth.tokens:
if isinstance(self.provider_config, dict):
claude_cfg = self.provider_config.get('claude_config', {}) or {}
else:
claude_cfg = getattr(self.provider_config, 'claude_config', {}) or {}
cli_file = claude_cfg.get('cli_credentials_file') if isinstance(claude_cfg, dict) else None
if cli_file:
expanded = os.path.expanduser(cli_file)
if os.path.exists(expanded):
try:
with open(expanded) as fh:
cli_creds = json.load(fh)
tokens = self._cli_credentials_to_oauth_tokens(cli_creds)
if tokens:
auth.tokens = tokens
logging.getLogger(__name__).info(
"ClaudeProviderHandler: Admin – loaded OAuth2 tokens from CLI credentials file"
)
except Exception as exc:
logging.getLogger(__name__).warning(
f"ClaudeProviderHandler: Failed to read CLI credentials file for token extraction: {exc}"
)
return auth
# Regular user: try OAuth2 credentials from DB first
try:
from ..database import DatabaseRegistry
db = DatabaseRegistry.get_config_database()
if db:
db_creds = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='claude_oauth2'
)
if db_creds and db_creds.get('credentials'):
auth = ClaudeAuth(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
auth.tokens = db_creds['credentials'].get('tokens', {})
if auth.tokens and 'expires_at' not in auth.tokens and 'expires_in' in auth.tokens:
auth.tokens['expires_at'] = time.time() + auth.tokens.get('expires_in', 3600)
logging.getLogger(__name__).info(
f"ClaudeProviderHandler: Loaded OAuth2 credentials from DB for user {self.user_id}"
)
return auth
# Fallback: extract OAuth2 tokens from uploaded CLI credentials
cli_row = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=provider_id,
auth_type='claude_cli_credentials'
)
if cli_row and cli_row.get('credentials'):
cli_creds = cli_row['credentials'].get('credentials') or cli_row['credentials']
tokens = self._cli_credentials_to_oauth_tokens(cli_creds)
if tokens:
auth = ClaudeAuth(
credentials_file=credentials_file,
skip_initial_load=True,
save_callback=lambda creds: self._save_auth_to_db(creds)
)
auth.tokens = tokens
logging.getLogger(__name__).info(
f"ClaudeProviderHandler: Extracted OAuth2 tokens from CLI credentials for user {self.user_id}"
)
return auth
except Exception as e:
logging.getLogger(__name__).warning(
f"ClaudeProviderHandler: Failed to load credentials from database: {e}"
)
logging.getLogger(__name__).info(
f"ClaudeProviderHandler: No credentials found for user {self.user_id}, returning unauthenticated instance"
)
return ClaudeAuth(credentials_file=credentials_file, skip_initial_load=True)
# ------------------------------------------------------------------ #
# Claude CLI mode helpers #
# ------------------------------------------------------------------ #
@staticmethod
def _oauth_tokens_to_cli_credentials(tokens: dict) -> dict:
"""
Convert AISBF OAuth2 token dict to the Claude CLI .credentials.json schema:
AISBF stores: access_token, refresh_token, expires_at (seconds float), scope
CLI expects: claudeAiOauth.accessToken, .refreshToken, .expiresAt (ms int),
.scopes (list), .subscriptionType, .rateLimitTier
"""
default_scopes = [
'user:file_upload',
'user:inference',
'user:mcp_servers',
'user:profile',
'user:sessions:claude_code',
]
raw_scope = tokens.get('scope', '')
scopes = raw_scope.split() if raw_scope.strip() else default_scopes
expires_at_sec = tokens.get('expires_at', 0)
expires_at_ms = int(expires_at_sec * 1000) if expires_at_sec else 0
return {
'claudeAiOauth': {
'accessToken': tokens.get('access_token', ''),
'refreshToken': tokens.get('refresh_token', ''),
'expiresAt': expires_at_ms,
'scopes': scopes,
'subscriptionType': tokens.get('subscription_type', 'pro'),
'rateLimitTier': tokens.get('rate_limit_tier', 'default_claude_ai'),
}
}
def _get_cli_credentials(self) -> Optional[dict]:
"""
Return the Claude CLI .credentials.json content for this user/provider,
or None if CLI credentials are not available.
Priority order:
1. Explicit CLI credentials file (admin) or uploaded CLI credentials (DB user)
2. Derive from existing OAuth2 tokens (access + refresh) if available
"""
logger = _logging.getLogger(__name__)
if isinstance(self.provider_config, dict):
claude_cfg = self.provider_config.get('claude_config', {}) or {}
else:
claude_cfg = getattr(self.provider_config, 'claude_config', {}) or {}
if self.user_id is None:
# ── Config admin ──────────────────────────────────────────────
cli_file = claude_cfg.get('cli_credentials_file') if isinstance(claude_cfg, dict) else None
if cli_file:
expanded = os.path.expanduser(cli_file)
if not os.path.exists(expanded):
logger.warning(f"ClaudeCliMode: CLI credentials file not found: {expanded}")
else:
try:
with open(expanded) as fh:
return json.load(fh)
except Exception as exc:
logger.warning(f"ClaudeCliMode: failed to read CLI credentials file: {exc}")
# Fallback: derive from existing OAuth2 tokens
if self.auth and self.auth.tokens:
logger.info("ClaudeCliMode: building CLI credentials from existing OAuth2 tokens (admin)")
return self._oauth_tokens_to_cli_credentials(self.auth.tokens)
return None
else:
# ── DB user ───────────────────────────────────────────────────
try:
from ..database import DatabaseRegistry
db = DatabaseRegistry.get_config_database()
if db:
# 1. Check for explicit uploaded CLI credentials
row = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=self.provider_id,
auth_type='claude_cli_credentials',
)
if row and row.get('credentials'):
return row['credentials'].get('credentials')
# 2. Derive from existing OAuth2 tokens (always, not gated on use_cli_mode)
oauth_row = db.get_user_oauth2_credentials(
user_id=self.user_id,
provider_id=self.provider_id,
auth_type='claude_oauth2',
)
if oauth_row and oauth_row.get('credentials'):
tokens = oauth_row['credentials'].get('tokens', {})
if tokens:
logger.info(
f"ClaudeCliMode: building CLI credentials from "
f"OAuth2 tokens for user {self.user_id}"
)
return self._oauth_tokens_to_cli_credentials(tokens)
except Exception as exc:
logger.warning(f"ClaudeCliMode: failed to load credentials: {exc}")
return None
def _messages_to_cli_prompt(self, messages: List[Dict],
tools: Optional[List[Dict]] = None) -> str:
"""
Convert an OpenAI-style messages list (plus optional tool definitions)
to a flat text prompt for the claude CLI sent via stdin.
System messages and tool definitions are included as a prefix.
"""
system_parts: List[str] = []
turn_parts: List[str] = []
for msg in messages:
role = msg.get('role', '')
content = msg.get('content', '')
if isinstance(content, list):
fragments = []
for block in content:
if isinstance(block, dict) and block.get('type') == 'text':
fragments.append(block.get('text', ''))
elif isinstance(block, str):
fragments.append(block)
content = '\n'.join(fragments)
elif not isinstance(content, str):
content = str(content)
if role == 'system':
system_parts.append(content.strip())
elif role == 'user':
turn_parts.append(f'Human: {content}')
elif role == 'assistant':
turn_parts.append(f'Assistant: {content}')
if tools:
tools_json = json.dumps(tools, ensure_ascii=False)
system_parts.append(
f'Available tools (respond with tool_use blocks as needed):\n{tools_json}'
)
parts: List[str] = []
if system_parts:
parts.append('[System Instructions: ' + '\n'.join(system_parts) + ']')
parts.extend(turn_parts)
return '\n\n'.join(parts)
async def _cli_discover_models(self, config_dir: str) -> List['Model']:
"""
Ask the claude CLI which models it supports using --output-format json.
Returns a list of Model objects parsed from the JSON result.
The single-object JSON output format (not stream-json) is used here
because it carries a `modelUsage` map with real contextWindow metadata,
and the `result` text lists all models Claude knows about.
"""
import re
logger = _logging.getLogger(__name__)
env = os.environ.copy()
env['CLAUDE_CONFIG_DIR'] = config_dir
env['CLAUDE_CODE_USE_KEYCHAIN'] = 'false'
prompt = (
"Which models are you compatible with? "
"Give me only a JSON list without any other comment or word "
"except for the list of the model IDs."
)
cmd = [
'claude', '-p', prompt,
'--output-format', 'json',
'--dangerously-skip-permissions',
'--no-session-persistence',
]
logger.info(
"ClaudeCliMode: model discovery subprocess\n"
f" Replicate with: CLAUDE_CONFIG_DIR={config_dir} CLAUDE_CODE_USE_KEYCHAIN=false "
+ ' '.join(cmd)
)
process = await asyncio.create_subprocess_exec(
*cmd,
env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
process.communicate(), timeout=60.0
)
except asyncio.TimeoutError:
logger.error("ClaudeCliMode: model discovery subprocess timed out")
process.kill()
await process.wait()
return []
if stderr_bytes:
logger.debug(
f"ClaudeCliMode: discovery stderr:\n"
f"{stderr_bytes.decode('utf-8', errors='replace')[:2000]}"
)
stdout_str = stdout_bytes.decode('utf-8', errors='replace').strip()
logger.debug(f"ClaudeCliMode: discovery raw output: {stdout_str[:1000]}")
if not stdout_str:
logger.warning("ClaudeCliMode: model discovery returned empty output")
return []
try:
data = json.loads(stdout_str)
except json.JSONDecodeError as e:
logger.warning(f"ClaudeCliMode: model discovery JSON parse error: {e}")
return []
if data.get('is_error') or data.get('subtype') != 'success':
logger.warning(
f"ClaudeCliMode: model discovery error: {data.get('result', '')[:200]}"
)
return []
# modelUsage keys → real metadata (contextWindow, maxOutputTokens)
# Note: only models actually invoked in this call appear here; haiku is
# used for internal routing so it shows up even though we didn't ask for it.
model_usage: dict = data.get('modelUsage', {})
# result text contains the JSON list we asked for, possibly wrapped in
# a markdown code fence like ```json\n[...]\n```
result_text: str = data.get('result', '')
logger.info(f"ClaudeCliMode: discovery result: {result_text!r}")
# Parse the JSON array from result (strip code fences if present)
json_match = re.search(r'\[[\s\S]*?\]', result_text)
result_ids: set = set()
if json_match:
try:
parsed = json.loads(json_match.group())
if isinstance(parsed, list):
result_ids = {m for m in parsed if isinstance(m, str) and m.startswith('claude-')}
except json.JSONDecodeError:
pass
# Fall back to regex scan of the result text if JSON parse failed
if not result_ids:
result_ids = set(re.findall(r'claude-[a-z0-9][a-z0-9.\-]*[a-z0-9]', result_text))
logger.info(f"ClaudeCliMode: model IDs from result: {sorted(result_ids)}")
logger.info(f"ClaudeCliMode: model IDs from modelUsage: {sorted(model_usage.keys())}")
# Known context window overrides — avoids a costly second prompt.
# modelUsage carries real values for models used in this call; for the
# rest we apply these known constants rather than querying Claude again.
_known_context: dict = {
'claude-opus-4-7': 1000000,
}
# Union: result_ids is the authoritative list; modelUsage adds metadata
all_ids = result_ids | set(model_usage.keys())
if not all_ids:
return []
models = []
for mid in sorted(all_ids):
usage_meta = model_usage.get(mid, {})
context_size = (
usage_meta.get('contextWindow')
or _known_context.get(mid)
or 200000
)
max_output = usage_meta.get('maxOutputTokens')
m = Model(
id=mid,
name=mid,
provider_id=self.provider_id,
context_size=context_size,
context_length=context_size,
)
if max_output:
m.max_output_tokens = max_output
models.append(m)
return models
async def _handle_cli_streaming_request(self, prompt: str, model: str, config_dir: str,
tools: Optional[List[Dict]] = None):
"""
Spawn a claude CLI subprocess, stream its JSON output, and yield
OpenAI-compatible SSE chunks. Multiple parallel calls each get their
own subprocess; the config_dir is shared (read-only at runtime).
"""
logger = _logging.getLogger(__name__)
clean_model = model.split('/')[-1] if '/' in model else model
env = os.environ.copy()
env['CLAUDE_CONFIG_DIR'] = config_dir
env['CLAUDE_CODE_USE_KEYCHAIN'] = 'false'
cmd = [
'stdbuf', '-oL',
'claude', '-p',
'--output-format', 'stream-json',
'--input-format', 'stream-json',
'--verbose',
'--include-partial-messages',
'--permission-prompt-tool', 'stdio',
'--allowedTools', '',
'--no-session-persistence',
]
if tools:
cmd += ['--tools', json.dumps(tools, ensure_ascii=False)]
if clean_model:
cmd += ['--model', clean_model]
stdin_payload: Dict = {
'type': 'user_message',
'content': [{'type': 'text', 'text': prompt}],
}
input_msg = json.dumps(stdin_payload) + '\n'
# Log a shell-replicable command for debugging
cmd_str = ' '.join(cmd)
logger.info(
f"ClaudeCliMode: launching subprocess model={clean_model} dir={config_dir}\n"
f" Replicate with: CLAUDE_CONFIG_DIR={config_dir} CLAUDE_CODE_USE_KEYCHAIN=false "
f"{cmd_str} <<'EOF'\n{input_msg.strip()}\nEOF"
)
process = await asyncio.create_subprocess_exec(
*cmd,
env=env,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
process.stdin.write(input_msg.encode())
await process.stdin.drain()
process.stdin.close()
completion_id = f'chatcmpl-cli-{int(time.time())}'
created_time = int(time.time())
first_chunk = True
# State for accumulating tool_use blocks
# { block_index: {"id": ..., "name": ..., "arguments": ""} }
tool_blocks: dict = {}
tool_header_sent: set = set()
cli_prev_text_len: int = 0
try:
while True:
try:
raw = await asyncio.wait_for(process.stdout.readline(), timeout=120.0)
except asyncio.TimeoutError:
logger.error("ClaudeCliMode: subprocess read timeout (120 s)")
break
if not raw:
break
line_str = raw.decode('utf-8', errors='replace').strip()
if not line_str:
continue
logger.debug(f"ClaudeCliMode: raw event: {line_str}")
try:
data = json.loads(line_str)
except json.JSONDecodeError:
logger.debug(f"ClaudeCliMode: non-JSON line: {line_str}")
continue
event_type = data.get('type')
if event_type == 'content_block_start':
cb = data.get('content_block', {})
if cb.get('type') == 'tool_use':
idx = data.get('index', 0)
tool_blocks[idx] = {
'id': cb.get('id', f'call_{idx}'),
'name': cb.get('name', ''),
'arguments': '',
}
logger.debug(f"ClaudeCliMode: tool_use block started idx={idx} name={cb.get('name')}")
elif event_type == 'content_block_delta':
delta = data.get('delta', {})
idx = data.get('index', 0)
if delta.get('type') == 'text_delta':
text = delta.get('text', '')
if not text:
continue
if first_chunk:
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}]})}\n\n'
first_chunk = False
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}]})}\n\n'
elif delta.get('type') == 'input_json_delta' and idx in tool_blocks:
partial = delta.get('partial_json', '')
tool_blocks[idx]['arguments'] += partial
# Emit streaming tool_calls delta
if idx not in tool_header_sent:
tool_header_sent.add(idx)
if first_chunk:
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"role": "assistant", "content": None, "tool_calls": [{"index": idx, "id": tool_blocks[idx]["id"], "type": "function", "function": {"name": tool_blocks[idx]["name"], "arguments": ""}}]}, "finish_reason": None}]})}\n\n'
first_chunk = False
else:
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"tool_calls": [{"index": idx, "id": tool_blocks[idx]["id"], "type": "function", "function": {"name": tool_blocks[idx]["name"], "arguments": ""}}]}, "finish_reason": None}]})}\n\n'
if partial:
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"tool_calls": [{"index": idx, "function": {"arguments": partial}}]}, "finish_reason": None}]})}\n\n'
elif event_type == 'assistant':
# Claude CLI stream-json format: partial or final assistant message
msg = data.get('message', {})
last_text = ''
for block in msg.get('content', []):
if not isinstance(block, dict):
continue
btype = block.get('type')
if btype == 'text':
last_text += block.get('text', '')
elif btype == 'tool_use':
# Tool call in assistant event — register and emit if not yet seen
tc_id = block.get('id', f'call_{len(tool_blocks)}')
if tc_id not in tool_header_sent:
tool_header_sent.add(tc_id)
idx = len(tool_blocks)
tool_blocks[idx] = {
'id': tc_id,
'name': block.get('name', ''),
'arguments': json.dumps(block.get('input', {}), ensure_ascii=False),
}
role_delta = {'role': 'assistant', 'content': None} if first_chunk else {}
first_chunk = False
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {**role_delta, "tool_calls": [{"index": idx, "id": tc_id, "type": "function", "function": {"name": tool_blocks[idx]["name"], "arguments": tool_blocks[idx]["arguments"]}}]}, "finish_reason": None}]})}\n\n'
if last_text:
# Content is cumulative; emit only new characters
new_text = last_text[cli_prev_text_len:]
cli_prev_text_len = len(last_text)
if new_text:
if first_chunk:
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}]})}\n\n'
first_chunk = False
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"content": new_text}, "finish_reason": None}]})}\n\n'
elif event_type == 'result':
result_text = data.get('result', '')
logger.debug(f"ClaudeCliMode: result event, is_error={data.get('is_error')}, text_len={len(result_text)}")
# Only emit via result if we haven't already streamed content via other events
if result_text and first_chunk:
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}]})}\n\n'
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {"content": result_text}, "finish_reason": None}]})}\n\n'
first_chunk = False
break
elif event_type == 'message_stop':
logger.debug("ClaudeCliMode: received message_stop")
break
# Check if standard API token is configured (bypasses OAuth2)
api_token = self._get_api_token()
if api_token:
logger.info(f"[{self.provider_id}] Claude using API token authentication")
return True
# OAuth2 mode: check if auth object exists and is authenticated
if hasattr(self, 'auth') and self.auth:
is_auth = self.auth.is_authenticated()
if is_auth:
logger.info(f"[{self.provider_id}] Claude OAuth2 credentials are valid")
else:
logger.debug(f"ClaudeCliMode: unhandled event type={event_type}")
except Exception as exc:
logger.error(f"ClaudeCliMode: streaming error: {exc}", exc_info=True)
finally:
try:
stderr_bytes = await asyncio.wait_for(process.stderr.read(), timeout=2.0)
if stderr_bytes:
decoded = stderr_bytes.decode('utf-8', errors='replace')
logger.debug(f"ClaudeCliMode: stderr:\n{decoded[:2000]}")
except Exception:
pass
try:
process.terminate()
await asyncio.wait_for(process.wait(), timeout=5.0)
except Exception:
try:
process.kill()
except Exception:
pass
finish_reason = 'tool_calls' if tool_blocks else 'stop'
yield f'data: {json.dumps({"id": completion_id, "object": "chat.completion.chunk", "created": created_time, "model": f"{self.provider_id}/{clean_model}", "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}]})}\n\n'
yield 'data: [DONE]\n\n'
async def _handle_cli_request(self, prompt: str, model: str, config_dir: str,
tools: Optional[List[Dict]] = None) -> dict:
"""Non-streaming CLI request using --output-format json with prompt via stdin."""
logger = _logging.getLogger(__name__)
clean_model = model.split('/')[-1] if '/' in model else model
logger.error(f"[{self.provider_id}] Claude OAuth2 credentials are invalid or missing")
return is_auth
env = os.environ.copy()
env['CLAUDE_CONFIG_DIR'] = config_dir
env['CLAUDE_CODE_USE_KEYCHAIN'] = 'false'
cmd = [
'claude', '-p',
'--output-format', 'json',
'--dangerously-skip-permissions',
'--no-session-persistence',
]
if tools:
cmd += ['--tools', json.dumps(tools, ensure_ascii=False)]
if clean_model:
cmd += ['--model', clean_model]
logger.info(
f"ClaudeCliMode: non-streaming subprocess model={clean_model} dir={config_dir}\n"
f" Replicate with: CLAUDE_CONFIG_DIR={config_dir} CLAUDE_CODE_USE_KEYCHAIN=false "
+ ' '.join(cmd) + f" <<'EOF'\n{prompt[:200]}...\nEOF"
)
process = await asyncio.create_subprocess_exec(
*cmd,
env=env,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
process.communicate(input=prompt.encode()), timeout=120.0
)
except asyncio.TimeoutError:
logger.error("ClaudeCliMode: non-streaming subprocess timed out")
process.kill()
await process.wait()
return {
'id': f'chatcmpl-cli-{int(time.time())}',
'object': 'chat.completion',
'created': int(time.time()),
'model': f'{self.provider_id}/{clean_model}',
'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'Request timed out.'}, 'finish_reason': 'stop'}],
'usage': {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
}
if stderr_bytes:
logger.debug(f"ClaudeCliMode: stderr:\n{stderr_bytes.decode('utf-8', errors='replace')[:2000]}")
stdout_str = stdout_bytes.decode('utf-8', errors='replace').strip()
logger.debug(f"ClaudeCliMode: raw output: {stdout_str[:500]}")
result_text = ''
try:
data = json.loads(stdout_str)
if data.get('is_error'):
logger.warning(f"ClaudeCliMode: CLI returned error: {data.get('result', '')[:200]}")
result_text = data.get('result', '')
except json.JSONDecodeError:
result_text = stdout_str
return {
'id': f'chatcmpl-cli-{int(time.time())}',
'object': 'chat.completion',
'created': int(time.time()),
'model': f'{self.provider_id}/{clean_model}',
'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': result_text}, 'finish_reason': 'stop'}],
'usage': {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
}
def _init_session_identifiers(self):
"""Initialize persistent session identifiers (device_id, account_uuid, session_id)."""
import uuid
import hashlib
if not self.session_state.get('device_id'):
device_seed = f"{self.provider_id}-{time.time()}"
self.session_state['device_id'] = hashlib.sha256(device_seed.encode()).hexdigest()
if not self.session_state.get('account_uuid'):
account_id = self.auth.get_account_id()
if account_id:
self.session_state['account_uuid'] = account_id
else:
self.session_state['account_uuid'] = str(uuid.uuid4())
logger.error(f"[{self.provider_id}] No authentication mechanism configured for Claude")
return False
async def _initialize_session(self):
"""Initialize session by sending a quota request to get rate limit information."""
......
......@@ -93,33 +93,36 @@ class CodexProviderHandler(BaseProviderHandler):
self._use_api_key_mode = bool(api_key or _cfg_api_key)
self._account_id = None # Will be extracted from ID token in OAuth2 mode
# Set base URL from config (default endpoint)
# This will be overridden for OAuth2 mode when credentials are validated
_endpoint = (provider_config.get('endpoint') if isinstance(provider_config, dict)
else getattr(provider_config, 'endpoint', None)) if provider_config else None
self.base_url = _endpoint or "https://api.openai.com/v1"
def validate_credentials(self) -> bool:
"""
Validate Codex credentials.
In API key mode: checks if api_key is present and valid.
In OAuth2 mode: checks if OAuth2 is authenticated via is_authenticated().
Returns:
True if credentials are valid, False otherwise.
"""
import logging
logger = logging.getLogger(__name__)
# API Key Mode: Initialize OpenAI client with configured endpoint
if self._use_api_key_mode:
resolved_api_key = api_key or _cfg_api_key
self.client = OpenAI(
base_url=self.base_url,
api_key=resolved_api_key or "dummy",
default_headers={
"User-Agent": "codex-cli/1.0.0",
}
)
logger.info(f"CodexProviderHandler: Initialized in API Key mode with endpoint: {self.base_url}")
logger.info(f"[{self.provider_id}] Codex using API key mode")
if self.api_key and self.api_key != "placeholder":
logger.debug(f"[{self.provider_id}] Codex API key present")
return True
logger.error(f"[{self.provider_id}] Codex API key missing or placeholder")
return False
else:
# OAuth2 Mode: Check if OAuth2 is authenticated
# If authenticated, use ChatGPT backend; otherwise use configured endpoint
if self.oauth2.is_authenticated():
self.base_url = "https://chatgpt.com/backend-api"
logger.info(f"CodexProviderHandler: Initialized in OAuth2 mode with ChatGPT backend: {self.base_url}")
if hasattr(self, 'oauth2') and self.oauth2:
is_auth = self.oauth2.is_authenticated()
if is_auth:
logger.info(f"[{self.provider_id}] Codex OAuth2 credentials are valid")
else:
# Not yet authenticated, keep configured endpoint
logger.info(f"CodexProviderHandler: Initialized in OAuth2 mode (not authenticated yet) with endpoint: {self.base_url}")
self.client = None # Not used in OAuth2 mode
logger.error(f"[{self.provider_id}] Codex OAuth2 credentials are invalid or missing")
return is_auth
logger.error(f"[{self.provider_id}] No OAuth2 instance configured for Codex")
return False
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, issuer: str) -> CodexOAuth2:
"""
......
......@@ -38,6 +38,20 @@ class GoogleProviderHandler(BaseProviderHandler):
# Cache storage for Google Context Caching
self._cached_content_refs = {} # {cache_key: (cached_content_name, expiry_time)}
def validate_credentials(self) -> bool:
"""Validate Google (Gemini) API key presence."""
import logging
logger = logging.getLogger(__name__)
if not self.api_key:
logger.error(f"[{self.provider_id}] API key required but not provided")
return False
stripped = self.api_key.strip()
if not stripped or stripped.startswith('YOUR_'):
logger.error(f"[{self.provider_id}] Invalid API key format")
return False
logger.info(f"[{self.provider_id}] API key validated")
return True
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Union[Dict, object]:
......
......@@ -127,6 +127,37 @@ class KiloProviderHandler(BaseProviderHandler):
self.client = OpenAI(base_url=endpoint, api_key=api_key or "placeholder")
def validate_credentials(self) -> bool:
"""
Validate Kilo credentials.
In API key mode: checks if api_key is present and valid (not placeholder).
In OAuth2 mode: checks if OAuth2 is authenticated via is_authenticated().
Returns:
True if credentials are valid, False otherwise.
"""
import logging
logger = logging.getLogger(__name__)
if self._use_api_key_auth:
logger.info(f"[{self.provider_id}] Kilo using API key mode")
if self.api_key and self.api_key != "placeholder":
logger.debug(f"[{self.provider_id}] Kilo API key present")
return True
logger.error(f"[{self.provider_id}] Kilo API key missing or placeholder")
return False
else:
if hasattr(self, 'oauth2') and self.oauth2:
is_auth = self.oauth2.is_authenticated()
if is_auth:
logger.info(f"[{self.provider_id}] Kilo OAuth2 credentials are valid")
else:
logger.error(f"[{self.provider_id}] Kilo OAuth2 credentials are invalid or missing")
return is_auth
logger.error(f"[{self.provider_id}] No OAuth2 instance configured for Kilo")
return False
def _load_oauth2_from_db(self, provider_id: str, credentials_file: str, api_base: str):
"""
Load OAuth2 credentials:
......
......@@ -52,40 +52,45 @@ class KiroProviderHandler(BaseProviderHandler):
"""
def __init__(self, provider_id: str, api_key: str):
super().__init__(provider_id, api_key)
self.provider_config = config.get_provider(provider_id)
# Don't load provider_config here — get_provider_handler will set it after creation
self.region = "us-east-1" # Default region
# Import AuthType for checking auth type
from ...auth.kiro import AuthType
self.AuthType = AuthType
# Initialize KiroAuthManager with credentials from config
# Initialize KiroAuthManager lazily on first use
self.auth_manager = None
self._init_auth_manager()
self._kiro_config = None # Will be populated from provider_config
# HTTP client for making requests
self.client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=30.0))
def _init_auth_manager(self):
"""Initialize KiroAuthManager with credentials from config"""
try:
def _ensure_auth_manager(self):
"""Initialize auth manager if not already done, using current provider_config."""
if self.auth_manager is not None:
return
from ...auth.kiro import KiroAuthManager
# Get Kiro-specific configuration from provider config
kiro_config = getattr(self.provider_config, 'kiro_config', None)
# Get kiro_config from provider_config (set by get_provider_handler)
provider_config = getattr(self, 'provider_config', None) or getattr(self, 'user_provider_config', None)
if not provider_config:
# Fallback to global config (shouldn't normally happen)
from ...config import config
provider_config = config.get_provider(self.provider_id)
if not kiro_config:
logging.warning(f"No kiro_config found in provider {self.provider_id}, using defaults")
kiro_config = {}
# Extract kiro_config (handle dict or object)
kiro_config = getattr(provider_config, 'kiro_config', {}) if hasattr(provider_config, 'kiro_config') else provider_config.get('kiro_config', {})
# Extract credentials from provider config
# Extract credential parameters
refresh_token = kiro_config.get('refresh_token') if isinstance(kiro_config, dict) else None
profile_arn = kiro_config.get('profile_arn') if isinstance(kiro_config, dict) else None
region = kiro_config.get('region', 'us-east-1') if isinstance(kiro_config, dict) else 'us-east-1'
creds_file = kiro_config.get('creds_file') if isinstance(kiro_config, dict) else None
sqlite_db = kiro_config.get('sqlite_db') if isinstance(kiro_config, dict) else None
client_id = kiro_config.get('client_id') if isinstance(kiro_config, dict) else None
client_secret = kiro_config.get('client_secret') if isinstance(kiro_config, dict) else None
region = kiro_config.get('region', 'us-east-1') if isinstance(kiro_config, dict) else getattr(provider_config, 'region', 'us-east-1')
creds_file = kiro_config.get('creds_file') if isinstance(kiro_config, dict) else getattr(provider_config, 'creds_file', None)
sqlite_db = kiro_config.get('sqlite_db') if isinstance(kiro_config, dict) else getattr(provider_config, 'sqlite_db', None)
client_id = kiro_config.get('client_id') if isinstance(kiro_config, dict) else getattr(provider_config, 'client_id', None)
client_secret = kiro_config.get('client_secret') if isinstance(kiro_config, dict) else getattr(provider_config, 'client_secret', None)
self.region = region
......@@ -100,11 +105,51 @@ class KiroProviderHandler(BaseProviderHandler):
client_secret=client_secret
)
logging.info(f"KiroProviderHandler: Auth manager initialized for region {region}")
def validate_credentials(self) -> bool:
"""
Validate Kiro-specific credentials.
Checks that credential files (creds_file or sqlite_db) exist and
that auth manager can successfully load credentials. Also validates
token/profile presence based on storage type.
"""
try:
self._ensure_auth_manager()
except Exception as e:
logging.error(f"Failed to initialize KiroAuthManager: {e}")
self.auth_manager = None
logging.error(f"[{self.provider_id}] Failed to initialize auth manager: {e}")
return False
if not self.auth_manager:
logging.error(f"[{self.provider_id}] Auth manager not initialized")
return False
# Check for credential sources
creds_file = getattr(self.auth_manager, 'creds_file', None)
sqlite_db = getattr(self.auth_manager, 'sqlite_db', None)
refresh_token = getattr(self.auth_manager, 'refresh_token', None)
profile_arn = getattr(self.auth_manager, 'profile_arn', None)
has_creds_file = creds_file and Path(creds_file).expanduser().exists()
has_sqlite_db = sqlite_db and Path(sqlite_db).expanduser().exists()
has_token = bool(refresh_token or profile_arn)
if not (has_creds_file or has_sqlite_db or has_token):
logging.error(
f"[{self.provider_id}] No Kiro credentials found. "
f"Need creds_file, sqlite_db, or refresh_token/profile_arn in kiro_config."
)
return False
if creds_file and not has_creds_file:
logging.error(f"[{self.provider_id}] Kiro creds_file not found: {creds_file}")
return False
if sqlite_db and not has_sqlite_db:
logging.error(f"[{self.provider_id}] Kiro sqlite_db not found: {sqlite_db}")
return False
logging.info(f"[{self.provider_id}] Kiro credentials validated successfully")
return True
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
......@@ -121,6 +166,9 @@ class KiroProviderHandler(BaseProviderHandler):
logging.info(f"KiroProviderHandler: Messages count: {len(messages)}")
logging.info(f"KiroProviderHandler: Tools count: {len(tools) if tools else 0}")
# Ensure auth manager is initialized and credentials are valid
self._ensure_auth_manager()
if not self.auth_manager:
raise Exception("Kiro authentication not configured. Please set kiro_config in provider configuration.")
......
......@@ -39,6 +39,23 @@ class OllamaProviderHandler(BaseProviderHandler):
)
self.client = httpx.AsyncClient(base_url=config.providers[provider_id].endpoint, timeout=timeout)
def validate_credentials(self) -> bool:
"""
Validate Ollama credentials.
Ollama typically runs locally without authentication.
If an API key is configured, it's noted but not required for validation.
Returns:
Always True (Ollama doesn't require credential validation).
"""
import logging
logger = logging.getLogger(__name__)
logger.debug(f"[{self.provider_id}] Ollama provider - no credentials required (local or trusted endpoint)")
if self.api_key:
logger.debug(f"[{self.provider_id}] Ollama API key is configured (optional)")
return True
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Dict:
......
......@@ -34,6 +34,20 @@ class OpenAIProviderHandler(BaseProviderHandler):
super().__init__(provider_id, api_key)
self.client = OpenAI(base_url=config.providers[provider_id].endpoint, api_key=api_key)
def validate_credentials(self) -> bool:
"""Validate OpenAI API key presence."""
if not self.api_key:
logging.error(f"[{self.provider_id}] API key required but not provided")
return False
stripped = self.api_key.strip()
if not stripped or stripped.startswith('YOUR_'):
logging.error(f"[{self.provider_id}] API key appears to be a placeholder")
return False
logging.info(f"[{self.provider_id}] API key validated")
return True
async def handle_request(self, model: str, messages: List[Dict], max_tokens: Optional[int] = None,
temperature: Optional[float] = 1.0, stream: Optional[bool] = False,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None) -> Union[Dict, object]:
......
......@@ -106,6 +106,53 @@ class QwenProviderHandler(BaseProviderHandler):
# OpenAI SDK client (will be configured dynamically with OAuth token)
self._sdk_client = None
def validate_credentials(self) -> bool:
"""
Validate Qwen credentials.
In API key mode: checks if qwen_config.api_key is present and valid.
In OAuth2 mode: checks if OAuth2 is authenticated via is_authenticated().
Note: As of April 2026, Qwen OAuth2 service has been discontinued.
API key authentication is the recommended method.
Returns:
True if credentials are valid, False otherwise.
"""
import logging
logger = logging.getLogger(__name__)
# Check if API key mode is configured
if isinstance(self.provider_config, dict):
qwen_config = self.provider_config.get('qwen_config')
else:
qwen_config = getattr(self.provider_config, 'qwen_config', None)
api_key = qwen_config.get('api_key') if qwen_config and isinstance(qwen_config, dict) else None
if api_key:
logger.info(f"[{self.provider_id}] Qwen using API key authentication")
if api_key and api_key != "placeholder":
logger.debug(f"[{self.provider_id}] Qwen API key present")
return True
logger.error(f"[{self.provider_id}] Qwen API key is placeholder or missing")
return False
else:
# OAuth2 mode (discontinued but code maintained for future)
if hasattr(self, 'auth') and self.auth:
is_auth = self.auth.is_authenticated()
if is_auth:
logger.info(f"[{self.provider_id}] Qwen OAuth2 credentials are valid")
else:
logger.error(f"[{self.provider_id}] Qwen OAuth2 credentials are invalid or missing")
logger.warning(
"Qwen OAuth2 service has been discontinued by Qwen. "
"Tokens obtained from chat.qwen.ai are no longer accepted by DashScope API. "
"Please use API key authentication instead."
)
return is_auth
logger.error(f"[{self.provider_id}] No authentication method configured for Qwen")
return False
def _load_auth_from_db(self, provider_id: str, credentials_file: str):
"""
Load OAuth2 credentials:
......
......@@ -1002,45 +1002,6 @@ async def refresh_model_cache():
except Exception as e:
logger.error(f"Error in model cache refresh task: {e}")
def validate_kiro_credentials(provider_id: str, provider_config) -> bool:
"""
Validate that kiro/kiro-cli credentials are available and accessible.
Args:
provider_id: Provider identifier (e.g., 'kiro', 'kiro-cli')
provider_config: Provider configuration object
Returns:
True if credentials are valid and accessible, False otherwise
"""
# Only validate kiro-type providers
if not hasattr(provider_config, 'type') or provider_config.type != 'kiro':
return True # Not a kiro provider, no validation needed
# Check if kiro_config exists
if not hasattr(provider_config, 'kiro_config'):
logger.debug(f"Provider {provider_id}: No kiro_config found")
return False
kiro_config = provider_config.kiro_config
# Handle both dict and object access patterns
def get_config_value(config, key):
"""Get value from config whether it's a dict or object"""
if isinstance(config, dict):
return config.get(key)
return getattr(config, key, None)
# Check for credentials file (kiro IDE)
creds_file_path = get_config_value(kiro_config, 'creds_file')
if creds_file_path:
creds_file = Path(creds_file_path).expanduser()
if not creds_file.exists():
logger.debug(f"Provider {provider_id}: Credentials file not found: {creds_file}")
return False
# Try to load and validate the credentials file
try:
with open(creds_file, 'r') as f:
data = json.load(f)
......@@ -1109,63 +1070,19 @@ async def get_provider_models(provider_id: str, provider_config, user_id: Option
"""Get models for a provider from local config or cache"""
global _model_cache, _model_cache_timestamps
# Check if provider requires API key and if it's configured
api_key_required = getattr(provider_config, 'api_key_required', False)
api_key = getattr(provider_config, 'api_key', None)
# If API key is required but not configured or is placeholder, skip this provider
if api_key_required:
if not api_key or api_key.startswith('YOUR_'):
logger.debug(f"Skipping provider {provider_id}: API key required but not configured")
return []
# Validate provider authentication status
provider_type = getattr(provider_config, 'type', '')
# Validate kiro/kiro-cli credentials
if provider_type in ('kiro', 'kiro-cli'):
if not validate_kiro_credentials(provider_id, provider_config):
logger.debug(f"Skipping provider {provider_id}: Kiro credentials not available or invalid")
return []
# Validate Codex OAuth2 credentials
if provider_type == 'codex':
# Validate provider credentials using the provider handler's validation
# This ensures consistency with request-time validation
try:
from aisbf.auth.codex import CodexOAuth2
codex_config = getattr(provider_config, 'codex_config', {})
credentials_file = codex_config.get('credentials_file', '~/.aisbf/codex_credentials.json')
auth = CodexOAuth2(credentials_file=credentials_file)
if not auth.is_authenticated():
logger.debug(f"Skipping provider {provider_id}: Codex OAuth2 not authenticated")
return []
from aisbf.providers import get_provider_handler
api_key = getattr(provider_config, 'api_key', None)
# Create a temporary handler to validate credentials
# This will call the provider-specific validate_credentials() method
handler = get_provider_handler(provider_id, api_key, user_id=user_id)
# If we got here, credentials are valid (get_provider_handler validates)
except Exception as e:
logger.debug(f"Codex auth check failed for {provider_id}: {e}")
# Validate Qwen OAuth2 credentials
if provider_type == 'qwen':
try:
from aisbf.auth.qwen import QwenOAuth2
qwen_config = getattr(provider_config, 'qwen_config', {})
credentials_file = qwen_config.get('credentials_file', '~/.aisbf/qwen_credentials.json')
auth = QwenOAuth2(credentials_file=credentials_file)
if not auth.is_authenticated():
logger.debug(f"Skipping provider {provider_id}: Qwen OAuth2 not authenticated")
# If validation fails, silently skip this provider for model listing
logger.debug(f"Skipping provider {provider_id}: Credential validation failed - {e}")
return []
except Exception as e:
logger.debug(f"Qwen auth check failed for {provider_id}: {e}")
# Validate Claude OAuth2 credentials
if provider_type == 'claude':
try:
from aisbf.auth.claude import ClaudeAuth
claude_config = getattr(provider_config, 'claude_config', {})
credentials_file = claude_config.get('credentials_file', '~/.claude_credentials.json')
auth = ClaudeAuth(credentials_file=credentials_file)
if not auth.is_authenticated():
logger.debug(f"Skipping provider {provider_id}: Claude OAuth2 not authenticated")
return []
except Exception as e:
logger.debug(f"Claude auth check failed for {provider_id}: {e}")
current_time = int(time.time())
......@@ -11349,15 +11266,6 @@ async def v1_chat_completions(request: Request, body: ChatCompletionRequest):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request (only for kiro-type providers)
provider_config = config.get_provider(provider_id)
provider_type = getattr(provider_config, 'type', '')
if provider_type in ('kiro', 'kiro-cli') and not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
# Handle as direct provider request
body_dict['model'] = actual_model
# Get user-specific handler
......@@ -11574,15 +11482,6 @@ async def v1_audio_transcriptions(request: Request):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request (only for kiro-type providers)
provider_config = config.get_provider(provider_id)
provider_type = getattr(provider_config, 'type', '')
if provider_type in ('kiro', 'kiro-cli') and not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
# Get user-specific handler
user_id = getattr(request.state, 'user_id', None)
handler = get_user_handler('request', user_id)
......@@ -11651,15 +11550,6 @@ async def v1_audio_speech(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request (only for kiro-type providers)
provider_config = config.get_provider(provider_id)
provider_type = getattr(provider_config, 'type', '')
if provider_type in ('kiro', 'kiro-cli') and not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
# Get user-specific handler
user_id = getattr(request.state, 'user_id', None)
......@@ -11719,15 +11609,6 @@ async def v1_image_generations(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request (only for kiro-type providers)
provider_config = config.get_provider(provider_id)
provider_type = getattr(provider_config, 'type', '')
if provider_type in ('kiro', 'kiro-cli') and not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
# Get user-specific handler
user_id = getattr(request.state, 'user_id', None)
......@@ -11787,14 +11668,6 @@ async def v1_embeddings(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
# Get user-specific handler
user_id = getattr(request.state, 'user_id', None)
......@@ -11988,51 +11861,22 @@ async def list_autoselection_models():
return await list_autoselect_models()
@app.post("/api/{provider_id}/chat/completions")
async def chat_completions(provider_id: str, request: Request, body: ChatCompletionRequest):
logger.info(f"=== CHAT COMPLETION REQUEST START ===")
logger.info(f"Request path: {request.url.path}")
async def provider_chat_completions(request: Request, provider_id: str, body: dict):
"""Chat completions endpoint for a specific provider (non-streaming)"""
logger.info("=== PROVIDER CHAT COMPLETIONS REQUEST ===")
logger.info(f"Provider ID: {provider_id}")
logger.info(f"Request headers: {dict(request.headers)}")
logger.info(f"Request body: {body}")
logger.info(f"Available providers: {list(config.providers.keys())}")
logger.info(f"Available rotations: {list(config.rotations.keys())}")
logger.info(f"Available autoselect: {list(config.autoselect.keys())}")
logger.debug(f"Request headers: {dict(request.headers)}")
logger.debug(f"Request body: {body}")
body_dict = body.model_dump()
# Get user-specific handler based on the type
user_id = getattr(request.state, 'user_id', None)
# Check if it's an autoselect
if provider_id in config.autoselect or (user_id and provider_id in get_user_handler('autoselect', user_id).user_autoselects):
logger.debug("Handling autoselect request")
token_id = getattr(request.state, 'token_id', None)
handler = get_user_handler('autoselect', user_id)
try:
if body.stream:
logger.debug("Handling streaming autoselect request")
return await handler.handle_autoselect_streaming_request(provider_id, body_dict)
else:
logger.debug("Handling non-streaming autoselect request")
result = await handler.handle_autoselect_request(provider_id, body_dict, user_id, token_id)
logger.debug(f"Autoselect response result: {result}")
return result
except Exception as e:
logger.error(f"Error handling autoselect: {str(e)}", exc_info=True)
raise
logger.info(f"Body: {body}")
# Check if it's a rotation
if provider_id in config.rotations or (user_id and provider_id in get_user_handler('rotation', user_id).rotations):
logger.info(f"Provider ID '{provider_id}' found in rotations")
logger.debug("Handling rotation request")
token_id = getattr(request.state, 'token_id', None)
handler = get_user_handler('rotation', user_id)
return await handler.handle_rotation_request(provider_id, body_dict, user_id, token_id)
# Parse model in format 'provider/model'
model = body.get('model', '')
if '/' not in model:
raise HTTPException(
status_code=400,
detail="Model must be in format 'provider/model', 'rotation/name', or 'autoselect/name'"
)
# Check if it's a provider
handler = get_user_handler('request', user_id)
parts = model.split('/', 1)
actual_model = parts[1]
if provider_id not in config.providers and (not user_id or provider_id not in handler.user_providers):
logger.error(f"Provider ID '{provider_id}' not found in providers")
logger.error(f"Available providers: {list(config.providers.keys())}")
......@@ -12042,15 +11886,8 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet
logger.info(f"Provider ID '{provider_id}' found in providers")
provider_config = handler.user_providers.get(provider_id) if user_id and provider_id in handler.user_providers else config.get_provider(provider_id)
logger.debug(f"Provider config: {provider_config}")
# Validate kiro credentials before processing request
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
provider_config_source = handler.user_providers.get(provider_id) if user_id and provider_id in handler.user_providers else config.get_provider(provider_id)
logger.debug(f"Provider config source: {provider_config_source}")
try:
if body.stream:
......@@ -12384,14 +12221,6 @@ async def embeddings(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
# Get user-specific handler
user_id = getattr(request.state, 'user_id', None)
......@@ -14127,12 +13956,6 @@ async def user_chat_completions_by_username(request: Request, username: str, bod
provider_config = handler.user_providers[provider_name]
if not validate_kiro_credentials(provider_name, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_name}' credentials not available."
)
# Extract actual model name: if format is "provider/model", keep only "model" part
if actual_model.startswith(f"{provider_name}/"):
actual_model_name = actual_model[len(provider_name)+1:]
......@@ -14183,12 +14006,6 @@ async def user_chat_completions_by_username(request: Request, username: str, bod
if provider_id in config.providers:
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available."
)
body_dict['model'] = actual_model
handler = get_user_handler('request', None)
......@@ -14490,13 +14307,6 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl
provider_config = handler.user_providers[actual_model]
# Validate kiro credentials
if not validate_kiro_credentials(actual_model, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{actual_model}' credentials not available."
)
body_dict['model'] = actual_model
if body.stream:
......@@ -14538,13 +14348,6 @@ async def user_chat_completions(request: Request, username: str, body: ChatCompl
if provider_id in config.providers:
provider_config = config.get_provider(provider_id)
# Validate kiro credentials
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available."
)
body_dict['model'] = actual_model
handler = get_user_handler('request', None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment