Commit 673ac596 authored by Your Name's avatar Your Name

Fix: Update VulkanBackend method signatures to match base class

- Added repeat_penalty, presence_penalty, frequency_penalty params to generate() and generate_stream()
- Changed from **kwargs to explicit parameters to match base class abstract methods

This fixes the TypeError when calling VulkanBackend.generate_stream() with extra params.
parent 4249e178
...@@ -536,7 +536,10 @@ class VulkanBackend(ModelBackend): ...@@ -536,7 +536,10 @@ class VulkanBackend(ModelBackend):
temperature: float = 0.7, temperature: float = 0.7,
top_p: float = 1.0, top_p: float = 1.0,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
grammar: Optional[str] = None grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0
) -> str: ) -> str:
"""Generate text non-streaming. """Generate text non-streaming.
...@@ -547,6 +550,9 @@ class VulkanBackend(ModelBackend): ...@@ -547,6 +550,9 @@ class VulkanBackend(ModelBackend):
top_p: Top-p sampling top_p: Top-p sampling
stop: Stop sequences stop: Stop sequences
grammar: Optional GBNF grammar string for constrained generation grammar: Optional GBNF grammar string for constrained generation
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
Returns: Returns:
Generated text Generated text
...@@ -628,13 +634,27 @@ class VulkanBackend(ModelBackend): ...@@ -628,13 +634,27 @@ class VulkanBackend(ModelBackend):
async def generate_stream( async def generate_stream(
self, self,
prompt: str, prompt: str,
**kwargs max_tokens: Optional[int] = None,
temperature: float = 0.7,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
grammar: Optional[str] = None,
repeat_penalty: float = 1.0,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Generate text with streaming. """Generate text with streaming.
Args: Args:
prompt: Input prompt (or messages for chat) prompt: Input prompt (or messages for chat)
**kwargs: Generation parameters max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
grammar: Optional GBNF grammar string
repeat_penalty: Repetition penalty (1.0 = no penalty)
presence_penalty: Presence penalty
frequency_penalty: Frequency penalty
Yields: Yields:
Generated text chunks Generated text chunks
...@@ -647,15 +667,15 @@ class VulkanBackend(ModelBackend): ...@@ -647,15 +667,15 @@ class VulkanBackend(ModelBackend):
prompt = self._apply_chat_template(prompt, add_generation_prompt=True) prompt = self._apply_chat_template(prompt, add_generation_prompt=True)
# Set defaults # Set defaults
max_tokens = kwargs.get('max_tokens', 256) max_tokens = max_tokens if max_tokens is not None else 256
temperature = kwargs.get('temperature', 0.7) temperature = temperature if temperature is not None else 0.7
top_p = kwargs.get('top_p', 0.9) top_p = top_p if top_p is not None else 0.9
top_k = kwargs.get('top_k', 40) top_k = 40
repeat_penalty = kwargs.get('repeat_penalty', 1.1) repeat_penalty = repeat_penalty if repeat_penalty is not None else 1.1
grammar = kwargs.get('grammar', None) grammar = grammar
# Get stop strings # Get stop strings
stop = kwargs.get('stop', None) stop = stop if stop is not None else None
if stop is None: if stop is None:
# Get default stop tokens based on template # Get default stop tokens based on template
stop = get_reasoning_stop_tokens(self.chat_template) stop = get_reasoning_stop_tokens(self.chat_template)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment