Commit 4249e178 authored by Your Name's avatar Your Name

Fix: Add repeat_penalty, presence_penalty, frequency_penalty params to NvidiaBackend

- Added missing parameters to generate() and generate_stream() methods
- Updated _generate_normal() and _generate_stream_normal() to use these params
- Also updated base.py abstract method signatures to match

This fixes the TypeError when using repeat_penalty with NVIDIA backend.
parent 441ea0fb
...@@ -15,14 +15,20 @@ class ModelBackend(ABC): ...@@ -15,14 +15,20 @@ class ModelBackend(ABC):
@abstractmethod @abstractmethod
def generate(self, prompt: str, max_tokens: Optional[int] = None, def generate(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0, temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> str: stop: Optional[List[str]] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> str:
"""Generate text non-streaming.""" """Generate text non-streaming."""
pass pass
@abstractmethod @abstractmethod
def generate_stream(self, prompt: str, max_tokens: Optional[int] = None, def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0, temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None) -> AsyncGenerator[str, None]: stop: Optional[List[str]] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> AsyncGenerator[str, None]:
"""Generate text in streaming fashion.""" """Generate text in streaming fashion."""
pass pass
......
...@@ -423,7 +423,10 @@ class NvidiaBackend(ModelBackend): ...@@ -423,7 +423,10 @@ class NvidiaBackend(ModelBackend):
def generate(self, prompt: str, max_tokens: Optional[int] = None, def generate(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0, temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
grammar: Optional[str] = None) -> str: grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> str:
"""Generate text non-streaming. """Generate text non-streaming.
Args: Args:
...@@ -433,6 +436,9 @@ class NvidiaBackend(ModelBackend): ...@@ -433,6 +436,9 @@ class NvidiaBackend(ModelBackend):
top_p: Top-p sampling top_p: Top-p sampling
stop: Stop sequences stop: Stop sequences
grammar: Optional regex pattern for constrained generation (outlines) grammar: Optional regex pattern for constrained generation (outlines)
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
""" """
import torch import torch
from transformers import LogitsProcessor, LogitsProcessorList from transformers import LogitsProcessor, LogitsProcessorList
...@@ -462,7 +468,7 @@ class NvidiaBackend(ModelBackend): ...@@ -462,7 +468,7 @@ class NvidiaBackend(ModelBackend):
# Fall through to normal generation # Fall through to normal generation
# Normal generation without grammar # Normal generation without grammar
return self._generate_normal(prompt, max_tokens, temperature, top_p, stop) return self._generate_normal(prompt, max_tokens, temperature, top_p, stop, repeat_penalty, presence_penalty, frequency_penalty)
def _generate_with_outlines(self, prompt: str, max_tokens: Optional[int], def _generate_with_outlines(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float, temperature: float, top_p: float,
...@@ -496,7 +502,10 @@ class NvidiaBackend(ModelBackend): ...@@ -496,7 +502,10 @@ class NvidiaBackend(ModelBackend):
def _generate_normal(self, prompt: str, max_tokens: Optional[int], def _generate_normal(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float, temperature: float, top_p: float,
stop: Optional[List[str]]) -> str: stop: Optional[List[str]],
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0) -> str:
"""Normal generation without grammar constraints.""" """Normal generation without grammar constraints."""
import torch import torch
from transformers import LogitsProcessor, LogitsProcessorList from transformers import LogitsProcessor, LogitsProcessorList
...@@ -515,19 +524,35 @@ class NvidiaBackend(ModelBackend): ...@@ -515,19 +524,35 @@ class NvidiaBackend(ModelBackend):
temperature, top_p, do_sample = self._validate_params(temperature, top_p) temperature, top_p, do_sample = self._validate_params(temperature, top_p)
# Build generation kwargs with penalty parameters
gen_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": max_tokens,
"temperature": temperature if do_sample else None,
"top_p": top_p if do_sample else None,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
}
# Add repetition penalty if not 1.0
if repeat_penalty != 1.0:
gen_kwargs["repetition_penalty"] = repeat_penalty
# Add presence and frequency penalties if not 0.0
if presence_penalty != 0.0 or frequency_penalty != 0.0:
# Transformers uses repetition_penalty for both presence and frequency
# For models that support both, we use the more general repetition_penalty
if repeat_penalty == 1.0:
# If no repetition_penalty set, use presence_penalty as repetition_penalty
# (this is an approximation - models may handle these differently)
gen_kwargs["repetition_penalty"] = max(presence_penalty, frequency_penalty) if max(presence_penalty, frequency_penalty) > 1.0 else 1.0
try: try:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(**gen_kwargs)
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
...@@ -538,17 +563,7 @@ class NvidiaBackend(ModelBackend): ...@@ -538,17 +563,7 @@ class NvidiaBackend(ModelBackend):
torch.cuda.empty_cache() torch.cuda.empty_cache()
try: try:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(**gen_kwargs)
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max(1, max_tokens // 2),
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
...@@ -560,7 +575,10 @@ class NvidiaBackend(ModelBackend): ...@@ -560,7 +575,10 @@ class NvidiaBackend(ModelBackend):
async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None, async def generate_stream(self, prompt: str, max_tokens: Optional[int] = None,
temperature: float = 0.7, top_p: float = 1.0, temperature: float = 0.7, top_p: float = 1.0,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
grammar: Optional[str] = None): grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0):
"""Generate text in streaming fashion. """Generate text in streaming fashion.
Args: Args:
...@@ -570,6 +588,9 @@ class NvidiaBackend(ModelBackend): ...@@ -570,6 +588,9 @@ class NvidiaBackend(ModelBackend):
top_p: Top-p sampling top_p: Top-p sampling
stop: Stop sequences stop: Stop sequences
grammar: Optional regex pattern for constrained generation (outlines) grammar: Optional regex pattern for constrained generation (outlines)
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
""" """
# Check for grammar-guided generation using outlines # Check for grammar-guided generation using outlines
use_grammar = grammar use_grammar = grammar
...@@ -598,7 +619,7 @@ class NvidiaBackend(ModelBackend): ...@@ -598,7 +619,7 @@ class NvidiaBackend(ModelBackend):
# Fall through to normal generation # Fall through to normal generation
# Normal streaming generation without grammar # Normal streaming generation without grammar
async for chunk in self._generate_stream_normal(prompt, max_tokens, temperature, top_p, stop): async for chunk in self._generate_stream_normal(prompt, max_tokens, temperature, top_p, stop, repeat_penalty, presence_penalty, frequency_penalty):
yield chunk yield chunk
async def _generate_stream_outlines(self, prompt: str, max_tokens: Optional[int], async def _generate_stream_outlines(self, prompt: str, max_tokens: Optional[int],
...@@ -634,7 +655,10 @@ class NvidiaBackend(ModelBackend): ...@@ -634,7 +655,10 @@ class NvidiaBackend(ModelBackend):
async def _generate_stream_normal(self, prompt: str, max_tokens: Optional[int], async def _generate_stream_normal(self, prompt: str, max_tokens: Optional[int],
temperature: float, top_p: float, temperature: float, top_p: float,
stop: Optional[List[str]]): stop: Optional[List[str]],
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0):
"""Normal streaming generation without grammar constraints.""" """Normal streaming generation without grammar constraints."""
import torch import torch
from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList from transformers import TextIteratorStreamer, LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
...@@ -672,6 +696,10 @@ class NvidiaBackend(ModelBackend): ...@@ -672,6 +696,10 @@ class NvidiaBackend(ModelBackend):
"logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]), "logits_processor": LogitsProcessorList([InvalidLogitsProcessor()]),
} }
# Add repetition penalty if not 1.0
if repeat_penalty != 1.0:
generation_kwargs["repetition_penalty"] = repeat_penalty
if stop: if stop:
class StopOnSequence(StoppingCriteria): class StopOnSequence(StoppingCriteria):
def __init__(self, stop_sequences, tokenizer): def __init__(self, stop_sequences, tokenizer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment