Commit 4249e178 authored by Your Name's avatar Your Name

Fix: Add repeat_penalty, presence_penalty, frequency_penalty params to NvidiaBackend

- Added missing parameters to generate() and generate_stream() methods
- Updated _generate_normal() and _generate_stream_normal() to use these params
- Also updated base.py abstract method signatures to match

This fixes the TypeError when using repeat_penalty with NVIDIA backend.
parent 441ea0fb
......@@ -15,14 +15,20 @@ class ModelBackend(ABC):
@abstractmethod
def generate(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> str:
stop: Optional[List[str]] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> str:
"""Generate text non-streaming."""
pass
@abstractmethod
def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]:
stop: Optional[List[str]] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> AsyncGenerator[str, None]:
"""Generate text in streaming fashion."""
pass
......
......@@ -423,7 +423,10 @@ class NvidiaBackend(ModelBackend):
def generate(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None,
grammar: Optional[str] = None) -> str:
grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> str:
"""Generate text non-streaming.
Args:
......@@ -433,6 +436,9 @@ class NvidiaBackend(ModelBackend):
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional regex pattern for constrained generation (outlines)
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
"""
import torch
from transformers import LogitsProcessor, LogitsProcessorList
......@@ -462,7 +468,7 @@ class NvidiaBackend(ModelBackend):
# Fall through to normal generation
# Normal generation without grammar
return self._generate_normal(prompt, max_tokens, temperature, top_p, stop)
return self._generate_normal(prompt, max_tokens, temperature, top_p, stop, repeat_penalty, presence_penalty, frequency_penalty)
def _generate_with_outlines(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
......@@ -496,7 +502,10 @@ class NvidiaBackend(ModelBackend):
def _generate_normal(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
stop: Optional[List[str]]) -> str:
stop: Optional[List[str]],
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> str:
"""Normal generation without grammar constraints."""
import torch
from transformers import LogitsProcessor, LogitsProcessorList
......@@ -515,19 +524,35 @@ class NvidiaBackend(ModelBackend):
temperature, top_p, do_sample = self._validate_params(temperature, top_p)
# Build generation kwargs with penalty parameters
gen_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": max_tokens,
"temperature": temperature if do_sample else None,
"top_p": top_p if do_sample else None,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
}
# Add repetition penalty if not 1.0
if repeat_penalty != 1.0:
gen_kwargs["repetition_penalty"] = repeat_penalty
# Add presence and frequency penalties if not 0.0
if presence_penalty != 0.0 or frequency_penalty != 0.0:
# Transformers uses repetition_penalty for both presence and frequency
# For models that support both, we use the more general repetition_penalty
if repeat_penalty == 1.0:
# If no repetition_penalty set, use presence_penalty as repetition_penalty
# (this is an approximation - models may handle these differently)
gen_kwargs["repetition_penalty"] = max(presence_penalty, frequency_penalty) if max(presence_penalty, frequency_penalty) > 1.0 else 1.0
try:
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
)
outputs = self.model.generate(**gen_kwargs)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
......@@ -538,17 +563,7 @@ class NvidiaBackend(ModelBackend):
torch.cuda.empty_cache()
try:
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max(1, max_tokens // 2),
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
)
outputs = self.model.generate(**gen_kwargs)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
......@@ -560,7 +575,10 @@ class NvidiaBackend(ModelBackend):
async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None,
grammar: Optional[str] = None):
grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0):
"""Generate text in streaming fashion.
Args:
......@@ -570,6 +588,9 @@ class NvidiaBackend(ModelBackend):
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional regex pattern for constrained generation (outlines)
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
"""
# Check for grammar-guided generation using outlines
use_grammar = grammar
......@@ -598,7 +619,7 @@ class NvidiaBackend(ModelBackend):
# Fall through to normal generation
# Normal streaming generation without grammar
async for chunk in self._generate_stream_normal(prompt, max_tokens, temperature, top_p, stop):
async for chunk in self._generate_stream_normal(prompt, max_tokens, temperature, top_p, stop, repeat_penalty, presence_penalty, frequency_penalty):
yield chunk
async def _generate_stream_outlines(self, prompt: str, max_tokens: Optional[int],
......@@ -634,7 +655,10 @@ class NvidiaBackend(ModelBackend):
async def _generate_stream_normal(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float,
stop: Optional[List[str]]):
stop: Optional[List[str]],
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0):
"""Normal streaming generation without grammar constraints."""
import torch
from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
......@@ -672,6 +696,10 @@ class NvidiaBackend(ModelBackend):
"logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
}
# Add repetition penalty if not 1.0
if repeat_penalty != 1.0:
generation_kwargs["repetition_penalty"] = repeat_penalty
if stop:
class StopOnSequence(StoppingCriteria):
def __init__(self, stop_sequences, tokenizer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment