Update to use Qwen2_5OmniForConditionalGeneration and correct inference format

parent 417e1b19
Pipeline #205 canceled with stages
......@@ -4,3 +4,4 @@ librosa
resemblyzer
webrtcvad
scikit-learn
soundfile
\ No newline at end of file
import argparse
import torch
from transformers import AutoProcessor, AutoModel
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
from resemblyzer import VoiceEncoder
from sklearn.cluster import AgglomerativeClustering
import webrtcvad
......@@ -74,8 +75,8 @@ def main():
return
# Load Qwen2.5-Omni-7B model
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B", trust_remote_code=True)
model = AutoModel.from_pretrained("Qwen/Qwen2.5-Omni-7B", trust_remote_code=True)
model = Qwen2_5OmniForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B", torch_dtype="auto", device_map="auto")
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
# Load audio
audio, sr = librosa.load(audio_file, sr=16000)
......@@ -101,12 +102,16 @@ def main():
{"type": "text", "text": "Transcribe this audio segment exactly as spoken."}
]}
]
inputs = processor(conversation=conversation, return_tensors="pt")
# Generate transcription
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# Preparation for inference
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=False)
inputs = inputs.to(model.device).to(model.dtype)
# Inference: Generation of the output text
text_ids, _ = model.generate(**inputs, use_audio_in_video=False)
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# Format timestamps
start_min, start_sec = divmod(start, 60)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment