Commit 2e7a8460 authored by Your Name's avatar Your Name

feat: Add Kiro AWS Event Stream parsing, converters, and TODO roadmap

- Add aisbf/kiro_parsers.py: AWS Event Stream parser for Kiro API responses
- Update kiro_converters_openai.py: Add build_kiro_payload_from_dict function
- Update kiro_converters.py: Minor fixes
- Update kiro_auth.py: Add AWS SSO OIDC authentication support
- Update handlers.py: Enhance streaming and error handling
- Update main.py: Add proxy headers middleware and configuration
- Update setup.py: Add version bump
- Add TODO.md: Comprehensive roadmap for caching and performance improvements

Features:
- Kiro AWS Event Stream parsing for non-streaming responses
- OpenAI-to-Kiro payload conversion
- AWS SSO OIDC authentication for Kiro
- Proxy headers middleware for reverse proxy support
- TODO roadmap with prioritized items for future development
parent 52b44029
This diff is collapsed.
This diff is collapsed.
...@@ -100,9 +100,13 @@ class KiroAuthManager: ...@@ -100,9 +100,13 @@ class KiroAuthManager:
"""Load credentials from SQLite database (kiro-cli)""" """Load credentials from SQLite database (kiro-cli)"""
if not self.sqlite_db: if not self.sqlite_db:
return return
# Expand ~ in path
db_path = Path(self.sqlite_db).expanduser()
conn = None
try: try:
conn = sqlite3.connect(self.sqlite_db) conn = sqlite3.connect(str(db_path))
cursor = conn.cursor() cursor = conn.cursor()
# Try to find token in SQLite # Try to find token in SQLite
...@@ -120,9 +124,13 @@ class KiroAuthManager: ...@@ -120,9 +124,13 @@ class KiroAuthManager:
if token_data: if token_data:
self._access_token = token_data.get('access_token') self._access_token = token_data.get('access_token')
self._refresh_token = token_data.get('refresh_token') self._refresh_token = token_data.get('refresh_token')
self.refresh_token = token_data.get('refresh_token') # Update public refresh_token too
self._expires_at = datetime.fromisoformat( self._expires_at = datetime.fromisoformat(
token_data.get('expires_at', '1970-01-01T00:00:00Z') token_data.get('expires_at', '1970-01-01T00:00:00Z')
) )
# Also try to get profile_arn from token data
if 'profile_arn' in token_data:
self.profile_arn = token_data['profile_arn']
logger.info(f"Loaded credentials from SQLite key: {token_key}") logger.info(f"Loaded credentials from SQLite key: {token_key}")
# Try to get device registration for AWS SSO OIDC # Try to get device registration for AWS SSO OIDC
...@@ -133,23 +141,60 @@ class KiroAuthManager: ...@@ -133,23 +141,60 @@ class KiroAuthManager:
reg_data = json.loads(row[0]) reg_data = json.loads(row[0])
self.client_id = reg_data.get('clientId') self.client_id = reg_data.get('clientId')
self.client_secret = reg_data.get('clientSecret') self.client_secret = reg_data.get('clientSecret')
# Also check for profile_arn in registration data
if 'profileArn' in reg_data:
self.profile_arn = reg_data['profileArn']
break break
# If profile_arn still not found, try to query it directly from the database
if not self.profile_arn:
# Try common profile ARN keys
profile_keys = [
"kirocli:profile:arn",
"codewhisperer:profile:arn",
"kirocli:social:profile",
"codewhisperer:social:profile"
]
for profile_key in profile_keys:
cursor.execute("SELECT value FROM auth_kv WHERE key = ?", (profile_key,))
row = cursor.fetchone()
if row:
try:
profile_data = json.loads(row[0])
if isinstance(profile_data, dict):
self.profile_arn = profile_data.get('arn') or profile_data.get('profileArn')
elif isinstance(profile_data, str):
self.profile_arn = profile_data
if self.profile_arn:
logger.info(f"Loaded profile ARN from SQLite key: {profile_key}")
break
except json.JSONDecodeError:
# Value might be a plain string
self.profile_arn = row[0]
logger.info(f"Loaded profile ARN (plain string) from SQLite key: {profile_key}")
break
except Exception as e: except Exception as e:
logger.error(f"Failed to load from SQLite: {e}") logger.error(f"Failed to load from SQLite: {e}")
finally: finally:
conn.close() if conn:
conn.close()
def _load_from_creds_file(self): def _load_from_creds_file(self):
"""Load credentials from JSON file""" """Load credentials from JSON file"""
if not self.creds_file: if not self.creds_file:
return return
# Expand ~ in path
creds_path = Path(self.creds_file).expanduser()
try: try:
with open(self.creds_file, 'r') as f: with open(creds_path, 'r') as f:
data = json.load(f) data = json.load(f)
self.refresh_token = data.get('refreshToken', self.refresh_token) refresh_token_value = data.get('refreshToken', self.refresh_token)
self.refresh_token = refresh_token_value
self._refresh_token = refresh_token_value # Keep private token in sync
self._access_token = data.get('accessToken') self._access_token = data.get('accessToken')
self.profile_arn = data.get('profileArn', self.profile_arn) self.profile_arn = data.get('profileArn', self.profile_arn)
...@@ -205,6 +250,7 @@ class KiroAuthManager: ...@@ -205,6 +250,7 @@ class KiroAuthManager:
self._access_token = data['accessToken'] self._access_token = data['accessToken']
if 'refreshToken' in data: if 'refreshToken' in data:
self.refresh_token = data['refreshToken'] self.refresh_token = data['refreshToken']
self._refresh_token = data['refreshToken'] # Keep private token in sync
# Calculate expiration (1 hour default) # Calculate expiration (1 hour default)
self._expires_at = datetime.now(timezone.utc) + timedelta(seconds=3600) self._expires_at = datetime.now(timezone.utc) + timedelta(seconds=3600)
...@@ -234,6 +280,7 @@ class KiroAuthManager: ...@@ -234,6 +280,7 @@ class KiroAuthManager:
self._access_token = data['access_token'] self._access_token = data['access_token']
if 'refresh_token' in data: if 'refresh_token' in data:
self.refresh_token = data['refresh_token'] self.refresh_token = data['refresh_token']
self._refresh_token = data['refresh_token'] # Keep private token in sync
expires_in = data.get('expires_in', 3600) expires_in = data.get('expires_in', 3600)
self._expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) self._expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
...@@ -242,9 +289,12 @@ class KiroAuthManager: ...@@ -242,9 +289,12 @@ class KiroAuthManager:
"""Save updated credentials to file""" """Save updated credentials to file"""
if not self.creds_file: if not self.creds_file:
return return
# Expand ~ in path
creds_path = Path(self.creds_file).expanduser()
try: try:
with open(self.creds_file, 'r') as f: with open(creds_path, 'r') as f:
data = json.load(f) data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError): except (FileNotFoundError, json.JSONDecodeError):
data = {} data = {}
...@@ -257,7 +307,7 @@ class KiroAuthManager: ...@@ -257,7 +307,7 @@ class KiroAuthManager:
'region': self.region 'region': self.region
}) })
with open(self.creds_file, 'w') as f: with open(creds_path, 'w') as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
def get_auth_headers(self, token: str) -> dict: def get_auth_headers(self, token: str) -> dict:
......
...@@ -311,8 +311,7 @@ def get_truncation_recovery_system_addition() -> str: ...@@ -311,8 +311,7 @@ def get_truncation_recovery_system_addition() -> str:
Returns: Returns:
System prompt addition text (empty string if truncation recovery is disabled) System prompt addition text (empty string if truncation recovery is disabled)
""" """
from kiro.config import TRUNCATION_RECOVERY # Use module-level constant (defined at top of file)
if not TRUNCATION_RECOVERY: if not TRUNCATION_RECOVERY:
return "" return ""
......
...@@ -30,22 +30,121 @@ Contains functions for: ...@@ -30,22 +30,121 @@ Contains functions for:
""" """
import logging import logging
import re
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
# Use standard Python logging # Use standard Python logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Configuration - hidden models that need normalization # Import Kiro models for type hints
from .kiro_models import ChatMessage, Tool
# Hidden models - not returned by Kiro /ListAvailableModels API but still functional.
# These need special internal IDs that differ from their display names.
# Format: "normalized_display_name" → "internal_kiro_id"
# Matches kiro-gateway's config.py HIDDEN_MODELS
HIDDEN_MODELS = { HIDDEN_MODELS = {
"claude-sonnet-4-5": "anthropic.claude-3-5-sonnet-20241022-v2:0", # Claude 3.7 Sonnet - legacy flagship model, still works!
"claude-haiku-4-5": "anthropic.claude-3-5-haiku-20241022-v1:0", "claude-3.7-sonnet": "CLAUDE_3_7_SONNET_20250219_V1_0",
"claude-opus-4-5": "anthropic.claude-3-5-opus-20250514-v1:0",
"claude-sonnet-4": "anthropic.claude-3-5-sonnet-20240620-v1:0",
} }
def normalize_model_name(name: str) -> str:
"""
Normalize client model name to Kiro format.
Ported from kiro-gateway's model_resolver.py normalize_model_name().
Transformations applied:
1. claude-haiku-4-5 → claude-haiku-4.5 (dash to dot for minor version)
2. claude-haiku-4-5-20251001 → claude-haiku-4.5 (strip date suffix)
3. claude-haiku-4-5-latest → claude-haiku-4.5 (strip 'latest' suffix)
4. claude-sonnet-4-20250514 → claude-sonnet-4 (strip date, no minor)
5. claude-3-7-sonnet → claude-3.7-sonnet (legacy format normalization)
6. claude-3-7-sonnet-20250219 → claude-3.7-sonnet (legacy + strip date)
7. claude-4.5-opus-high → claude-opus-4.5 (inverted format with suffix)
Args:
name: External model name from client
Returns:
Normalized model name in Kiro format
"""
if not name:
return name
# Lowercase for consistent matching
name_lower = name.lower()
# Pattern 1: Standard format - claude-{family}-{major}-{minor}(-{suffix})?
# Matches: claude-haiku-4-5, claude-haiku-4-5-20251001, claude-haiku-4-5-latest
# IMPORTANT: Minor version is 1-2 digits only! 8-digit dates should NOT match here.
standard_pattern = r'^(claude-(?:haiku|sonnet|opus)-\d+)-(\d{1,2})(?:-(?:\d{8}|latest|\d+))?$'
match = re.match(standard_pattern, name_lower)
if match:
base = match.group(1) # claude-haiku-4
minor = match.group(2) # 5
return f"{base}.{minor}" # claude-haiku-4.5
# Pattern 2: Standard format without minor - claude-{family}-{major}(-{date})?
# Matches: claude-sonnet-4, claude-sonnet-4-20250514
no_minor_pattern = r'^(claude-(?:haiku|sonnet|opus)-\d+)(?:-\d{8})?$'
match = re.match(no_minor_pattern, name_lower)
if match:
return match.group(1) # claude-sonnet-4
# Pattern 3: Legacy format - claude-{major}-{minor}-{family}(-{suffix})?
# Matches: claude-3-7-sonnet, claude-3-7-sonnet-20250219
legacy_pattern = r'^(claude)-(\d+)-(\d+)-(haiku|sonnet|opus)(?:-(?:\d{8}|latest|\d+))?$'
match = re.match(legacy_pattern, name_lower)
if match:
prefix = match.group(1) # claude
major = match.group(2) # 3
minor = match.group(3) # 7
family = match.group(4) # sonnet
return f"{prefix}-{major}.{minor}-{family}" # claude-3.7-sonnet
# Pattern 4: Already normalized with dot but has date suffix
# Matches: claude-haiku-4.5-20251001, claude-3.7-sonnet-20250219
dot_with_date_pattern = r'^(claude-(?:\d+\.\d+-)?(?:haiku|sonnet|opus)(?:-\d+\.\d+)?)-\d{8}$'
match = re.match(dot_with_date_pattern, name_lower)
if match:
return match.group(1)
# Pattern 5: Inverted format with suffix - claude-{major}.{minor}-{family}-{suffix}
# Matches: claude-4.5-opus-high, claude-4.5-sonnet-low
# Convert to: claude-{family}-{major}.{minor}
# NOTE: Requires a suffix to avoid matching already-normalized formats
inverted_with_suffix_pattern = r'^claude-(\d+)\.(\d+)-(haiku|sonnet|opus)-(.+)$'
match = re.match(inverted_with_suffix_pattern, name_lower)
if match:
major = match.group(1) # 4
minor = match.group(2) # 5
family = match.group(3) # opus
return f"claude-{family}-{major}.{minor}" # claude-opus-4.5
# No transformation needed - return as-is
return name
def get_model_id_for_kiro(model: str, hidden_models: dict) -> str: def get_model_id_for_kiro(model: str, hidden_models: dict) -> str:
"""Normalize model name for Kiro API""" """
return hidden_models.get(model, model) Get the model ID to send to Kiro API.
Normalizes the name first (dashes→dots, strip dates),
then checks hidden_models for special internal IDs.
Ported from kiro-gateway's model_resolver.py get_model_id_for_kiro().
Args:
model: External model name from client
hidden_models: Dict mapping display names to internal Kiro IDs
Returns:
Model ID to send to Kiro API
"""
normalized = normalize_model_name(model)
return hidden_models.get(normalized, normalized)
# Import from core - reuse shared logic # Import from core - reuse shared logic
from .kiro_converters import ( from .kiro_converters import (
......
This diff is collapsed.
...@@ -1308,9 +1308,35 @@ async def dashboard_prompts(request: Request): ...@@ -1308,9 +1308,35 @@ async def dashboard_prompts(request: Request):
prompts_data = [] prompts_data = []
for prompt_file in prompt_files: for prompt_file in prompt_files:
# Check user config first
config_path = Path.home() / '.aisbf' / prompt_file['filename'] config_path = Path.home() / '.aisbf' / prompt_file['filename']
if not config_path.exists(): if not config_path.exists():
config_path = Path(__file__).parent / 'config' / prompt_file['filename'] # Try installed locations
installed_dirs = [
Path.home() / '.local' / 'share' / 'aisbf',
Path('/usr/share/aisbf'),
Path(__file__).parent, # For source tree
]
source_path = None
for installed_dir in installed_dirs:
test_path = installed_dir / prompt_file['filename']
if test_path.exists():
source_path = test_path
break
# Also check config subdirectory
test_path = installed_dir / 'config' / prompt_file['filename']
if test_path.exists():
source_path = test_path
break
if source_path:
# Copy to user config directory
config_path.parent.mkdir(parents=True, exist_ok=True)
import shutil
shutil.copy2(source_path, config_path)
logger.info(f"Copied prompt from {source_path} to {config_path}")
if config_path.exists(): if config_path.exists():
with open(config_path) as f: with open(config_path) as f:
...@@ -1321,6 +1347,14 @@ async def dashboard_prompts(request: Request): ...@@ -1321,6 +1347,14 @@ async def dashboard_prompts(request: Request):
'filename': prompt_file['filename'], 'filename': prompt_file['filename'],
'content': content 'content': content
}) })
else:
# Add empty prompt if file not found
prompts_data.append({
'key': prompt_file['key'],
'name': prompt_file['name'],
'filename': prompt_file['filename'],
'content': f'# {prompt_file["name"]}\n\nPrompt template not found. Please add your prompt here.'
})
# Check for success parameter # Check for success parameter
success = request.query_params.get('success') success = request.query_params.get('success')
...@@ -1391,10 +1425,37 @@ async def dashboard_settings(request: Request): ...@@ -1391,10 +1425,37 @@ async def dashboard_settings(request: Request):
if auth_check: if auth_check:
return auth_check return auth_check
# Load aisbf.json # Load aisbf.json - check user config first, then installed locations
config_path = Path.home() / '.aisbf' / 'aisbf.json' config_path = Path.home() / '.aisbf' / 'aisbf.json'
if not config_path.exists(): if not config_path.exists():
config_path = Path(__file__).parent / 'config' / 'aisbf.json' # Try installed locations
installed_dirs = [
Path.home() / '.local' / 'share' / 'aisbf',
Path('/usr/share/aisbf'),
Path(__file__).parent, # For source tree
]
source_path = None
for installed_dir in installed_dirs:
test_path = installed_dir / 'aisbf.json'
if test_path.exists():
source_path = test_path
break
# Also check config subdirectory
test_path = installed_dir / 'config' / 'aisbf.json'
if test_path.exists():
source_path = test_path
break
if source_path:
# Copy to user config directory
config_path.parent.mkdir(parents=True, exist_ok=True)
import shutil
shutil.copy2(source_path, config_path)
logger.info(f"Copied config from {source_path} to {config_path}")
else:
raise HTTPException(status_code=500, detail="Configuration file not found in any location")
with open(config_path) as f: with open(config_path) as f:
aisbf_config = json.load(f) aisbf_config = json.load(f)
...@@ -1550,12 +1611,21 @@ async def dashboard_docs(request: Request): ...@@ -1550,12 +1611,21 @@ async def dashboard_docs(request: Request):
if auth_check: if auth_check:
return auth_check return auth_check
# Try to find DOCUMENTATION.md # Try to find DOCUMENTATION.md in multiple locations
doc_path = Path(__file__).parent / 'DOCUMENTATION.md' search_paths = [
if not doc_path.exists(): Path.home() / '.aisbf' / 'DOCUMENTATION.md',
doc_path = Path.home() / '.aisbf' / 'DOCUMENTATION.md' Path.home() / '.local' / 'share' / 'aisbf' / 'DOCUMENTATION.md',
Path('/usr/share/aisbf') / 'DOCUMENTATION.md',
Path(__file__).parent / 'DOCUMENTATION.md',
]
doc_path = None
for path in search_paths:
if path.exists():
doc_path = path
break
if doc_path.exists(): if doc_path and doc_path.exists():
with open(doc_path, encoding='utf-8') as f: with open(doc_path, encoding='utf-8') as f:
markdown_content = f.read() markdown_content = f.read()
# Convert markdown to HTML with extensions for better formatting # Convert markdown to HTML with extensions for better formatting
...@@ -1580,12 +1650,21 @@ async def dashboard_about(request: Request): ...@@ -1580,12 +1650,21 @@ async def dashboard_about(request: Request):
if auth_check: if auth_check:
return auth_check return auth_check
# Try to find README.md # Try to find README.md in multiple locations
readme_path = Path(__file__).parent / 'README.md' search_paths = [
if not readme_path.exists(): Path.home() / '.aisbf' / 'README.md',
readme_path = Path.home() / '.aisbf' / 'README.md' Path.home() / '.local' / 'share' / 'aisbf' / 'README.md',
Path('/usr/share/aisbf') / 'README.md',
Path(__file__).parent / 'README.md',
]
readme_path = None
for path in search_paths:
if path.exists():
readme_path = path
break
if readme_path.exists(): if readme_path and readme_path.exists():
with open(readme_path, encoding='utf-8') as f: with open(readme_path, encoding='utf-8') as f:
markdown_content = f.read() markdown_content = f.read()
# Convert markdown to HTML with extensions for better formatting # Convert markdown to HTML with extensions for better formatting
...@@ -1610,12 +1689,21 @@ async def dashboard_license(request: Request): ...@@ -1610,12 +1689,21 @@ async def dashboard_license(request: Request):
if auth_check: if auth_check:
return auth_check return auth_check
# Try to find LICENSE.txt # Try to find LICENSE.txt in multiple locations
license_path = Path(__file__).parent / 'LICENSE.txt' search_paths = [
if not license_path.exists(): Path.home() / '.aisbf' / 'LICENSE.txt',
license_path = Path.home() / '.aisbf' / 'LICENSE.txt' Path.home() / '.local' / 'share' / 'aisbf' / 'LICENSE.txt',
Path('/usr/share/aisbf') / 'LICENSE.txt',
Path(__file__).parent / 'LICENSE.txt',
]
license_path = None
for path in search_paths:
if path.exists():
license_path = path
break
if license_path.exists(): if license_path and license_path.exists():
with open(license_path, encoding='utf-8') as f: with open(license_path, encoding='utf-8') as f:
content = f.read() content = f.read()
# Convert to HTML with pre tags to preserve formatting # Convert to HTML with pre tags to preserve formatting
......
...@@ -82,10 +82,15 @@ setup( ...@@ -82,10 +82,15 @@ setup(
'main.py', 'main.py',
'requirements.txt', 'requirements.txt',
'aisbf.sh', 'aisbf.sh',
'DOCUMENTATION.md',
'README.md',
'LICENSE.txt',
'config/providers.json', 'config/providers.json',
'config/rotations.json', 'config/rotations.json',
'config/autoselect.json', 'config/autoselect.json',
'config/autoselect.md', 'config/autoselect.md',
'config/condensation_conversational.md',
'config/condensation_semantic.md',
'config/aisbf.json', 'config/aisbf.json',
]), ]),
# Install aisbf package to share directory for venv installation # Install aisbf package to share directory for venv installation
...@@ -104,6 +109,7 @@ setup( ...@@ -104,6 +109,7 @@ setup(
'aisbf/kiro_converters.py', 'aisbf/kiro_converters.py',
'aisbf/kiro_converters_openai.py', 'aisbf/kiro_converters_openai.py',
'aisbf/kiro_models.py', 'aisbf/kiro_models.py',
'aisbf/kiro_parsers.py',
'aisbf/kiro_utils.py', 'aisbf/kiro_utils.py',
]), ]),
# Install dashboard templates # Install dashboard templates
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment