Commit 4718ce47 authored by Your Name's avatar Your Name

Add credential validation for kiro/kiro-cli providers

- Added validate_kiro_credentials() function to check credential availability
- Validates kiro IDE credentials from ~/.config/Code/User/globalStorage/amazon.q/credentials.json
- Validates kiro-cli credentials from ~/.local/share/kiro-cli/data.sqlite3
- Integrated validation into get_provider_models() to exclude providers without credentials
- Added validation checks to all request endpoints (chat, audio, images, embeddings)
- Providers only appear in model listings and accept requests when credentials are valid
- Returns HTTP 403 when credentials are missing or invalid
parent 1da131cd
......@@ -591,6 +591,100 @@ async def refresh_model_cache():
except Exception as e:
logger.error(f"Error in model cache refresh task: {e}")
def validate_kiro_credentials(provider_id: str, provider_config) -> bool:
"""
Validate that kiro/kiro-cli credentials are available and accessible.
Args:
provider_id: Provider identifier (e.g., 'kiro', 'kiro-cli')
provider_config: Provider configuration object
Returns:
True if credentials are valid and accessible, False otherwise
"""
# Only validate kiro-type providers
if not hasattr(provider_config, 'type') or provider_config.type != 'kiro':
return True # Not a kiro provider, no validation needed
# Check if kiro_config exists
if not hasattr(provider_config, 'kiro_config'):
logger.debug(f"Provider {provider_id}: No kiro_config found")
return False
kiro_config = provider_config.kiro_config
# Check for credentials file (kiro IDE)
if hasattr(kiro_config, 'creds_file') and kiro_config.creds_file:
creds_file = Path(kiro_config.creds_file).expanduser()
if not creds_file.exists():
logger.debug(f"Provider {provider_id}: Credentials file not found: {creds_file}")
return False
# Try to load and validate the credentials file
try:
with open(creds_file, 'r') as f:
data = json.load(f)
# Check for required fields
if not data.get('refreshToken') and not data.get('accessToken'):
logger.debug(f"Provider {provider_id}: No valid tokens in credentials file")
return False
logger.debug(f"Provider {provider_id}: Valid credentials file found")
return True
except Exception as e:
logger.debug(f"Provider {provider_id}: Error reading credentials file: {e}")
return False
# Check for SQLite database (kiro-cli)
if hasattr(kiro_config, 'sqlite_db') and kiro_config.sqlite_db:
sqlite_db = Path(kiro_config.sqlite_db).expanduser()
if not sqlite_db.exists():
logger.debug(f"Provider {provider_id}: SQLite database not found: {sqlite_db}")
return False
# Try to check if the database has valid tokens
try:
import sqlite3
conn = sqlite3.connect(str(sqlite_db))
cursor = conn.cursor()
# Check for token keys
token_keys = [
"kirocli:social:token",
"kirocli:odic:token",
"codewhisperer:odic:token"
]
found_token = False
for key in token_keys:
cursor.execute("SELECT value FROM auth_kv WHERE key = ?", (key,))
row = cursor.fetchone()
if row:
try:
token_data = json.loads(row[0])
if token_data.get('access_token') or token_data.get('refresh_token'):
found_token = True
break
except:
pass
conn.close()
if not found_token:
logger.debug(f"Provider {provider_id}: No valid tokens in SQLite database")
return False
logger.debug(f"Provider {provider_id}: Valid SQLite credentials found")
return True
except Exception as e:
logger.debug(f"Provider {provider_id}: Error reading SQLite database: {e}")
return False
# No valid credential source found
logger.debug(f"Provider {provider_id}: No valid credential source configured")
return False
async def get_provider_models(provider_id: str, provider_config) -> list:
"""Get models for a provider from local config or cache"""
global _model_cache, _model_cache_timestamps
......@@ -605,6 +699,11 @@ async def get_provider_models(provider_id: str, provider_config) -> list:
logger.debug(f"Skipping provider {provider_id}: API key required but not configured")
return []
# Validate kiro/kiro-cli credentials
if not validate_kiro_credentials(provider_id, provider_config):
logger.debug(f"Skipping provider {provider_id}: Kiro credentials not available or invalid")
return []
# If provider has local model config, use it
if hasattr(provider_config, 'models') and provider_config.models:
models = []
......@@ -1651,6 +1750,14 @@ async def v1_chat_completions(request: Request, body: ChatCompletionRequest):
detail=f"Provider '{provider_id}' not found. Available providers: {list(config.providers.keys())}, or use 'rotation/name' or 'autoselect/name'"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
# Handle as direct provider request
body_dict['model'] = actual_model
if body.stream:
......@@ -1811,6 +1918,14 @@ async def v1_audio_transcriptions(request: Request):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
# Create new form data with updated model
from starlette.datastructures import FormData
updated_form = FormData()
......@@ -1875,6 +1990,14 @@ async def v1_audio_speech(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
return await request_handler.handle_text_to_speech(request, provider_id, body)
......@@ -1931,6 +2054,14 @@ async def v1_image_generations(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
return await request_handler.handle_image_generation(request, provider_id, body)
......@@ -1987,6 +2118,14 @@ async def v1_embeddings(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
return await request_handler.handle_embeddings(request, provider_id, body)
......@@ -2207,6 +2346,13 @@ async def chat_completions(provider_id: str, request: Request, body: ChatComplet
provider_config = config.get_provider(provider_id)
logger.debug(f"Provider config: {provider_config}")
# Validate kiro credentials before processing request
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
try:
if body.stream:
......@@ -2502,6 +2648,14 @@ async def embeddings(request: Request, body: dict):
detail=f"Provider '{provider_id}' not found. Available: {list(config.providers.keys())}"
)
# Validate kiro credentials before processing request
provider_config = config.get_provider(provider_id)
if not validate_kiro_credentials(provider_id, provider_config):
raise HTTPException(
status_code=403,
detail=f"Provider '{provider_id}' credentials not available. Please configure credentials for this provider."
)
body['model'] = actual_model
return await request_handler.handle_embeddings(request, provider_id, body)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment