Update to use Qwen2_5OmniForConditionalGeneration and correct inference format

parent 417e1b19
Pipeline #205 canceled with stages
...@@ -3,4 +3,5 @@ transformers ...@@ -3,4 +3,5 @@ transformers
librosa librosa
resemblyzer resemblyzer
webrtcvad webrtcvad
scikit-learn scikit-learn
\ No newline at end of file soundfile
\ No newline at end of file
import argparse import argparse
import torch import torch
from transformers import AutoProcessor, AutoModel from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
from resemblyzer import VoiceEncoder from resemblyzer import VoiceEncoder
from sklearn.cluster import AgglomerativeClustering from sklearn.cluster import AgglomerativeClustering
import webrtcvad import webrtcvad
...@@ -74,8 +75,8 @@ def main(): ...@@ -74,8 +75,8 @@ def main():
return return
# Load Qwen2.5-Omni-7B model # Load Qwen2.5-Omni-7B model
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B", trust_remote_code=True) model = Qwen2_5OmniForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B", torch_dtype="auto", device_map="auto")
model = AutoModel.from_pretrained("Qwen/Qwen2.5-Omni-7B", trust_remote_code=True) processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
# Load audio # Load audio
audio, sr = librosa.load(audio_file, sr=16000) audio, sr = librosa.load(audio_file, sr=16000)
...@@ -101,12 +102,16 @@ def main(): ...@@ -101,12 +102,16 @@ def main():
{"type": "text", "text": "Transcribe this audio segment exactly as spoken."} {"type": "text", "text": "Transcribe this audio segment exactly as spoken."}
]} ]}
] ]
inputs = processor(conversation=conversation, return_tensors="pt")
# Generate transcription # Preparation for inference
with torch.no_grad(): text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
generated_ids = model.generate(**inputs, max_new_tokens=200, do_sample=False) audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=False)
inputs = inputs.to(model.device).to(model.dtype)
# Inference: Generation of the output text
text_ids, _ = model.generate(**inputs, use_audio_in_video=False)
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# Format timestamps # Format timestamps
start_min, start_sec = divmod(start, 60) start_min, start_sec = divmod(start, 60)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment