Fix market

parent a8b5285c
...@@ -5,6 +5,7 @@ import time ...@@ -5,6 +5,7 @@ import time
import logging import logging
import threading import threading
import hmac as _hmac import hmac as _hmac
import ipaddress
from typing import Optional from typing import Optional
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
...@@ -83,6 +84,23 @@ def _is_local_client(request: Request) -> bool: ...@@ -83,6 +84,23 @@ def _is_local_client(request: Request) -> bool:
return ip in _LOCAL_IPS if ip else False return ip in _LOCAL_IPS if ip else False
def _is_private_or_local_ip(ip: Optional[str]) -> bool:
if not ip:
return False
if ip in _LOCAL_IPS:
return True
try:
addr = ipaddress.ip_address(ip)
except ValueError:
return False
return any((
addr.is_private,
addr.is_loopback,
addr.is_link_local,
addr.is_reserved,
))
class GenocidalBlockingMiddleware(BaseHTTPMiddleware): class GenocidalBlockingMiddleware(BaseHTTPMiddleware):
"""Block Israeli IPs/domains.""" """Block Israeli IPs/domains."""
...@@ -105,7 +123,7 @@ class GenocidalBlockingMiddleware(BaseHTTPMiddleware): ...@@ -105,7 +123,7 @@ class GenocidalBlockingMiddleware(BaseHTTPMiddleware):
return True return True
from aisbf import geolocation from aisbf import geolocation
client_ip = _get_client_ip(request) client_ip = _get_client_ip(request)
if client_ip: if client_ip and not _is_private_or_local_ip(client_ip):
country = await geolocation.get_ip_country(client_ip) country = await geolocation.get_ip_country(client_ip)
if country == 'IL': if country == 'IL':
return True return True
......
...@@ -3534,6 +3534,77 @@ class DatabaseManager: ...@@ -3534,6 +3534,77 @@ class DatabaseManager:
cursor.execute(query) cursor.execute(query)
return [self._load_market_listing_row(row) for row in cursor.fetchall()] return [self._load_market_listing_row(row) for row in cursor.fetchall()]
def list_market_listings_paginated(
self,
page: int = 1,
limit: int = 25,
search: Optional[str] = None,
source_type: Optional[str] = None,
active_filter: Optional[str] = None,
online_filter: Optional[str] = None,
owner_username: Optional[str] = None,
) -> Dict[str, Any]:
page = max(int(page or 1), 1)
limit = max(1, min(int(limit or 25), 100))
offset = (page - 1) * limit
where_clauses = []
params: List[Any] = []
placeholder = self.placeholder
if search:
like = f"%{search.strip()}%"
where_clauses.append(
f"(title LIKE {placeholder} OR description LIKE {placeholder} OR source_id LIKE {placeholder} OR provider_id LIKE {placeholder} OR model_id LIKE {placeholder} OR owner_username LIKE {placeholder})"
)
params.extend([like, like, like, like, like, like])
if source_type:
where_clauses.append(f"source_type = {placeholder}")
params.append(source_type)
if owner_username:
where_clauses.append(f"owner_username = {placeholder}")
params.append(owner_username)
if active_filter == 'active':
where_clauses.append("is_active = 1")
elif active_filter == 'inactive':
where_clauses.append("is_active = 0")
if online_filter == 'online':
where_clauses.append("provider_id IS NOT NULL")
elif online_filter == 'offline':
where_clauses.append("provider_id IS NULL")
where_sql = f" WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM market_listings{where_sql}", tuple(params))
total_row = cursor.fetchone()
total = int(total_row[0] if total_row else 0)
query = f'''
SELECT id, owner_user_id, owner_username, source_scope, source_type, source_id, listing_key,
title, description, provider_id, model_id, endpoint, currency_code,
price_per_million_tokens, metadata, config_snapshot, price_per_1000_requests,
provider_price_per_million_tokens, provider_price_per_1000_requests, is_active,
created_at, updated_at
FROM market_listings
{where_sql}
ORDER BY created_at DESC
LIMIT {placeholder} OFFSET {placeholder}
'''
cursor.execute(query, tuple(params + [limit, offset]))
items = [self._load_market_listing_row(row) for row in cursor.fetchall()]
return {
'items': items,
'total': total,
'page': page,
'limit': limit,
}
def get_market_listing(self, listing_id: int) -> Optional[Dict[str, Any]]: def get_market_listing(self, listing_id: int) -> Optional[Dict[str, Any]]:
with self._get_connection() as conn: with self._get_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
......
This diff is collapsed.
...@@ -2,6 +2,19 @@ ...@@ -2,6 +2,19 @@
* Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net> * Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
* *
* AISBF - AI Service Broker Framework || AI Should Be Free * AISBF - AI Service Broker Framework || AI Should Be Free
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
.studio { .studio {
......
...@@ -2,6 +2,19 @@ ...@@ -2,6 +2,19 @@
* Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net> * Copyright (C) 2026 Stefy Lanza <stefy@nexlab.net>
* *
* AISBF - AI Service Broker Framework || AI Should Be Free * AISBF - AI Service Broker Framework || AI Should Be Free
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
// ───────────────────────────────────────────────────────────────── // ─────────────────────────────────────────────────────────────────
......
...@@ -98,7 +98,9 @@ ...@@ -98,7 +98,9 @@
{% if pagination.has_prev %} {% if pagination.has_prev %}
<a class="btn" href="{{ url_for(request, '/dashboard/admin/market') }}?q={{ filters.q|urlencode }}&source_type={{ filters.source_type|urlencode }}&active_filter={{ filters.active_filter|urlencode }}&online_filter={{ filters.online_filter|urlencode }}&owner_username={{ filters.owner_username|urlencode }}&limit={{ filters.limit }}&page={{ pagination.prev_page }}" style="background: var(--bg-accent); color: var(--color-text);">Previous</a> <a class="btn" href="{{ url_for(request, '/dashboard/admin/market') }}?q={{ filters.q|urlencode }}&source_type={{ filters.source_type|urlencode }}&active_filter={{ filters.active_filter|urlencode }}&online_filter={{ filters.online_filter|urlencode }}&owner_username={{ filters.owner_username|urlencode }}&limit={{ filters.limit }}&page={{ pagination.prev_page }}" style="background: var(--bg-accent); color: var(--color-text);">Previous</a>
{% endif %} {% endif %}
<a class="btn" href="{{ url_for(request, '/dashboard/admin/market') }}?q={{ filters.q|urlencode }}&source_type={{ filters.source_type|urlencode }}&active_filter={{ filters.active_filter|urlencode }}&online_filter={{ filters.online_filter|urlencode }}&owner_username={{ filters.owner_username|urlencode }}&limit={{ filters.limit }}&page={{ pagination.page }}" style="background: var(--color-link); color: white;">Current Page</a> {% if pagination.total_pages > 1 %}
<a class="btn" href="{{ url_for(request, '/dashboard/admin/market') }}?q={{ filters.q|urlencode }}&source_type={{ filters.source_type|urlencode }}&active_filter={{ filters.active_filter|urlencode }}&online_filter={{ filters.online_filter|urlencode }}&owner_username={{ filters.owner_username|urlencode }}&limit={{ filters.limit }}&page={{ pagination.page }}" style="background: var(--color-link); color: white;">{{ pagination.page }}</a>
{% endif %}
{% if pagination.has_next %} {% if pagination.has_next %}
<a class="btn" href="{{ url_for(request, '/dashboard/admin/market') }}?q={{ filters.q|urlencode }}&source_type={{ filters.source_type|urlencode }}&active_filter={{ filters.active_filter|urlencode }}&online_filter={{ filters.online_filter|urlencode }}&owner_username={{ filters.owner_username|urlencode }}&limit={{ filters.limit }}&page={{ pagination.next_page }}" style="background: var(--bg-accent); color: var(--color-text);">Next</a> <a class="btn" href="{{ url_for(request, '/dashboard/admin/market') }}?q={{ filters.q|urlencode }}&source_type={{ filters.source_type|urlencode }}&active_filter={{ filters.active_filter|urlencode }}&online_filter={{ filters.online_filter|urlencode }}&owner_username={{ filters.owner_username|urlencode }}&limit={{ filters.limit }}&page={{ pagination.next_page }}" style="background: var(--bg-accent); color: var(--color-text);">Next</a>
{% endif %} {% endif %}
......
...@@ -19,6 +19,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>. ...@@ -19,6 +19,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
{% block title %}Providers - AISBF Dashboard{% endblock %} {% block title %}Providers - AISBF Dashboard{% endblock %}
{% block content %} {% block content %}
<script id="providers-bootstrap" type="application/json">{{ providers_data | tojson }}</script>
<h2 style="margin-bottom: 30px;">Providers Configuration</h2> <h2 style="margin-bottom: 30px;">Providers Configuration</h2>
{% if success %} {% if success %}
...@@ -306,10 +307,10 @@ const IS_LOCAL_CLIENT = {{ 'true' if is_local_client else 'false' }}; ...@@ -306,10 +307,10 @@ const IS_LOCAL_CLIENT = {{ 'true' if is_local_client else 'false' }};
const BASE_PATH = {{ (request.scope.get('root_path', '') or '') | tojson }}; const BASE_PATH = {{ (request.scope.get('root_path', '') or '') | tojson }};
// Marker used by the AISBF Chrome Extension to auto-detect this page and configure itself. // Marker used by the AISBF Chrome Extension to auto-detect this page and configure itself.
window.AISBF_PROVIDERS_PAGE = { serverUrl: window.location.origin + BASE_PATH }; window.AISBF_PROVIDERS_PAGE = { serverUrl: window.location.origin + BASE_PATH };
let providersData = JSON.parse({{ providers_json | tojson }}); let providersData = {{ providers_data | tojson }};
const STUDIO_CAPABILITY_CHOICES = JSON.parse({{ studio_capability_choices_json | tojson }}); const STUDIO_CAPABILITY_CHOICES = {{ studio_capability_choices | tojson }};
const STUDIO_ADAPTER_CHOICES = JSON.parse({{ studio_adapter_choices_json | tojson }}); const STUDIO_ADAPTER_CHOICES = {{ studio_adapter_choices | tojson }};
const STUDIO_ADAPTER_PROFILE_CHOICES = JSON.parse({{ studio_adapter_profile_choices_json | tojson }}); const STUDIO_ADAPTER_PROFILE_CHOICES = {{ studio_adapter_profile_choices | tojson }};
let expandedProviders = new Set(); let expandedProviders = new Set();
let currentProviderPage = 0; let currentProviderPage = 0;
const PROVIDERS_PAGE_SIZE = 10; const PROVIDERS_PAGE_SIZE = 10;
...@@ -1208,11 +1209,11 @@ function renderProviderDetails(key) { ...@@ -1208,11 +1209,11 @@ function renderProviderDetails(key) {
<div><strong>Connected At:</strong> ${escHtmlAttr(formatBrokerTimestamp(brokerSession.connected_at))}</div> <div><strong>Connected At:</strong> ${escHtmlAttr(formatBrokerTimestamp(brokerSession.connected_at))}</div>
<div><strong>Remote Endpoint:</strong> ${escHtmlAttr(brokerSession.endpoint || 'Unknown')}</div> <div><strong>Remote Endpoint:</strong> ${escHtmlAttr(brokerSession.endpoint || 'Unknown')}</div>
<div><strong>Transport:</strong> ${escHtmlAttr(brokerSession.transport || 'broker')}</div> <div><strong>Transport:</strong> ${escHtmlAttr(brokerSession.transport || 'broker')}</div>
<div><strong>GPU Count:</strong> ${escHtmlAttr(String(brokerMetadata.gpu_count ?? (Array.isArray(brokerMetadata.gpus) ? brokerMetadata.gpus.length : 0) || '0'))}</div> <div><strong>GPU Count:</strong> ${escHtmlAttr(String((brokerMetadata.gpu_count ?? (Array.isArray(brokerMetadata.gpus) ? brokerMetadata.gpus.length : 0)) || '0'))}</div>
<div><strong>VRAM:</strong> ${escHtmlAttr(formatVramMb(brokerMetadata.available_vram_mb))} free / ${escHtmlAttr(formatVramMb(brokerMetadata.total_vram_mb))} total</div> <div><strong>VRAM:</strong> ${escHtmlAttr(formatVramMb(brokerMetadata.available_vram_mb))} free / ${escHtmlAttr(formatVramMb(brokerMetadata.total_vram_mb))} total</div>
<div><strong>GPUs:</strong> ${escHtmlAttr(renderGpuSummary(brokerMetadata))}</div> <div><strong>GPUs:</strong> ${escHtmlAttr(renderGpuSummary(brokerMetadata))}</div>
<div><strong>Avg Latency:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_latency_ms, 1, ' ms'))}</div> <div><strong>Avg Latency:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_latency_ms, 1, ' ms'))}</div>
<div><strong>Avg Throughput:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_tokens_per_second, 2, ' tok/s'))}</div> <div><strong>Avg Throughput:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_tokens_per_second, 1, ' tok/s'))}</div>
<div><strong>Avg Tokens:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_total_tokens, 1))}</div> <div><strong>Avg Tokens:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_total_tokens, 1))}</div>
<div><strong>Success Rate:</strong> ${escHtmlAttr(formatPerfNumber((performance.success_rate ?? 0) * 100, 1, '%'))}</div> <div><strong>Success Rate:</strong> ${escHtmlAttr(formatPerfNumber((performance.success_rate ?? 0) * 100, 1, '%'))}</div>
<div><strong>Samples:</strong> ${escHtmlAttr(String(performance.sample_count ?? 0))} / ${escHtmlAttr(String(performance.window_size ?? 100))}</div> <div><strong>Samples:</strong> ${escHtmlAttr(String(performance.sample_count ?? 0))} / ${escHtmlAttr(String(performance.window_size ?? 100))}</div>
......
...@@ -27,7 +27,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>. ...@@ -27,7 +27,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
{% block content %} {% block content %}
<script id="studio-bootstrap" type="application/json">{{ studio_bootstrap_json|safe }}</script> <script id="studio-bootstrap" type="application/json">{{ studio_bootstrap_json|safe }}</script>
<div class="studio"> <div class="studio" data-studio-shell="dashboard">
<!-- Sidebar --> <!-- Sidebar -->
<aside class="sidebar"> <aside class="sidebar">
...@@ -1039,9 +1039,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>. ...@@ -1039,9 +1039,10 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
<!-- ═══════════════ PIPELINES ═══════════════ --> <!-- ═══════════════ PIPELINES ═══════════════ -->
<div class="panel" id="panel-pipe"> <div class="panel" id="panel-pipe">
<div class="pipe-panel"> <div class="pipe-panel">
<div class="diag-card" id="studio-diagnostics"> <div class="diag-card" id="studio-diagnostics" data-empty-message="No diagnostics yet.">
<div class="diag-title">Diagnostics</div> <div class="diag-title">Diagnostics</div>
<div class="diag-sub">Evaluated from the current model type, declared capabilities, and existing Studio fallback rules.</div> <div class="diag-sub">Evaluated from the current model type, declared capabilities, and existing Studio fallback rules.</div>
<span data-i18n="studio.diagnostics_empty">No diagnostics yet.</span>
<div class="diag-groups" id="diag-groups"></div> <div class="diag-groups" id="diag-groups"></div>
</div> </div>
<div class="hist-card" id="studio-history"> <div class="hist-card" id="studio-history">
......
...@@ -311,14 +311,14 @@ async function apiCall(method, url, body) { ...@@ -311,14 +311,14 @@ async function apiCall(method, url, body) {
} }
let providersData = {}; let providersData = {};
const STUDIO_CAPABILITY_CHOICES = JSON.parse({{ studio_capability_choices_json | tojson }}); const STUDIO_CAPABILITY_CHOICES = {{ studio_capability_choices | tojson }};
const STUDIO_ADAPTER_CHOICES = JSON.parse({{ studio_adapter_choices_json | tojson }}); const STUDIO_ADAPTER_CHOICES = {{ studio_adapter_choices | tojson }};
const STUDIO_ADAPTER_PROFILE_CHOICES = JSON.parse({{ studio_adapter_profile_choices_json | tojson }}); const STUDIO_ADAPTER_PROFILE_CHOICES = {{ studio_adapter_profile_choices | tojson }};
let expandedProviders = new Set(); let expandedProviders = new Set();
let currentProviderPage = 0; let currentProviderPage = 0;
const PROVIDERS_PAGE_SIZE = 10; const PROVIDERS_PAGE_SIZE = 10;
let providerSearchFilter = ''; let providerSearchFilter = '';
let rawProviders = JSON.parse({{ user_providers_json | tojson }}); let rawProviders = JSON.parse({{ user_providers_bootstrap_json | safe }});
let providerMasterOrder = []; let providerMasterOrder = [];
let _providerDS = null; let _providerDS = null;
...@@ -1125,7 +1125,7 @@ function renderProviderDetails(key) { ...@@ -1125,7 +1125,7 @@ function renderProviderDetails(key) {
<div><strong>Connected At:</strong> ${escHtmlAttr(formatBrokerTimestamp(brokerSession.connected_at))}</div> <div><strong>Connected At:</strong> ${escHtmlAttr(formatBrokerTimestamp(brokerSession.connected_at))}</div>
<div><strong>Remote Endpoint:</strong> ${escHtmlAttr(brokerSession.endpoint || 'Unknown')}</div> <div><strong>Remote Endpoint:</strong> ${escHtmlAttr(brokerSession.endpoint || 'Unknown')}</div>
<div><strong>Transport:</strong> ${escHtmlAttr(brokerSession.transport || 'broker')}</div> <div><strong>Transport:</strong> ${escHtmlAttr(brokerSession.transport || 'broker')}</div>
<div><strong>GPU Count:</strong> ${escHtmlAttr(String(brokerMetadata.gpu_count ?? (Array.isArray(brokerMetadata.gpus) ? brokerMetadata.gpus.length : 0) || '0'))}</div> <div><strong>GPU Count:</strong> ${escHtmlAttr(String((brokerMetadata.gpu_count ?? (Array.isArray(brokerMetadata.gpus) ? brokerMetadata.gpus.length : 0)) || '0'))}</div>
<div><strong>VRAM:</strong> ${escHtmlAttr(formatVramMb(brokerMetadata.available_vram_mb))} free / ${escHtmlAttr(formatVramMb(brokerMetadata.total_vram_mb))} total</div> <div><strong>VRAM:</strong> ${escHtmlAttr(formatVramMb(brokerMetadata.available_vram_mb))} free / ${escHtmlAttr(formatVramMb(brokerMetadata.total_vram_mb))} total</div>
<div><strong>GPUs:</strong> ${escHtmlAttr(renderGpuSummary(brokerMetadata))}</div> <div><strong>GPUs:</strong> ${escHtmlAttr(renderGpuSummary(brokerMetadata))}</div>
<div><strong>Avg Latency:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_latency_ms, 1, ' ms'))}</div> <div><strong>Avg Latency:</strong> ${escHtmlAttr(formatPerfNumber(performance.avg_latency_ms, 1, ' ms'))}</div>
......
...@@ -337,6 +337,17 @@ def _load_json_parse_bootstrap(response_text: str, marker: str): ...@@ -337,6 +337,17 @@ def _load_json_parse_bootstrap(response_text: str, marker: str):
return json.loads(json.loads(js_string_literal)) return json.loads(json.loads(js_string_literal))
def test_json_parse_bootstrap_escapes_inline_script_sequences():
from aisbf.routes.dashboard.providers import _json_parse_bootstrap
payload = [{"config": {"name": 'Alice\'s "Provider" </script><script>alert(1)</script>'}}]
bootstrap = _json_parse_bootstrap(payload)
assert '</script><script>alert(1)</script>' not in bootstrap
assert '\\u003c/script\\u003e\\u003cscript\\u003ealert(1)\\u003c/script\\u003e' in bootstrap
assert json.loads(json.loads(bootstrap))[0]["config"]["name"] == 'Alice\'s "Provider" </script><script>alert(1)</script>'
@pytest.fixture @pytest.fixture
def runtime_fixture(monkeypatch): def runtime_fixture(monkeypatch):
db = MarketReferenceRuntimeDbStub() db = MarketReferenceRuntimeDbStub()
...@@ -624,10 +635,10 @@ def test_dashboard_admin_providers_bootstrap_uses_json_parse(monkeypatch): ...@@ -624,10 +635,10 @@ def test_dashboard_admin_providers_bootstrap_uses_json_parse(monkeypatch):
"request": request, "request": request,
"session": {}, "session": {},
"__version__": "test", "__version__": "test",
"providers_json": json.dumps(providers_payload), "providers_data": providers_payload,
"studio_capability_choices_json": "[]", "studio_capability_choices": [],
"studio_adapter_choices_json": "[]", "studio_adapter_choices": [],
"studio_adapter_profile_choices_json": "[]", "studio_adapter_profile_choices": [],
"claude_cli_mode": False, "claude_cli_mode": False,
"is_local_client": True, "is_local_client": True,
"success": None, "success": None,
...@@ -637,9 +648,9 @@ def test_dashboard_admin_providers_bootstrap_uses_json_parse(monkeypatch): ...@@ -637,9 +648,9 @@ def test_dashboard_admin_providers_bootstrap_uses_json_parse(monkeypatch):
response_text = response.body.decode() response_text = response.body.decode()
assert response.status_code == 200 assert response.status_code == 200
assert "let providersData = JSON.parse(" in response_text assert "let providersData = {" in response_text
providers_bootstrap = _load_json_parse_bootstrap(response_text, "let providersData") bootstrap_fragment = response_text.split("let providersData = ", 1)[1].split(";\n", 1)[0]
bootstrap_fragment = response_text.split("let providersData = JSON.parse(", 1)[1].split("\n", 1)[0] providers_bootstrap = json.loads(bootstrap_fragment)
assert '</script><script>alert(2)</script>' not in bootstrap_fragment assert '</script><script>alert(2)</script>' not in bootstrap_fragment
assert '\\u003c/script\\u003e\\u003cscript\\u003ealert(2)\\u003c/script\\u003e' in bootstrap_fragment assert '\\u003c/script\\u003e\\u003cscript\\u003ealert(2)\\u003c/script\\u003e' in bootstrap_fragment
assert providers_bootstrap["danger-provider"]["name"] == 'Admin "Provider" </script><script>alert(2)</script>' assert providers_bootstrap["danger-provider"]["name"] == 'Admin "Provider" </script><script>alert(2)</script>'
......
...@@ -2,14 +2,20 @@ import pytest ...@@ -2,14 +2,20 @@ import pytest
import time import time
import ipaddress import ipaddress
from unittest.mock import patch, AsyncMock, Mock from unittest.mock import patch, AsyncMock, Mock
from aisbf.geolocation import get_ip_country, _subnet_cache, _find_in_cache, _fallback_prefix from fastapi import FastAPI
from fastapi.testclient import TestClient
from aisbf.geolocation import get_ip_country, _subnet_cache, _failure_cache, _find_in_cache, _fallback_prefix
from aisbf.app.middleware import GenocidalBlockingMiddleware
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_cache(): def clear_cache():
_subnet_cache.clear() _subnet_cache.clear()
_failure_cache.clear()
yield yield
_subnet_cache.clear() _subnet_cache.clear()
_failure_cache.clear()
def _mock_json_response(country: str, network: str, status: int = 200): def _mock_json_response(country: str, network: str, status: int = 200):
...@@ -47,6 +53,47 @@ def test_find_in_cache_expired_entry_removed(): ...@@ -47,6 +53,47 @@ def test_find_in_cache_expired_entry_removed():
assert "10.0.0.0/8" not in _subnet_cache assert "10.0.0.0/8" not in _subnet_cache
def _build_geo_test_client():
app = FastAPI()
@app.get("/")
async def root():
return {"ok": True}
app.add_middleware(GenocidalBlockingMiddleware, server_ip_blocked_ref=lambda: False)
return TestClient(app)
def test_localhost_ip_skips_geolocation_lookup():
client = _build_geo_test_client()
with patch("aisbf.geolocation.get_ip_country", new_callable=AsyncMock) as mock_geo:
response = client.get("/", headers={"X-Forwarded-For": "127.0.0.1"})
assert response.status_code == 200
mock_geo.assert_not_called()
def test_private_rfc1918_ip_skips_geolocation_lookup():
client = _build_geo_test_client()
with patch("aisbf.geolocation.get_ip_country", new_callable=AsyncMock) as mock_geo:
response = client.get("/", headers={"X-Forwarded-For": "192.168.1.25"})
assert response.status_code == 200
mock_geo.assert_not_called()
def test_public_ip_still_checks_geolocation():
client = _build_geo_test_client()
with patch("aisbf.geolocation.get_ip_country", new_callable=AsyncMock, return_value=None) as mock_geo:
response = client.get("/", headers={"X-Forwarded-For": "8.8.8.8"})
assert response.status_code == 200
mock_geo.assert_awaited_once_with("8.8.8.8")
# --- invalid IP --- # --- invalid IP ---
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -136,7 +183,7 @@ async def test_non_200_not_cached(): ...@@ -136,7 +183,7 @@ async def test_non_200_not_cached():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_failure_then_success_retries(): async def test_failure_then_success_uses_failure_backoff():
fail = Mock(status_code=500) fail = Mock(status_code=500)
ok = _mock_json_response("IT", "1.2.0.0/16") ok = _mock_json_response("IT", "1.2.0.0/16")
with patch('httpx.AsyncClient.get', new_callable=AsyncMock) as mock_get: with patch('httpx.AsyncClient.get', new_callable=AsyncMock) as mock_get:
...@@ -146,8 +193,9 @@ async def test_failure_then_success_retries(): ...@@ -146,8 +193,9 @@ async def test_failure_then_success_retries():
mock_get.return_value = ok mock_get.return_value = ok
result = await get_ip_country("1.2.3.4") result = await get_ip_country("1.2.3.4")
assert result == "IT" assert result is None
assert "1.2.0.0/16" in _subnet_cache mock_get.assert_called_once()
assert "1.2.0.0/16" not in _subnet_cache
# --- TTL expiry --- # --- TTL expiry ---
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment