Commit 2a5f2bf6 authored by Your Name's avatar Your Name

Fix generate() method signature to match base class

Now accepts positional args: max_tokens, temperature, top_p, stop
parent 280b91c3
...@@ -524,17 +524,32 @@ class VulkanBackend(ModelBackend): ...@@ -524,17 +524,32 @@ class VulkanBackend(ModelBackend):
def generate( def generate(
self, self,
prompt: str, prompt: str,
**kwargs max_tokens: Optional[int] = None,
temperature: float = 0.7,
top_p: float = 1.0,
stop: Optional[List[str]] = None
) -> str: ) -> str:
"""Generate text from prompt. """Generate text non-streaming.
Args: Args:
prompt: Input prompt (or messages for chat) prompt: Input prompt (or messages for chat)
**kwargs: Generation parameters max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
Returns: Returns:
Generated text Generated text
""" """
kwargs = {}
if max_tokens is not None:
kwargs['max_tokens'] = max_tokens
if temperature is not None:
kwargs['temperature'] = temperature
if top_p is not None:
kwargs['top_p'] = top_p
if stop is not None:
kwargs['stop'] = stop
if self.model is None: if self.model is None:
raise RuntimeError("Model not loaded") raise RuntimeError("Model not loaded")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment