Commit 673ac596 authored by Your Name's avatar Your Name

Fix: Update VulkanBackend method signatures to match base class

- Added repeat_penalty, presence_penalty, frequency_penalty params to generate() and generate_stream()
- Changed from **kwargs to explicit parameters to match base class abstract methods

This fixes the TypeError when calling VulkanBackend.generate_stream() with extra params.
parent 4249e178
......@@ -536,7 +536,10 @@ class VulkanBackend(ModelBackend):
temperature: float = 0.7,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
grammar: Optional[str] = None
grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0
) -> str:
"""Generate text non-streaming.
......@@ -547,6 +550,9 @@ class VulkanBackend(ModelBackend):
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional GBNF grammar string for constrained generation
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
Returns:
Generated text
......@@ -628,13 +634,27 @@ class VulkanBackend(ModelBackend):
async def generate_stream(
self,
prompt: str,
**kwargs
max_tokens: Optional[int] = None,
temperature: float = 0.7,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0
) -> AsyncIterator[str]:
"""Generate text with streaming.
Args:
prompt: Input prompt (or messages for chat)
**kwargs: Generation parameters
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional GBNF grammar string
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
Yields:
Generated text chunks
......@@ -647,15 +667,15 @@ class VulkanBackend(ModelBackend):
prompt = self._apply_chat_template(prompt, add_generation_prompt=True)
# Set defaults
max_tokens = kwargs.get('max_tokens', 256)
temperature = kwargs.get('temperature', 0.7)
top_p = kwargs.get('top_p', 0.9)
top_k = kwargs.get('top_k', 40)
repeat_penalty = kwargs.get('repeat_penalty', 1.1)
grammar = kwargs.get('grammar', None)
max_tokens = max_tokens if max_tokens is not None else 256
temperature = temperature if temperature is not None else 0.7
top_p = top_p if top_p is not None else 0.9
top_k = 40
repeat_penalty = repeat_penalty if repeat_penalty is not None else 1.1
grammar = grammar
# Get stop strings
stop = kwargs.get('stop', None)
stop = stop if stop is not None else None
if stop is None:
# Get default stop tokens based on template
stop = get_reasoning_stop_tokens(self.chat_template)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment