Better capabilities detectiong and autosearch on HF

parent 227c3dbd
......@@ -54,6 +54,8 @@ def set_config_manager(mgr):
"""Set the shared ConfigManager instance."""
global config_manager
config_manager = mgr
from codai.models.capabilities import init_capability_cache
init_capability_cache(str(mgr.config_dir))
def _next_whisper_server_model_id(audio_models) -> str:
......@@ -859,7 +861,9 @@ def _scan_caches() -> dict:
result: dict = {"hf": [], "gguf": []}
from codai.models.cache import get_all_cache_dirs, get_model_cache_dir
from codai.models.capabilities import detect_model_capabilities
from codai.models.capabilities import (
detect_model_capabilities, lookup_capability_cache,
)
caches = get_all_cache_dirs()
# Collect configured models: key (path/id) → (settings_dict, model_type)
......@@ -917,7 +921,8 @@ def _scan_caches() -> dict:
continue # skip adding to hf list
cfg = configured_settings.get(repo.repo_id, ({}, None))
caps = detect_model_capabilities(repo.repo_id)
caps = (lookup_capability_cache(repo.repo_id)
or detect_model_capabilities(repo.repo_id))
result["hf"].append({
"id": repo.repo_id,
"size_gb": round(size_bytes / 1e9, 2),
......@@ -1784,7 +1789,7 @@ async def api_hf_search(
import urllib.request
import urllib.parse
import json as _json
from codai.models.capabilities import detect_model_capabilities
from codai.models.capabilities import detect_capabilities_from_pipeline_tag
if sort not in ("downloads", "likes", "lastModified", "createdAt"):
sort = "downloads"
......@@ -1861,17 +1866,25 @@ async def api_hf_search(
except Exception:
pass
return [
{
"id": m.get("modelId") or m.get("id", ""),
from codai.models.capabilities import update_capability_cache
results = []
for m in merged[:20]:
mid = m.get("modelId") or m.get("id", "")
caps = detect_capabilities_from_pipeline_tag(
m.get("pipeline_tag", ""), mid,
)
# Only cache when pipeline_tag gave us authoritative information
if m.get("pipeline_tag"):
update_capability_cache(mid, caps)
results.append({
"id": mid,
"downloads": m.get("downloads", 0),
"likes": m.get("likes", 0),
"pipeline_tag": m.get("pipeline_tag", ""),
"vram_available": vram_gb,
"capabilities": detect_model_capabilities(m.get("modelId") or m.get("id", "")).to_list(),
}
for m in merged[:20]
]
"capabilities": caps.to_list(),
})
return results
except Exception as e:
raise HTTPException(status_code=502, detail=f"HuggingFace API error: {e}")
......
......@@ -188,9 +188,9 @@ td code{font-family:var(--mono);font-size:11.5px;background:var(--raised);paddin
.modal-close{background:none;border:none;color:var(--text-3);cursor:pointer;font-size:1.125rem;line-height:1;padding:.125rem;border-radius:3px;transition:color .1s}
.modal-close:hover{color:var(--text)}
.modal-body{padding:1.125rem}
.donate-modal-box{max-width:560px}
.nav-donate-btn{background:transparent;border:none;cursor:pointer;color:var(--text-3)}
.nav-donate-btn:hover{color:var(--text)}
.donate-modal-box{max-width:660px}
.nav-donate-btn{background:var(--accent);border:none;cursor:pointer;color:#fff;border-radius:6px;font-weight:600;padding:.3rem .75rem;margin-left:.375rem;transition:opacity .15s,transform .1s}
.nav-donate-btn:hover{opacity:.88;transform:scale(1.03)}
.donate-tagline{font-size:13px;color:var(--text-2);margin:0 0 1.25rem;line-height:1.55}
.donate-coins{display:flex;gap:.875rem}
.donate-coin{flex:1;display:flex;flex-direction:column;align-items:center;gap:.75rem;padding:1rem;background:var(--raised);border:1px solid var(--border);border-radius:8px}
......@@ -198,6 +198,8 @@ td code{font-family:var(--mono);font-size:11.5px;background:var(--raised);paddin
.donate-qr{display:block;border-radius:5px}
.donate-addr-row{display:flex;align-items:center;gap:.5rem;width:100%;background:var(--card);border:1px solid var(--border);border-radius:5px;padding:.375rem .5rem}
.donate-addr{flex:1;font-family:var(--mono);font-size:9px;color:var(--text-2);word-break:break-all;line-height:1.5;user-select:all}
.donate-paypal-link{color:var(--accent);text-decoration:none}
.donate-paypal-link:hover{text-decoration:underline}
.donate-copy{flex-shrink:0;background:none;border:1px solid var(--border);border-radius:4px;cursor:pointer;padding:.2rem .5rem;font-size:11px;font-weight:600;color:var(--text-3);font-family:var(--font);transition:color .1s,border-color .1s}
.donate-copy:hover{color:var(--text);border-color:var(--border-2)}
@media(max-width:480px){.donate-coins{flex-direction:column}}
......
......@@ -31,7 +31,7 @@
<a href="/admin/archive" class="nav-link {% if '/archive' in request.url.path %}active{% endif %}">Archive</a>
<a href="/admin/settings" class="nav-link {% if '/settings' in request.url.path %}active{% endif %}">Settings</a>
{% endif %}
<button class="nav-link nav-donate-btn" onclick="document.getElementById('donateModal').classList.add('show')">Donate &#9829;</button>
<button class="nav-link nav-donate-btn" onclick="document.getElementById('donateModal').classList.add('show')">&#9829; Donate</button>
</div>
</div>
<div class="topnav-right">
......@@ -67,13 +67,23 @@
<button class="donate-copy" onclick="donateCopy('ethAddr',this)">Copy</button>
</div>
</div>
<div class="donate-coin">
<span class="donate-coin-label">PayPal</span>
<img class="donate-qr" src="https://api.qrserver.com/v1/create-qr-code/?size=160x160&color=DDE1F0&bgcolor=161820&qzone=2&data=https://paypal.me/nexlab" alt="PayPal QR code" width="160" height="160">
<div class="donate-addr-row">
<a class="donate-addr donate-paypal-link" id="ppAddr" href="https://paypal.me/nexlab" target="_blank" rel="noopener">paypal.me/nexlab</a>
<a class="donate-copy" href="https://paypal.me/nexlab" target="_blank" rel="noopener">Open</a>
</div>
</div>
</div>
</div>
</div>
</div>
<script>
function donateCopy(id, btn) {
navigator.clipboard.writeText(document.getElementById(id).textContent).then(function() {
var el = document.getElementById(id);
var text = el.href || el.textContent;
navigator.clipboard.writeText(text.trim()).then(function() {
var orig = btn.innerHTML;
btn.innerHTML = '&#10003;';
setTimeout(function(){ btn.innerHTML = orig; }, 1500);
......
......@@ -2087,6 +2087,7 @@ a.dl { display:inline-block; margin-top:.4rem; }
// State
// ─────────────────────────────────────────────────────────────────
let models = [], activeModel = null, chatHistory = [], chatBusy = false, attachedImage = null;
let _localCapSet = new Set(); // capabilities available from locally downloaded (not necessarily configured) models
let _imgPollTimer = null;
let _vidPollTimer = null;
let _audPollTimer = null;
......@@ -2290,28 +2291,42 @@ const STUDIO_CAPABILITIES = {
}
};
// Maps a capability token to the best HuggingFace search parameters
// NOTE: HF API `search=` is AND — keep q to one or two co-occurring terms.
// The pipeline filter already narrows by category; q just refines within it.
const CAP_TO_HF_SEARCH = {
'image_generation': { pipeline:'text-to-image', q:'', gguf:'no-gguf' },
'image_to_image': { pipeline:'image-to-image', q:'', gguf:'no-gguf' },
'inpainting': { pipeline:'image-to-image', q:'inpainting', gguf:'no-gguf' },
'image_upscaling': { pipeline:'image-to-image', q:'upscale', gguf:'no-gguf' },
'depth_estimation': { pipeline:'depth-estimation', q:'', gguf:'no-gguf' },
'image_segmentation': { pipeline:'image-segmentation', q:'', gguf:'no-gguf' },
'video_generation': { pipeline:'text-to-video', q:'', gguf:'no-gguf' },
'image_to_video': { pipeline:'image-to-video', q:'', gguf:'no-gguf' },
'video_to_video': { pipeline:'', q:'video to video', gguf:'no-gguf' },
'video_interpolation': { pipeline:'', q:'frame interpolation',gguf:'no-gguf' },
'video_upscaling': { pipeline:'', q:'video upscaling', gguf:'no-gguf' },
'speech_to_text': { pipeline:'automatic-speech-recognition',q:'', gguf:'gguf' },
'text_to_speech': { pipeline:'text-to-speech', q:'', gguf:'no-gguf' },
'audio_generation': { pipeline:'text-to-audio', q:'', gguf:'no-gguf' },
'text_generation': { pipeline:'text-generation', q:'', gguf:'gguf' },
'audio_to_audio': { pipeline:'audio-to-audio', q:'voice conversion', gguf:'no-gguf' },
'subtitle_generation': { pipeline:'automatic-speech-recognition',q:'', gguf:'gguf' },
'image_to_3d': { pipeline:'', q:'image to 3d', gguf:'no-gguf' },
'video_to_3d': { pipeline:'', q:'video depth 3d', gguf:'no-gguf' },
'model_3d_generation': { pipeline:'', q:'3d generation', gguf:'no-gguf' },
'model_3d_to_image': { pipeline:'', q:'3d rendering', gguf:'no-gguf' },
// Text / LLM
'text_generation': { pipeline:'text-generation', q:'instruct', gguf:'gguf' },
'embeddings': { pipeline:'feature-extraction', q:'embedding', gguf:'all' },
'image_to_text': { pipeline:'image-to-text', q:'vision', gguf:'gguf' },
// Image generation & editing
'image_generation': { pipeline:'text-to-image', q:'diffusion', gguf:'no-gguf' },
'image_to_image': { pipeline:'image-to-image', q:'diffusion', gguf:'no-gguf' },
'inpainting': { pipeline:'image-to-image', q:'inpainting', gguf:'no-gguf' },
'image_upscaling': { pipeline:'image-to-image', q:'upscaler', gguf:'no-gguf' },
'face_restoration': { pipeline:'image-to-image', q:'face restoration', gguf:'no-gguf' },
'style_transfer': { pipeline:'image-to-image', q:'style transfer', gguf:'no-gguf' },
'controlnet': { pipeline:'', q:'controlnet', gguf:'no-gguf' },
// Image analysis
'depth_estimation': { pipeline:'depth-estimation', q:'depth', gguf:'no-gguf' },
'image_segmentation': { pipeline:'image-segmentation', q:'segmentation', gguf:'no-gguf' },
'object_detection': { pipeline:'object-detection', q:'detection', gguf:'no-gguf' },
// Image 3D
'image_to_3d': { pipeline:'', q:'image-to-3d', gguf:'no-gguf' },
'model_3d_generation': { pipeline:'', q:'3d generation', gguf:'no-gguf' },
'model_3d_to_image': { pipeline:'', q:'3d rendering', gguf:'no-gguf' },
// Video
'video_generation': { pipeline:'text-to-video', q:'video', gguf:'no-gguf' },
'image_to_video': { pipeline:'image-to-video', q:'image-to-video', gguf:'no-gguf' },
'video_to_video': { pipeline:'', q:'video diffusion', gguf:'no-gguf' },
'video_interpolation': { pipeline:'', q:'frame interpolation', gguf:'no-gguf' },
'video_upscaling': { pipeline:'', q:'video upscaling', gguf:'no-gguf' },
'video_to_3d': { pipeline:'', q:'video depth', gguf:'no-gguf' },
// Audio
'speech_to_text': { pipeline:'automatic-speech-recognition',q:'whisper', gguf:'gguf' },
'subtitle_generation': { pipeline:'automatic-speech-recognition',q:'whisper', gguf:'gguf' },
'text_to_speech': { pipeline:'text-to-speech', q:'tts', gguf:'no-gguf' },
'audio_generation': { pipeline:'text-to-audio', q:'music', gguf:'no-gguf' },
'audio_to_audio': { pipeline:'audio-to-audio', q:'voice conversion', gguf:'no-gguf' },
};
function capSearchUrl(cap) {
const s = CAP_TO_HF_SEARCH[cap];
......@@ -2322,10 +2337,14 @@ function capSearchUrl(cap) {
function capMissingHtml(caps, label) {
if (!caps.length) return '';
const links = caps.map(cap => {
const url = capSearchUrl(cap);
const chip = `<span class="cap-chip dim">${cap.replace(/_/g,' ')}</span>`;
if (_localCapSet.has(cap)) {
const url = `/admin/models?local_cap=${encodeURIComponent(cap)}`;
return `<a href="${url}" class="cap-find-link" title="You have a local model with ${cap.replace(/_/g,' ')} — click to configure it">${chip}<span class="cap-find-icon" style="color:#6ecf7e">↑ configure</span></a>`;
}
const url = capSearchUrl(cap);
return url
? `<a href="${url}" target="_blank" class="cap-find-link" title="Find ${cap.replace(/_/g,' ')} model on HuggingFace">${chip}<span class="cap-find-icon">↗</span></a>`
? `<a href="${url}" target="_blank" class="cap-find-link" title="Find ${cap.replace(/_/g,' ')} model on HuggingFace">${chip}<span class="cap-find-icon">↗ HuggingFace</span></a>`
: chip;
}).join(' ');
return `<div class="cap-missing"><strong>${label}:</strong> ${links}</div>`;
......@@ -2335,7 +2354,7 @@ const SUB_CAPABILITY_RULES = {
'img-gen': { category:'image', requiresAny:['image_generation'] },
'img-edit': { category:'image', requiresAny:['image_to_image'] },
'img-inpaint': { category:'image', requiresAny:['inpainting'] },
'img-upscale': { category:'image', optional:['image_upscaling'] },
'img-upscale': { category:'image', requiresAny:['image_upscaling'] },
'img-depth': { category:'image', requiresAny:['depth_estimation'] },
'img-seg': { category:'image', requiresAny:['image_segmentation'] },
'img-faceswap': { category:'image', optional:['image_to_image'] },
......@@ -2349,7 +2368,7 @@ const SUB_CAPABILITY_RULES = {
'vid-interp': { category:'video', optional:['video_interpolation'], fallbackTypes:['video'] },
'vid-sub': { category:'video', optional:['subtitle_generation'], fallbackTypes:['video'] },
'vid-dub': { category:'video', optional:['subtitle_generation','speech_to_text','text_to_speech'], fallbackTypes:['video'] },
'vid-up': { category:'video', optional:['video_upscaling'], fallbackTypes:['video'] },
'vid-up': { category:'video', requiresAny:['video_upscaling'], fallbackTypes:['video'] },
'vid-faceswap': { category:'video', optional:['video_to_video'], fallbackTypes:['video'] },
'vid-outfit': { category:'video', optional:['video_to_video'], fallbackTypes:['video'] },
'img-to3d': { category:'image', optional:['depth_estimation','image_to_3d'], fallbackTypes:['image'] },
......@@ -3249,6 +3268,20 @@ async function loadModels() {
}
}
async function loadLocalCapabilities() {
try {
const r = await fetch('/admin/api/cached-models');
if (!r.ok) return;
const d = await r.json();
_localCapSet.clear();
[...(d.hf||[]), ...(d.gguf||[])].forEach(m => {
(m.capabilities||[]).forEach(cap => _localCapSet.add(cap));
});
// Re-render capability cards now that local data is available
renderCapabilityCards();
} catch(e) {}
}
const BADGE = {text:'mb-text',vision:'mb-vision',image:'mb-image',video:'mb-video',
audio:'mb-audio',tts:'mb-tts',audio_gen:'mb-audiogen',embedding:'mb-embed'};
const BLABEL = {text:'LLM',vision:'VLM',image:'IMG',video:'VID',audio:'STT',
......@@ -5908,6 +5941,7 @@ async function profEnvDelete(name) {
// ─────────────────────────────────────────────────────────────────
loadModels();
loadLocalCapabilities();
loadVoiceProfiles();
profCharLoad();
profEnvLoad();
......
......@@ -640,6 +640,7 @@ function closeModal(id){document.getElementById(id).classList.remove('show')}
/* ── Global settings ─────────────────────────────────── */
let _defaultOffloadDir = './offload';
let _highlightCap = null; // capability to highlight in local models list (from ?local_cap= param)
async function loadGlobalSettings(){
try{
......@@ -1218,7 +1219,8 @@ async function loadCachedModels(){
const loaded = _loadedKeys.has(m.id) || [..._loadedKeys].some(k=>k.endsWith(':'+m.id)||k===m.id);
const capBadges = fmtCapabilities(m.capabilities||[]);
const instBadgeHf = m.in_config ? _instanceBadge([m.id], (m.settings||{}).max_instances||1) : '';
return `<tr style="border-top:1px solid var(--border)">
const hlHf = _highlightCap && (m.capabilities||[]).includes(_highlightCap);
return `<tr${hlHf?' class="local-cap-highlight"':''} style="border-top:1px solid var(--border)${hlHf?';background:rgba(110,207,126,.07);outline:2px solid rgba(110,207,126,.25);outline-offset:-1px':''}">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:12px;max-width:260px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(m.id)}">${esc(m.id)}</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${fmtGB(m.size_gb)}</td>
<td style="text-align:right;padding:.4rem .25rem;color:var(--text-2)">${m.file_count}</td>
......@@ -1280,7 +1282,8 @@ async function loadCachedModels(){
return `<span class="badge ${active?'badge-ok':'badge-user'}" style="font-size:10px;padding:.1rem .3rem;cursor:pointer" title="${esc(f.filename)}" onclick="_switchGgufQuant(${idx},${JSON.stringify(f.path)})">${esc(q||f.filename)}</span>`;
}).join(' ')
: '';
return `<tr style="border-top:1px solid var(--border)" id="gguf-row-${idx}">
const hlGguf = _highlightCap && files.some(f=>(f.capabilities||[]).includes(_highlightCap));
return `<tr${hlGguf?' class="local-cap-highlight"':''} id="gguf-row-${idx}" style="border-top:1px solid var(--border)${hlGguf?';background:rgba(110,207,126,.07);outline:2px solid rgba(110,207,126,.25);outline-offset:-1px':''}">
<td style="padding:.4rem .25rem;font-family:monospace;font-size:11px;max-width:280px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${esc(primary.filename)}">${esc(base)}<br>${quantBadges}</td>
<td style="text-align:right;padding:.4rem .25rem;white-space:nowrap;color:var(--text-2)">${fmtGB(primary.size_gb)}</td>
<td style="padding:.4rem .25rem;font-size:11px">${capBadges||'<span class="muted small">—</span>'}</td>
......@@ -1303,6 +1306,12 @@ async function loadCachedModels(){
'<th style="text-align:center;padding:.3rem .25rem;font-weight:700">Config</th>'+
'<th></th></tr></thead><tbody>'+rows.join('')+'</tbody></table>';
}
// Scroll to first highlighted model if a local_cap deep-link is active
if (_highlightCap) {
const first = document.querySelector('.local-cap-highlight');
if (first) first.scrollIntoView({behavior:'smooth', block:'center'});
}
// Rebuild set of locally cached model IDs for search results indicator
_cachedSearchIds.clear();
(d.hf||[]).forEach(m => _cachedSearchIds.add(m.id));
......@@ -1383,8 +1392,27 @@ loadGlobalSettings();
refreshLocal();
// ── Deep-link from Studio: /admin/models?tab=search&q=...&pipeline=...&gguf=...
// ── or: /admin/models?local_cap=CAPABILITY — highlight local models with that capability
(function applyDeepLink(){
const p = new URLSearchParams(location.search);
// Local capability highlight mode
const localCap = p.get('local_cap');
if (localCap) {
_highlightCap = localCap;
const banner = document.createElement('div');
banner.id = 'local-cap-banner';
banner.style.cssText = 'background:var(--accent-s);border:1px solid rgba(99,102,241,.35);border-radius:8px;padding:.5rem .875rem;margin-bottom:.75rem;font-size:13px;display:flex;align-items:center;gap:.5rem;flex-wrap:wrap';
banner.innerHTML = `<span style="color:#A5B4FC;font-weight:600">Looking for capability:</span>`
+ `<span class="badge badge-user" style="font-size:12px">${localCap.replace(/_/g,' ')}</span>`
+ `<span style="color:var(--text-2)">— models highlighted below can be configured to enable it</span>`
+ `<button onclick="document.getElementById('local-cap-banner').remove();_highlightCap=null" style="margin-left:auto;background:none;border:none;color:var(--text-2);cursor:pointer;font-size:16px;line-height:1">×</button>`;
const tabLocal = document.getElementById('tab-local');
tabLocal.insertBefore(banner, tabLocal.firstChild);
// local tab is already active by default — no tab switch needed
return;
}
if (p.get('tab') !== 'search') return;
// Switch to the HF search tab
const tabBtn = document.querySelector('.tab[onclick*="search"]');
......@@ -1398,7 +1426,7 @@ refreshLocal();
const pipelineEl = document.getElementById('filter-pipeline');
if (pipelineEl && pipeline) pipelineEl.value = pipeline;
// Set gguf mode toggle
const gguf = p.get('gguf') || 'gguf';
const gguf = p.get('gguf') || 'all';
document.querySelectorAll('.tog-btn').forEach(btn => {
btn.classList.toggle('on', btn.dataset.val === gguf);
if (btn.dataset.val === gguf) btn.click(); // fires the existing toggle handler
......
......@@ -17,7 +17,11 @@
"""Model capabilities module."""
from dataclasses import dataclass
from typing import List
from threading import Lock
from typing import List, Optional
import json
import os
import time
@dataclass
......@@ -180,12 +184,22 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.video_generation = True
return caps
# ── Image: upscaling (checked before general SD rule to catch SD-family upscalers) ──
if any(x in n for x in ['real-esrgan', 'esrgan', 'swinir', 'edsr',
'bsrgan', 'hat-', 'dat-',
'x2-upscaler', 'x4-upscaler', 'x2_upscaler', 'x4_upscaler',
'latent-upscaler', 'latent_upscaler',
'ldm-super-resolution', 'rcan-', 'sr3-']):
caps.image_upscaling = True
caps.image_to_image = True
return caps
# ── Image: generation ────────────────────────────────────────────────────
if any(x in n for x in ['inpaint', 'instruct-pix2pix', 'paint-by-example']):
caps.inpainting = True
caps.image_generation = True
caps.image_to_image = True
caps.text_generation = True # T2I models process text
caps.text_generation = True
return caps
if 'controlnet' in n:
......@@ -194,14 +208,34 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.text_generation = True
return caps
if any(x in n for x in ['stable-diffusion', 'sd15', 'sdxl', 'sd-xl',
'playground', 'flux', 'kandinsky', 'deepfloyd',
'pixart', 'dalle', 'waifu', 'pony',
'realistic-vision', 'realistic_vision']):
if any(x in n for x in [
# Stable Diffusion family
'stable-diffusion', 'sd15', 'sdxl', 'sd-xl', 'sd-turbo', 'sdxl-turbo',
'sd3', 'stable-cascade', 'stable-zero123',
# Other major T2I families
'flux', 'flux-dev', 'flux-schnell',
'kandinsky', 'deepfloyd', 'pixart', 'hunyuan-dit',
'dalle', 'dall-e',
'playground', 'playgroundai',
'imagen', 'parti-',
# Community / fine-tuned SD models (common naming conventions)
'dreamshaper', 'dreamlike', 'juggernaut', 'revanimated',
'epicrealism', 'absolutereality', 'counterfeit', 'deliberate',
'anything-v', 'openjourney', 'realvis', 'photon-',
'waifu', 'pony', 'realistic-vision', 'realistic_vision',
'aingdiffusion', 'majicmix', 'chillout', 'ghostmix',
# Speed / distilled variants (turbo, lightning, hyper)
'image-turbo', 'image-lightning', 'image-hyper', 'image-flash',
'diffusion-turbo', 'diffusion-lightning',
# Explicit T2I naming
'text-to-image', 'txt2img', 'text2img', 't2i-adapter',
# Generic diffusion bases
'latent-diffusion', 'ldm-',
]):
caps.image_generation = True
caps.image_to_image = True
caps.inpainting = True # most SD/SDXL/Flux support inpainting variant
caps.text_generation = True # T2I models process text
caps.inpainting = True # most SD/SDXL/Flux checkpoints support inpainting via mask
caps.text_generation = True
return caps
# ── Image: analysis / processing ─────────────────────────────────────────
......@@ -217,12 +251,6 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
caps.image_to_text = True
return caps
if any(x in n for x in ['real-esrgan', 'esrgan', 'swinir', 'edsr',
'bsrgan', 'hat-', 'dat-']):
caps.image_upscaling = True
caps.image_to_image = True
return caps
if any(x in n for x in ['codeformer', 'gfpgan', 'restoreformer']):
caps.face_restoration = True
caps.image_upscaling = True
......@@ -235,20 +263,39 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
return caps
# ── Vision / multimodal LLMs ─────────────────────────────────────────────
if any(x in n for x in ['vision', 'vl-', '-vl', 'llava', 'qwen2-vl',
'qwen-vl', 'phi-4-mini', 'pixtral', 'clip',
'blip', 'internvl', 'moondream', 'idefics',
'cogvlm', 'minigpt', 'flamingo']):
if any(x in n for x in [
'vision', 'vl-', '-vl', 'llava', 'llava-next', 'llavanext',
'qwen2-vl', 'qwen-vl', 'qwen2.5-vl',
'phi-4-mini', 'phi-3-vision', 'phi3-vision',
'pixtral', 'mistral-pixtral',
'clip', 'clip-vit',
'blip', 'blip-2', 'blip2',
'internvl', 'internlm-xcomposer',
'moondream', 'idefics', 'idefics2', 'idefics3',
'cogvlm', 'cogvlm2',
'minigpt', 'minigpt4',
'flamingo', 'openflamingo',
'paligemma', 'gemma-2-vl',
'deepseek-vl', 'deepseek-vl2',
'minicpm-v', 'minicpmv',
'mllm', 'multimodal',
]):
caps.image_to_text = True
caps.text_generation = True
return caps
# ── Embeddings ───────────────────────────────────────────────────────────
if any(x in n for x in ['embed', 'bge-', 'e5-', 'minilm',
'sentence-transformer', 'nomic-embed',
'instructor-', 'gte-', 'jina-embed']):
if any(x in n for x in [
'embed', 'bge-', 'e5-', 'minilm',
'sentence-transformer', 'sentence-bert',
'nomic-embed', 'nomic-text',
'instructor-', 'gte-', 'jina-embed', 'jina-clip',
'all-mpnet', 'all-minilm', 'paraphrase-',
'multilingual-e5', 'multilingual-mpnet',
'text-embedding', 'voyage-',
]):
caps.embeddings = True
caps.text_generation = True # Embedding models process text
caps.text_generation = True
return caps
# ── GGUF quantised text models ───────────────────────────────────────────
......@@ -258,4 +305,130 @@ def detect_model_capabilities(model_name: str) -> ModelCapabilities:
# Default: text generation
caps.text_generation = True
return caps
# ── HuggingFace pipeline_tag → capability fields ─────────────────────────────
_PIPELINE_TAG_CAPS: dict = {
'text-generation': ['text_generation'],
'text2text-generation': ['text_generation'],
'image-to-text': ['image_to_text', 'text_generation'],
'visual-question-answering': ['image_to_text', 'text_generation'],
'image-text-to-text': ['image_to_text', 'text_generation'],
'text-to-image': ['image_generation', 'image_to_image', 'text_generation'],
'unconditional-image-generation': ['image_generation'],
'image-to-image': ['image_to_image'], # sub-typed below
'automatic-speech-recognition': ['speech_to_text'],
'audio-to-audio': ['audio_to_audio'],
'text-to-speech': ['text_to_speech'],
'text-to-audio': ['audio_generation'],
'text-to-video': ['video_generation', 'text_generation'],
'image-to-video': ['image_to_video'],
'feature-extraction': ['embeddings', 'text_generation'],
'sentence-similarity': ['embeddings', 'text_generation'],
'depth-estimation': ['depth_estimation', 'image_to_text'],
'image-segmentation': ['image_segmentation', 'image_to_text'],
'object-detection': ['object_detection', 'image_to_text'],
'image-classification': ['image_to_text'],
'video-classification': ['image_to_text'],
'zero-shot-image-classification': ['image_to_text'],
'mask-generation': ['image_segmentation', 'image_to_text'],
}
def detect_capabilities_from_pipeline_tag(
pipeline_tag: str, model_name: str = ""
) -> ModelCapabilities:
"""
Detect capabilities using the HuggingFace pipeline_tag as the primary signal,
supplemented by name heuristics for sub-types (e.g. upscaling vs I2I).
Falls back to detect_model_capabilities when tag is absent or unrecognised.
"""
tag = (pipeline_tag or "").lower().strip()
fields = _PIPELINE_TAG_CAPS.get(tag)
if not fields:
return detect_model_capabilities(model_name)
caps = ModelCapabilities()
for f in fields:
if hasattr(caps, f):
setattr(caps, f, True)
# For generic image-to-image, use name heuristics to identify sub-type
if tag == 'image-to-image' and model_name:
name_caps = detect_model_capabilities(model_name)
if name_caps.image_upscaling:
caps.image_upscaling = True
if name_caps.inpainting:
caps.inpainting = True
caps.image_generation = True
if name_caps.face_restoration:
caps.face_restoration = True
caps.image_upscaling = True
if name_caps.style_transfer:
caps.style_transfer = True
return caps
# ── Persistent capability cache ───────────────────────────────────────────────
# Populated from HF search results (pipeline_tag-based, authoritative).
# Used as first-pass during local model scans so pipeline_tag info survives offline.
_CAP_CACHE_TTL = 90 * 86400 # 90 days
_cap_cache: dict = {}
_cap_cache_path: Optional[str] = None
_cap_cache_lock = Lock()
def init_capability_cache(config_dir: str) -> None:
"""Load the on-disk capability cache. Call once at server startup."""
global _cap_cache, _cap_cache_path
_cap_cache_path = os.path.join(config_dir, "capability_cache.json")
try:
with open(_cap_cache_path) as f:
loaded = json.load(f)
now = time.time()
# Prune stale entries on load
_cap_cache = {k: v for k, v in loaded.items()
if now - v.get("ts", 0) < _CAP_CACHE_TTL}
except Exception:
_cap_cache = {}
def _flush_capability_cache() -> None:
if _cap_cache_path is None:
return
try:
with open(_cap_cache_path, "w") as f:
json.dump(_cap_cache, f, indent=2)
except Exception:
pass
def update_capability_cache(model_id: str, caps: ModelCapabilities) -> None:
"""Store authoritative (pipeline_tag-derived) capabilities and persist to disk."""
if not model_id:
return
with _cap_cache_lock:
_cap_cache[model_id] = {"caps": caps.to_list(), "ts": int(time.time())}
_flush_capability_cache()
def lookup_capability_cache(model_id: str) -> Optional[ModelCapabilities]:
"""Return cached ModelCapabilities, or None if absent or expired."""
if not model_id:
return None
entry = _cap_cache.get(model_id)
if not entry:
return None
if time.time() - entry.get("ts", 0) > _CAP_CACHE_TTL:
with _cap_cache_lock:
_cap_cache.pop(model_id, None)
return None
caps = ModelCapabilities()
for field in entry.get("caps", []):
if hasattr(caps, field):
setattr(caps, field, True)
return caps
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment