#!/usr/bin/env python3
"""
Test script for UDP broadcast functionality
"""

import socket
import json
import time
import threading
import sys

def test_udp_broadcast():
    """Test UDP broadcast functionality"""
    print("Testing UDP broadcast functionality...")
    
    # Test server info
    server_info = {
        "service": "MBetterClient",
        "host": "127.0.0.1",
        "port": 5001,
        "ssl": False,
        "url": "http://127.0.0.1:5001",
        "timestamp": time.time()
    }
    
    try:
        # Create UDP socket for broadcasting
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        
        # Broadcast data
        broadcast_data = json.dumps(server_info).encode('utf-8')
        
        print(f"Broadcasting: {server_info}")
        
        # Send to localhost broadcast and general broadcast
        addresses = [
            ('127.255.255.255', 45123),  # Local broadcast
            ('255.255.255.255', 45123),  # Global broadcast
        ]
        
        for addr in addresses:
            try:
                sock.sendto(broadcast_data, addr)
                print(f"✓ Sent to {addr}")
            except Exception as e:
                print(f"✗ Failed to send to {addr}: {e}")
        
        sock.close()
        print("✓ Broadcast test completed successfully")
        return True
        
    except Exception as e:
        print(f"✗ Broadcast test failed: {e}")
        return False

def test_udp_listener():
    """Test UDP listener functionality"""
    print("\nTesting UDP listener functionality...")
    
    def listen_for_broadcasts():
        try:
            # Create UDP socket for listening
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.bind(('', 45123))
            sock.settimeout(5.0)  # 5 second timeout
            
            print("Listening for broadcasts on port 45123...")
            
            while True:
                try:
                    data, addr = sock.recvfrom(1024)
                    message = json.loads(data.decode('utf-8'))
                    
                    if message.get('service') == 'MBetterClient':
                        print(f"✓ Received broadcast from {addr}: {message}")
                        sock.close()
                        return True
                        
                except socket.timeout:
                    print("⚠ Timeout - no broadcasts received")
                    sock.close()
                    return False
                except Exception as e:
                    print(f"✗ Listener error: {e}")
                    sock.close()
                    return False
                    
        except Exception as e:
            print(f"✗ Failed to create listener: {e}")
            return False
    
    # Start listener in separate thread
    listener_thread = threading.Thread(target=listen_for_broadcasts)
    listener_thread.daemon = True
    listener_thread.start()
    
    # Wait a moment then send test broadcast
    time.sleep(1)
    
    # Send test broadcast
    test_udp_broadcast()
    
    # Wait for listener to complete
    listener_thread.join(timeout=6)
    
    if listener_thread.is_alive():
        print("✗ Listener test timed out")
        return False
    else:
        print("✓ Listener test completed")
        return True

def test_json_structure():
    """Test JSON message structure"""
    print("\nTesting JSON message structure...")
    
    # Test valid message
    valid_message = {
        "service": "MBetterClient",
        "host": "192.168.1.100",
        "port": 5001,
        "ssl": True,
        "url": "https://192.168.1.100:5001",
        "timestamp": time.time()
    }
    
    try:
        # Test JSON encoding/decoding
        json_data = json.dumps(valid_message)
        decoded_data = json.loads(json_data)
        
        # Validate required fields
        required_fields = ['service', 'host', 'port', 'ssl', 'url', 'timestamp']
        missing_fields = [field for field in required_fields if field not in decoded_data]
        
        if missing_fields:
            print(f"✗ Missing required fields: {missing_fields}")
            return False
        
        if decoded_data['service'] != 'MBetterClient':
            print(f"✗ Invalid service name: {decoded_data['service']}")
            return False
        
        print("✓ JSON structure validation passed")
        print(f"  Sample message: {json_data}")
        return True
        
    except Exception as e:
        print(f"✗ JSON validation failed: {e}")
        return False

def main():
    """Main test function"""
    print("=" * 60)
    print("UDP Broadcast System Test")
    print("=" * 60)
    
    tests = [
        ("JSON Structure", test_json_structure),
        ("UDP Broadcast", test_udp_broadcast),
        ("UDP Listener", test_udp_listener),
    ]
    
    results = []
    
    for test_name, test_func in tests:
        print(f"\n--- {test_name} Test ---")
        try:
            result = test_func()
            results.append((test_name, result))
        except Exception as e:
            print(f"✗ {test_name} test crashed: {e}")
            results.append((test_name, False))
    
    # Summary
    print("\n" + "=" * 60)
    print("Test Results Summary")
    print("=" * 60)
    
    passed = 0
    total = len(results)
    
    for test_name, result in results:
        status = "✓ PASS" if result else "✗ FAIL"
        print(f"{test_name:<20} {status}")
        if result:
            passed += 1
    
    print(f"\nOverall: {passed}/{total} tests passed")
    
    if passed == total:
        print("🎉 All tests passed! UDP broadcast system is working correctly.")
        return True
    else:
        print("⚠ Some tests failed. Check the output above for details.")
        return False

if __name__ == "__main__":
    try:
        success = main()
        sys.exit(0 if success else 1)
    except KeyboardInterrupt:
        print("\n\nTest interrupted by user")
        sys.exit(1)