Implement wsssh subprocess spawning for terminal

parent a12360be
...@@ -28,6 +28,8 @@ import sys ...@@ -28,6 +28,8 @@ import sys
import os import os
import threading import threading
import uuid import uuid
import subprocess
import fcntl
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, send_from_directory
from flask_login import LoginManager, UserMixin, login_user, login_required, logout_user, current_user from flask_login import LoginManager, UserMixin, login_user, login_required, logout_user, current_user
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
...@@ -37,12 +39,11 @@ from werkzeug.security import generate_password_hash, check_password_hash ...@@ -37,12 +39,11 @@ from werkzeug.security import generate_password_hash, check_password_hash
clients = {} clients = {}
# Active tunnels: request_id -> {'client_ws': ws, 'wsssh_ws': ws, 'client_id': id} # Active tunnels: request_id -> {'client_ws': ws, 'wsssh_ws': ws, 'client_id': id}
active_tunnels = {} active_tunnels = {}
# Active terminals: request_id -> {'web_sid': sid, 'client_id': id, 'username': username} # Active terminals: request_id -> {'client_id': id, 'username': username, 'proc': proc}
active_terminals = {} active_terminals = {}
debug = False debug = False
server_password = None server_password = None
args = None args = None
loop = None
# Flask app for web interface # Flask app for web interface
app = Flask(__name__) app = Flask(__name__)
...@@ -182,19 +183,23 @@ def logos_files(filename): ...@@ -182,19 +183,23 @@ def logos_files(filename):
@login_required @login_required
def connect_terminal(client_id): def connect_terminal(client_id):
username = request.form.get('username', 'root') username = request.form.get('username', 'root')
if client_id in clients:
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
active_terminals[request_id] = {'client_id': client_id, 'username': username, 'data': []} # Spawn wsssh process
loop.call_soon_threadsafe(lambda: asyncio.create_task(send_terminal_request(request_id, client_id, username))) proc = subprocess.Popen(
['wsssh', '-P', str(args.port), f'{username}@{client_id}.{args.domain}'],
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=0
)
# Set stdout to non-blocking
fd = proc.stdout.fileno()
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
active_terminals[request_id] = {'client_id': client_id, 'username': username, 'proc': proc}
return jsonify({'request_id': request_id}) return jsonify({'request_id': request_id})
return jsonify({'error': 'Client not connected'}), 400
async def send_terminal_request(request_id, client_id, username):
await clients[client_id].send(json.dumps({
"type": "terminal_request",
"request_id": request_id,
"username": username
}))
@app.route('/terminal/<client_id>/data', methods=['GET', 'POST']) @app.route('/terminal/<client_id>/data', methods=['GET', 'POST'])
@login_required @login_required
...@@ -203,37 +208,42 @@ def terminal_data(client_id): ...@@ -203,37 +208,42 @@ def terminal_data(client_id):
request_id = request.form.get('request_id') request_id = request.form.get('request_id')
data = request.form.get('data') data = request.form.get('data')
if request_id in active_terminals: if request_id in active_terminals:
loop.call_soon_threadsafe(lambda: asyncio.create_task(send_terminal_data(request_id, active_terminals[request_id]['client_id'], data))) proc = active_terminals[request_id]['proc']
if proc.poll() is None: # Process is still running
proc.stdin.write(data.encode())
proc.stdin.flush()
return 'OK' return 'OK'
else: else:
request_id = request.args.get('request_id') request_id = request.args.get('request_id')
if request_id in active_terminals: if request_id in active_terminals:
data = active_terminals[request_id]['data'] proc = active_terminals[request_id]['proc']
active_terminals[request_id]['data'] = [] if proc.poll() is None:
return ''.join(data) try:
data = proc.stdout.read(1024).decode('utf-8', errors='ignore')
return data
except:
return ''
else:
# Process terminated
return '\r\nProcess terminated.\r\n'
return '' return ''
async def send_terminal_data(request_id, client_id, data):
await clients[client_id].send(json.dumps({
"type": "terminal_data",
"request_id": request_id,
"data": data
}))
@app.route('/terminal/<client_id>/disconnect', methods=['POST']) @app.route('/terminal/<client_id>/disconnect', methods=['POST'])
@login_required @login_required
def disconnect_terminal(client_id): def disconnect_terminal(client_id):
request_id = request.form.get('request_id') request_id = request.form.get('request_id')
if request_id in active_terminals: if request_id in active_terminals:
loop.call_soon_threadsafe(lambda: asyncio.create_task(send_terminal_close(request_id, active_terminals[request_id]['client_id']))) proc = active_terminals[request_id]['proc']
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=5)
except:
proc.kill()
del active_terminals[request_id] del active_terminals[request_id]
return 'OK' return 'OK'
async def send_terminal_close(request_id, client_id):
await clients[client_id].send(json.dumps({
"type": "terminal_close",
"request_id": request_id
}))
async def handle_websocket(websocket, path=None): async def handle_websocket(websocket, path=None):
try: try:
...@@ -307,19 +317,6 @@ async def handle_websocket(websocket, path=None): ...@@ -307,19 +317,6 @@ async def handle_websocket(websocket, path=None):
})) }))
# Clean up tunnel # Clean up tunnel
del active_tunnels[request_id] del active_tunnels[request_id]
elif data.get('type') == 'terminal_ack':
request_id = data['request_id']
if request_id in active_terminals:
active_terminals[request_id]['data'].append('\r\nConnected successfully!\r\n$ ')
elif data.get('type') == 'terminal_data':
request_id = data['request_id']
if request_id in active_terminals:
active_terminals[request_id]['data'].append(data['data'])
elif data.get('type') == 'terminal_close':
request_id = data['request_id']
if request_id in active_terminals:
active_terminals[request_id]['data'].append('\r\nTerminal closed.\r\n')
del active_terminals[request_id]
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
# Remove from registry and clean up tunnels # Remove from registry and clean up tunnels
disconnected_client = None disconnected_client = None
...@@ -374,9 +371,6 @@ async def main(): ...@@ -374,9 +371,6 @@ async def main():
# Start WebSocket server # Start WebSocket server
ws_server = await websockets.serve(handle_websocket, args.host, args.port, ssl=ssl_context) ws_server = await websockets.serve(handle_websocket, args.host, args.port, ssl=ssl_context)
global loop
loop = asyncio.get_running_loop()
print(f"WebSocket SSH Daemon running on {args.host}:{args.port}") print(f"WebSocket SSH Daemon running on {args.host}:{args.port}")
# Start web interface if specified # Start web interface if specified
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment