Image generation with coderai

parent b0c0e3c9
......@@ -2049,7 +2049,8 @@ class RequestHandler:
multipart=multipart_payload,
stream=wants_stream,
)
response_headers = payload.get('headers') or {}
_skip_headers = {'content-length', 'transfer-encoding', 'content-encoding'}
response_headers = {k: v for k, v in (payload.get('headers') or {}).items() if k.lower() not in _skip_headers}
if payload.get('stream_chunks'):
media_type = payload.get('content_type') or 'text/event-stream'
......
......@@ -2608,3 +2608,193 @@ async def dashboard_billing(request: Request):
"stripe_publishable_key": stripe_publishable_key,
}
)
# ── Studio v1 API proxy ───────────────────────────────────────────────────────
# All /dashboard/api/studio/v1/* and /dashboard/api/studio/u/{username}/v1/*
# requests are forwarded to the appropriate RequestHandler method so the
# dashboard session can drive the same capabilities as the token-auth API.
_STUDIO_V1_BODY_MODEL_FIELDS = {
"v1/video/dub": ["video_model", "stt_model", "tts_model", "model"],
"v1/audio/clone": ["model", "tts_model"],
"v1/audio/convert": ["model", "audio_model", "tts_model", "stt_model"],
"v1/audio/split": ["model", "audio_model"],
"v1/audio/denoise": ["model", "audio_model"],
"v1/images/faceswap": ["model", "image_model", "video_model"],
"v1/images/outfit": ["model", "image_model", "video_model"],
"v1/images/to3d": ["model", "render_model", "image_model", "video_model"],
"v1/images/from3d": ["model", "render_model", "image_model", "video_model"],
"v1/video/to3d": ["model", "render_model", "image_model", "video_model"],
"v1/video/from3d": ["model", "render_model", "image_model", "video_model"],
"v1/3d/generate": ["model", "render_model", "image_model", "video_model"],
}
def _studio_normalize_model(path: str, body: dict) -> str:
"""Pick the first non-empty model field for multi-model endpoints."""
for field in _STUDIO_V1_BODY_MODEL_FIELDS.get(path, ["model"]):
v = (body.get(field) or '').strip()
if v:
return v
return (body.get('model') or '').strip()
async def _studio_v1_dispatch(request: Request, path: str, user_id: Optional[int]):
from aisbf.handlers import RequestHandler, RotationHandler, AutoselectHandler
try:
body = await request.json()
except Exception:
body = {}
body = dict(body)
full_path = f"v1/{path}"
model_str = _studio_normalize_model(full_path, body)
kind, source_id, target_id = _parse_studio_model_id(model_str)
if kind == 'rotation':
rot = RotationHandler(user_id=user_id)
if not hasattr(rot, '_select_provider_and_model'):
raise HTTPException(status_code=400, detail=f"Rotation '{source_id}' cannot be resolved for this endpoint")
provider_id, actual_model = rot._select_provider_and_model(source_id)
elif kind == 'autoselect':
asel_cfg = (_config.autoselect or {}).get(source_id) if _config else None
if not asel_cfg:
asel = AutoselectHandler(user_id=user_id)
asel_cfg = getattr(asel, 'user_autoselects', {}).get(source_id)
if not asel_cfg:
raise HTTPException(status_code=404, detail=f"Autoselect '{source_id}' not found")
fallback = getattr(asel_cfg, 'fallback', None) or (asel_cfg.get('fallback') if isinstance(asel_cfg, dict) else '') or ''
if '/' in fallback:
provider_id, actual_model = fallback.split('/', 1)
else:
raise HTTPException(status_code=400, detail=f"Cannot resolve autoselect '{source_id}' for non-chat endpoint")
elif kind == 'provider':
provider_id, actual_model = source_id, target_id
else:
raise HTTPException(status_code=400, detail="model required (format: provider/model, rotation/name, or autoselect/name)")
body['model'] = actual_model
handler = RequestHandler(user_id=user_id)
return await handler.handle_generic_proxy(request, provider_id, full_path, body)
async def _studio_progress(request: Request, endpoint_path: str, user_id: Optional[int]):
from aisbf.handlers import RequestHandler
handler = RequestHandler(user_id=user_id)
provider = request.query_params.get('provider', '').strip()
if provider:
return await handler.handle_generic_proxy(request, provider, endpoint_path, {}, method="GET")
return JSONResponse({"active": False, "current": 0, "total": 0, "pct": 0, "elapsed": 0, "it_per_s": 0})
@router.post("/dashboard/api/studio/v1/{path:path}")
async def dashboard_studio_v1_proxy(request: Request, path: str):
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
return await _studio_v1_dispatch(request, path, user_id=None)
@router.post("/dashboard/api/studio/u/{username}/v1/{path:path}")
async def dashboard_user_studio_v1_proxy(request: Request, username: str, path: str):
scope, user_id = _dashboard_studio_user_scope(request, username)
return await _studio_v1_dispatch(request, path, user_id=user_id)
@router.get("/dashboard/api/studio/images/progress")
async def dashboard_studio_images_progress(request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
return await _studio_progress(request, "v1/images/progress", user_id=None)
@router.get("/dashboard/api/studio/video/progress")
async def dashboard_studio_video_progress(request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
return await _studio_progress(request, "v1/video/progress", user_id=None)
@router.get("/dashboard/api/studio/audio/progress")
async def dashboard_studio_audio_progress(request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
return await _studio_progress(request, "v1/audio/progress", user_id=None)
@router.get("/dashboard/api/studio/u/{username}/images/progress")
async def dashboard_user_studio_images_progress(request: Request, username: str):
scope, user_id = _dashboard_studio_user_scope(request, username)
return await _studio_progress(request, "v1/images/progress", user_id=user_id)
@router.get("/dashboard/api/studio/u/{username}/video/progress")
async def dashboard_user_studio_video_progress(request: Request, username: str):
scope, user_id = _dashboard_studio_user_scope(request, username)
return await _studio_progress(request, "v1/video/progress", user_id=user_id)
@router.get("/dashboard/api/studio/u/{username}/audio/progress")
async def dashboard_user_studio_audio_progress(request: Request, username: str):
scope, user_id = _dashboard_studio_user_scope(request, username)
return await _studio_progress(request, "v1/audio/progress", user_id=user_id)
async def _studio_audio_transcription(request: Request, user_id: Optional[int]):
from aisbf.handlers import RequestHandler
from starlette.datastructures import FormData
form = await request.form()
model_str = (form.get('model') or '').strip()
kind, source_id, target_id = _parse_studio_model_id(model_str)
provider_id = source_id if kind == 'provider' else (model_str.split('/')[0] if '/' in model_str else model_str)
actual_model = target_id if kind == 'provider' else model_str
updated_form = FormData([(k, actual_model if k == 'model' else v) for k, v in form.multi_items()])
handler = RequestHandler(user_id=user_id)
return await handler.handle_audio_transcription(request, provider_id, updated_form)
async def _studio_audio_speech(request: Request, user_id: Optional[int]):
from aisbf.handlers import RequestHandler
body = await request.json()
body = dict(body)
model_str = (body.get('model') or '').strip()
kind, source_id, target_id = _parse_studio_model_id(model_str)
if kind == 'provider':
provider_id, actual_model = source_id, target_id
else:
provider_id, actual_model = model_str.split('/', 1) if '/' in model_str else (model_str, model_str)
body['model'] = actual_model
handler = RequestHandler(user_id=user_id)
return await handler.handle_text_to_speech(request, provider_id, body)
@router.post("/dashboard/api/studio/audio/transcriptions")
async def dashboard_studio_audio_transcriptions(request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
return await _studio_audio_transcription(request, user_id=None)
@router.post("/dashboard/api/studio/u/{username}/audio/transcriptions")
async def dashboard_user_studio_audio_transcriptions(request: Request, username: str):
scope, user_id = _dashboard_studio_user_scope(request, username)
return await _studio_audio_transcription(request, user_id=user_id)
@router.post("/dashboard/api/studio/audio/speech")
async def dashboard_studio_audio_speech(request: Request):
auth_check = require_dashboard_auth(request)
if auth_check:
return auth_check
return await _studio_audio_speech(request, user_id=None)
@router.post("/dashboard/api/studio/u/{username}/audio/speech")
async def dashboard_user_studio_audio_speech(request: Request, username: str):
scope, user_id = _dashboard_studio_user_scope(request, username)
return await _studio_audio_speech(request, user_id=user_id)
......@@ -31,14 +31,15 @@ let bindingSearchState = {};
let selectedBindingId = 'chat';
let _pendingBindingFocusKey = null;
function _startVidPoll(prefix) {
function _startVidPoll(prefix, provider) {
if (_vidPollTimer) { clearInterval(_vidPollTimer); _vidPollTimer = null; }
const wrap = $(prefix+'-pbar-wrap'), fill = $(prefix+'-pbar-fill'), lbl = $(prefix+'-pbar-label');
if (!wrap) return;
wrap.classList.add('active'); fill.style.width='0%'; lbl.textContent='';
const _progUrl = buildStudioUrl('/video/progress') + (provider ? '?provider='+encodeURIComponent(provider) : '');
_vidPollTimer = setInterval(async () => {
try {
const p = await (await dashboardFetch(buildStudioUrl('/video/progress'))).json();
const p = await (await dashboardFetch(_progUrl)).json();
if (p.total > 0) {
fill.style.width = p.pct + '%';
const spd = p.it_per_s > 0 ? ` · ${p.it_per_s} it/s` : (p.elapsed > 0 ? ` · ${p.elapsed}s` : '');
......@@ -59,14 +60,15 @@ function _stopVidPoll(prefix, done) {
else { wrap.classList.remove('active'); }
}
function _startAudPoll(prefix) {
function _startAudPoll(prefix, provider) {
if (_audPollTimer) { clearInterval(_audPollTimer); _audPollTimer = null; }
const wrap = $(prefix+'-pbar-wrap'), fill = $(prefix+'-pbar-fill'), lbl = $(prefix+'-pbar-label');
if (!wrap) return;
wrap.classList.add('active'); fill.style.width='0%'; lbl.textContent='';
const _progUrl = buildStudioUrl('/audio/progress') + (provider ? '?provider='+encodeURIComponent(provider) : '');
_audPollTimer = setInterval(async () => {
try {
const p = await (await dashboardFetch(buildStudioUrl('/audio/progress'))).json();
const p = await (await dashboardFetch(_progUrl)).json();
if (p.total > 0) {
fill.style.width = p.pct + '%';
const unit = p.unit || 'it';
......@@ -2033,6 +2035,17 @@ function modelForSub(sub) {
return activeModel?.id || '';
}
function providerForModelId(modelId) {
const m = models.find(x => x.id === modelId);
if (m?.sourceId) return m.sourceId;
const parts = (modelId || '').split('/');
if (parts.length === 3 && parts[0] === 'provider') return parts[1];
if (parts.length >= 2) return parts[0];
return '';
}
function providerForSub(sub) { return providerForModelId(modelForSub(sub)); }
// Assigns a model to a single-cap sub without changing the global activeModel.
function selectSubModel(sub, model) {
const cap = SUB_API_CAP[sub]
......@@ -2678,9 +2691,11 @@ async function genImage() {
const wrap=$('ig-pbar-wrap'), fill=$('ig-pbar-fill'), lbl=$('ig-pbar-label');
wrap.classList.add('active'); fill.style.width='0%'; lbl.textContent='';
$('ig-prog').scrollIntoView({behavior:'smooth', block:'nearest'});
const _igProvider = providerForSub('img-gen');
const _igProgUrl = buildStudioUrl('/images/progress') + (_igProvider ? '?provider='+encodeURIComponent(_igProvider) : '');
_imgPollTimer = setInterval(async()=>{
try{
const p=await (await dashboardFetch(buildStudioUrl('/images/progress'))).json();
const p=await (await dashboardFetch(_igProgUrl)).json();
if(p.total>0){
fill.style.width=p.pct+'%';
const spd = p.it_per_s>0 ? ` · ${p.it_per_s} it/s` : (p.elapsed>0 ? ` · ${p.elapsed}s` : '');
......@@ -2816,8 +2831,8 @@ async function genVideo(mode) {
const prefixMap = {t2v:'vt', i2v:'vi', v2v:'vv'};
const prog = progMap[mode], outId = outMap[mode], prefix = prefixMap[mode];
$(prog).textContent='Generating… (this may take several minutes)';
_startVidPoll(prefix);
const subId = {t2v:'vid-t2v', i2v:'vid-i2v', v2v:'vid-v2v'}[mode];
_startVidPoll(prefix, providerForSub(subId));
const body = {model:modelForSub(subId), mode};
if (mode==='t2v') {
body.prompt=val('vt-prompt'); body.negative_prompt=val('vt-neg');
......@@ -2882,7 +2897,7 @@ async function genTi2V() {
$('ti-prog').textContent='Generating… (may take several minutes)';
const [initImg, endImg, srcVid] = await Promise.all([b64OrNull('ti-init'), b64OrNull('ti-end'), b64OrNull('ti-vid')]);
if (!srcVid && !initImg) { $('ti-prog').textContent='Select an initial image or source video.'; return; }
_startVidPoll('ti');
_startVidPoll('ti', providerForSub('vid-ti2v'));
const body = {
model:modelForSub('vid-ti2v'),
......@@ -3664,7 +3679,7 @@ async function runPipeline4() {
async function genAudio() {
if (!activeModel) return;
$('ag-prog').textContent='Generating audio…';
_startAudPoll('ag');
_startAudPoll('ag', providerForSub('aud-gen'));
const melody = await b64OrNull('ag-melody');
const body = {
model:modelForSub('aud-gen'), prompt:val('ag-prompt'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment