import argparse
import torch
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
from resemblyzer import VoiceEncoder
from sklearn.cluster import AgglomerativeClustering
import webrtcvad
import librosa
import numpy as np
import os

def get_diarization(audio, sr):
    encoder = VoiceEncoder()
    vad = webrtcvad.Vad(3)
    frame_duration = 0.01
    frame_length = int(sr * frame_duration)
    segments = []
    start = None
    for i in range(0, len(audio) - frame_length, frame_length):
        frame = audio[i:i+frame_length]
        is_speech = vad.is_speech((frame * 32767).astype(np.int16).tobytes(), sr)
        if is_speech and start is None:
            start = i / sr
        elif not is_speech and start is not None:
            end = i / sr
            segments.append((start, end))
            start = None
    if start is not None:
        segments.append((start, len(audio)/sr))
    if not segments:
        return []
    embeddings = []
    for start, end in segments:
        start_sample = int(start * sr)
        end_sample = int(end * sr)
        chunk = audio[start_sample:end_sample]
        if len(chunk) > 0:
            emb = encoder.embed_utterance(chunk)
            embeddings.append(emb)
    if len(embeddings) <= 1:
        labels = [0] * len(segments)
    else:
        clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=0.5).fit(embeddings)
        labels = clustering.labels_
        n_speakers = len(set(labels))
        if n_speakers > 1:
            clustering = AgglomerativeClustering(n_clusters=n_speakers).fit(embeddings)
            labels = clustering.labels_
    diarization = []
    for (start, end), label in zip(segments, labels):
        diarization.append((start, end, f"SPEAKER_{label:02d}"))
    # Merge continuous same speaker
    merged = []
    if diarization:
        current_start, current_end, current_speaker = diarization[0]
        for start, end, speaker in diarization[1:]:
            if speaker == current_speaker and abs(start - current_end) < 0.1:  # small gap
                current_end = end
            else:
                merged.append((current_start, current_end, current_speaker))
                current_start, current_end, current_speaker = start, end, speaker
        merged.append((current_start, current_end, current_speaker))
    return merged

def main():
    parser = argparse.ArgumentParser(description='Transcribe audio with speakers and timestamps using Qwen2.5-Omni-7B')
    parser.add_argument('audio_file', help='Path to the audio file')
    args = parser.parse_args()

    audio_file = args.audio_file

    # Check if file exists
    if not os.path.exists(audio_file):
        print(f"Error: Audio file '{audio_file}' not found.")
        return

    # Load Qwen2.5-Omni-7B model
    model = Qwen2_5OmniForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B", torch_dtype="auto", device_map="auto")
    processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")

    # Load audio
    audio, sr = librosa.load(audio_file, sr=16000)

    # Diarize
    diarization = get_diarization(audio, sr)

    # Transcribe each speaker segment
    output_lines = []
    for start, end, speaker_id in diarization:
        # Extract audio chunk
        start_sample = int(start * sr)
        end_sample = int(end * sr)
        audio_chunk = audio[start_sample:end_sample]

        if len(audio_chunk) == 0:
            continue

        # Prepare inputs for Qwen-Omni
        conversation = [
            {"role": "user", "content": [
                {"type": "audio", "audio": {"waveform": audio_chunk, "sample_rate": sr}},
                {"type": "text", "text": "Transcribe this audio segment exactly as spoken."}
            ]}
        ]

        # Preparation for inference
        text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
        audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
        inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=False)
        inputs = inputs.to(model.device).to(model.dtype)

        # Inference: Generation of the output text
        text_ids, _ = model.generate(**inputs, use_audio_in_video=False)
        text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        # Format timestamps
        start_min, start_sec = divmod(start, 60)
        start_hour, start_min = divmod(start_min, 60)
        start_str = f"{int(start_hour):02d}:{int(start_min):02d}:{start_sec:05.2f}"

        end_min, end_sec = divmod(end, 60)
        end_hour, end_min = divmod(end_min, 60)
        end_str = f"{int(end_hour):02d}:{int(end_min):02d}:{end_sec:05.2f}"

        line = f"[{start_str} - {end_str}] {speaker_id}: {text}"
        output_lines.append(line)

    # Write to output file
    base_name = os.path.splitext(audio_file)[0]
    output_file = base_name + '.txt'
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(output_lines))

    print(f"Transcription saved to {output_file}")

if __name__ == "__main__":
    main()