Fix NaN/inf probability tensor error during generation

- Add InvalidLogitsProcessor to replace NaN and Inf values with finite numbers
- Add _validate_generation_params() to clamp temperature and top_p to valid ranges
- Add try-except blocks with fallback to greedy decoding on numerical errors
- Add error handling in streaming responses to prevent crashes
- Fix temperature=0 handling to use greedy decoding instead of sampling
parent 087ba9e1
...@@ -21,7 +21,7 @@ An OpenAI-compatible API server for HuggingFace models with intelligent memory m ...@@ -21,7 +21,7 @@ An OpenAI-compatible API server for HuggingFace models with intelligent memory m
- Python 3.8+ - Python 3.8+
- For NVIDIA GPUs: CUDA toolkit (11.8+ recommended) - For NVIDIA GPUs: CUDA toolkit (11.8+ recommended)
- For AMD GPUs: ROCm (5.4+ recommended) - For AMD GPUs: ROCm (5.6+ recommended, 6.0+ preferred)
- For CPU-only: No additional requirements - For CPU-only: No additional requirements
### Basic Installation ### Basic Installation
...@@ -43,36 +43,44 @@ pip install -r requirements.txt ...@@ -43,36 +43,44 @@ pip install -r requirements.txt
PyTorch installation varies by platform. Uncomment the appropriate section in [`requirements.txt`](requirements.txt) or install manually: PyTorch installation varies by platform. Uncomment the appropriate section in [`requirements.txt`](requirements.txt) or install manually:
> **⚠️ WARNING: Shell Redirection Issue**
> When using `>=` in pip commands, always use **quotes** around the package specifier!
> Without quotes, the shell interprets `>` as output redirection.
>
> ❌ Wrong: `pip install torch>=2.0.0` (creates file named "=2.0.0")
> ✅ Correct: `pip install "torch>=2.0.0"` (with quotes)
> ✅ Also correct: `pip install torch==2.0.0` (exact version, no >=)
#### NVIDIA (CUDA) #### NVIDIA (CUDA)
```bash ```bash
# For CUDA 11.8 # For CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/cu118
# For CUDA 12.1 # For CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/cu121
# For CUDA 12.4 (latest) # For CUDA 12.4 (latest)
pip install torch torchvision torchaudio pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0"
``` ```
#### AMD (ROCm) #### AMD (ROCm)
```bash ```bash
# For ROCm 5.4.2 # For ROCm 6.0 (recommended for newer AMD GPUs)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2 pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/rocm6.0
# For ROCm 5.6
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6
# For ROCm 6.0 # For ROCm 5.6 (for older AMD GPUs)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0 pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/rocm5.6
``` ```
> **Note**: ROCm 5.4.2 is deprecated. Use ROCm 5.6 or 6.0 for better compatibility.
> Check available versions at: https://pytorch.org/get-started/locally/
#### CPU Only #### CPU Only
```bash ```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/cpu
``` ```
### Optional Dependencies ### Optional Dependencies
...@@ -83,7 +91,7 @@ For 4-bit and 8-bit quantization support (reduces VRAM requirements): ...@@ -83,7 +91,7 @@ For 4-bit and 8-bit quantization support (reduces VRAM requirements):
```bash ```bash
# CUDA # CUDA
pip install bitsandbytes>=0.41.0 pip install "bitsandbytes>=0.41.0"
# ROCm support may require building from source # ROCm support may require building from source
# See: https://github.com/TimDettmers/bitsandbytes # See: https://github.com/TimDettmers/bitsandbytes
...@@ -272,7 +280,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \ ...@@ -272,7 +280,7 @@ curl -X POST http://localhost:8000/v1/chat/completions \
```bash ```bash
# Install CUDA-enabled PyTorch # Install CUDA-enabled PyTorch
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/cu121
# Run with GPU acceleration (automatic) # Run with GPU acceleration (automatic)
python coderai --model meta-llama/Llama-2-7b-chat-hf python coderai --model meta-llama/Llama-2-7b-chat-hf
...@@ -284,8 +292,8 @@ python coderai --model meta-llama/Llama-2-7b-chat-hf --flash-attn ...@@ -284,8 +292,8 @@ python coderai --model meta-llama/Llama-2-7b-chat-hf --flash-attn
### ROCm (AMD GPU) ### ROCm (AMD GPU)
```bash ```bash
# Install ROCm-enabled PyTorch # Install ROCm-enabled PyTorch (use 6.0 for newer GPUs, 5.6 for older)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2 pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/rocm6.0
# Run with GPU acceleration (automatic) # Run with GPU acceleration (automatic)
python coderai --model meta-llama/Llama-2-7b-chat-hf python coderai --model meta-llama/Llama-2-7b-chat-hf
...@@ -297,7 +305,7 @@ python coderai --model meta-llama/Llama-2-7b-chat-hf ...@@ -297,7 +305,7 @@ python coderai --model meta-llama/Llama-2-7b-chat-hf
```bash ```bash
# Install CPU-only PyTorch # Install CPU-only PyTorch
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install "torch>=2.0.0" "torchvision>=0.15.0" "torchaudio>=2.0.0" --index-url https://download.pytorch.org/whl/cpu
# Run on CPU (automatic fallback) # Run on CPU (automatic fallback)
python coderai --model microsoft/DialoGPT-medium python coderai --model microsoft/DialoGPT-medium
...@@ -352,6 +360,17 @@ python coderai --model meta-llama/Llama-2-70b-chat-hf --load-in-8bit ...@@ -352,6 +360,17 @@ python coderai --model meta-llama/Llama-2-70b-chat-hf --load-in-8bit
## Troubleshooting ## Troubleshooting
### Shell Redirection Error: "No such file or directory: '0.0'"
**Problem**: Running `pip install torch>=2.0.0` fails with an error about file "0.0" or "=2.0.0" not found.
**Cause**: The shell interprets `>` as output redirection. The command creates a file named "=2.0.0" and installs an unversioned torch package.
**Solutions**:
1. **Use quotes** (recommended): `pip install "torch>=2.0.0"`
2. **Use exact versions**: `pip install torch==2.0.0`
3. **Use requirements.txt**: Add exact versions to requirements.txt and run `pip install -r requirements.txt`
### Out of Memory Errors ### Out of Memory Errors
**Problem**: `CUDA out of memory` or system RAM exhausted **Problem**: `CUDA out of memory` or system RAM exhausted
...@@ -373,6 +392,33 @@ python coderai --model meta-llama/Llama-2-70b-chat-hf --load-in-8bit ...@@ -373,6 +392,33 @@ python coderai --model meta-llama/Llama-2-70b-chat-hf --load-in-8bit
4. Check GPU compatibility (Ampere, Ada Lovelace, Hopper for NVIDIA) 4. Check GPU compatibility (Ampere, Ada Lovelace, Hopper for NVIDIA)
5. Skip Flash Attention - the server works without it 5. Skip Flash Attention - the server works without it
### Flash Attention: No module named 'torch' during build
**Problem**: Flash Attention build fails with `ModuleNotFoundError: No module named 'torch'` even though PyTorch is installed (e.g., PyTorch 2.9.1+rocm6.4).
**Cause**: pip uses isolated build environments by default, which prevents flash-attention from seeing the installed torch package during compilation.
**Solutions**:
1. **Use --no-build-isolation flag** (recommended):
```bash
pip install flash-attn --no-build-isolation
```
2. **For ROCm systems**, you may also need to limit parallel jobs to avoid resource exhaustion:
```bash
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```
3. **Use pre-built wheels** if available for your platform (check https://github.com/Dao-AILab/flash-attention/releases)
4. **ROCm 6.4 compatibility note**: Flash Attention may not officially support ROCm 6.4 yet (it was primarily built for ROCm 6.0). If build fails on ROCm 6.4, you can run without Flash Attention:
```bash
python coderai --model meta-llama/Llama-2-7b-chat-hf
# (omit the --flash-attn flag)
```
5. **Fallback**: The server works perfectly without Flash Attention - simply omit the `--flash-attn` flag when starting the server.
### bitsandbytes Not Working on ROCm ### bitsandbytes Not Working on ROCm
**Problem**: Quantization fails on AMD GPUs **Problem**: Quantization fails on AMD GPUs
......
...@@ -13,6 +13,7 @@ import re ...@@ -13,6 +13,7 @@ import re
import sys import sys
import time import time
import uuid import uuid
import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, List, Optional, Union from typing import AsyncGenerator, Dict, List, Optional, Union
...@@ -28,6 +29,8 @@ from transformers import ( ...@@ -28,6 +29,8 @@ from transformers import (
TextIteratorStreamer, TextIteratorStreamer,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
LogitsProcessor,
LogitsProcessorList,
) )
from threading import Thread from threading import Thread
...@@ -45,6 +48,24 @@ def check_flash_attn_availability() -> bool: ...@@ -45,6 +48,24 @@ def check_flash_attn_availability() -> bool:
return False return False
# =============================================================================
# Logits Processor for Numerical Stability
# =============================================================================
class InvalidLogitsProcessor(LogitsProcessor):
"""Replace NaN and Inf values in logits with finite values."""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Replace invalid values in logits."""
# Replace NaN with very negative number (near -inf but finite)
scores = torch.where(torch.isnan(scores), torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
# Replace Inf with large finite number
scores = torch.where(torch.isinf(scores), torch.tensor(1e9, dtype=scores.dtype, device=scores.device), scores)
# Replace -Inf with very negative finite number
scores = torch.where(scores < -1e9, torch.tensor(-1e9, dtype=scores.dtype, device=scores.device), scores)
return scores
# ============================================================================= # =============================================================================
# Memory Detection and Model Sizing # Memory Detection and Model Sizing
# ============================================================================= # =============================================================================
...@@ -611,6 +632,22 @@ class ModelManager: ...@@ -611,6 +632,22 @@ class ModelManager:
formatted.append("Assistant:") formatted.append("Assistant:")
return "\n\n".join(formatted) return "\n\n".join(formatted)
def _validate_generation_params(self, temperature: float, top_p: float) -> tuple:
"""Validate and clamp generation parameters for numerical stability."""
# Clamp temperature to avoid numerical issues
# Temperature must be > 0 for sampling, but very small values can cause issues
if temperature <= 0:
temperature = 1.0
do_sample = False
else:
temperature = max(0.01, min(temperature, 2.0))
do_sample = True
# Clamp top_p
top_p = max(0.0, min(top_p, 1.0))
return temperature, top_p, do_sample
def generate_stream( def generate_stream(
self, self,
prompt: str, prompt: str,
...@@ -628,6 +665,9 @@ class ModelManager: ...@@ -628,6 +665,9 @@ class ModelManager:
if max_tokens is None: if max_tokens is None:
max_tokens = 512 max_tokens = 512
# Validate parameters
temperature, top_p, do_sample = self._validate_generation_params(temperature, top_p)
streamer = TextIteratorStreamer( streamer = TextIteratorStreamer(
self.tokenizer, self.tokenizer,
skip_prompt=True, skip_prompt=True,
...@@ -638,14 +678,19 @@ class ModelManager: ...@@ -638,14 +678,19 @@ class ModelManager:
"input_ids": inputs["input_ids"], "input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"], "attention_mask": inputs["attention_mask"],
"max_new_tokens": max_tokens, "max_new_tokens": max_tokens,
"temperature": temperature if temperature > 0 else 1.0, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
"do_sample": temperature > 0, "do_sample": do_sample,
"streamer": streamer, "streamer": streamer,
"pad_token_id": self.tokenizer.pad_token_id, "pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id, "eos_token_id": self.tokenizer.eos_token_id,
} }
# Add logits processor to handle NaN/Inf values
generation_kwargs["logits_processor"] = LogitsProcessorList([
InvalidLogitsProcessor()
])
# Handle stop sequences # Handle stop sequences
if stop: if stop:
class StopOnSequence(StoppingCriteria): class StopOnSequence(StoppingCriteria):
...@@ -661,16 +706,34 @@ class ModelManager: ...@@ -661,16 +706,34 @@ class ModelManager:
StopOnSequence(stop, self.tokenizer) StopOnSequence(stop, self.tokenizer)
]) ])
# Run generation in a separate thread # Run generation in a separate thread with error handling
generated_text = ""
try:
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start() thread.start()
generated_text = ""
for text in streamer: for text in streamer:
generated_text += text generated_text += text
yield text yield text
thread.join() thread.join()
except RuntimeError as e:
if "probability tensor contains" in str(e):
print(f"Warning: Numerical error during generation: {e}")
print("This may be due to temperature=0 or numerical instability.")
print("Trying again with greedy decoding...")
# Fallback to greedy decoding
generation_kwargs["do_sample"] = False
generation_kwargs["temperature"] = None
generation_kwargs["top_p"] = None
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for text in streamer:
generated_text += text
yield text
thread.join()
else:
raise
def generate( def generate(
self, self,
...@@ -687,21 +750,46 @@ class ModelManager: ...@@ -687,21 +750,46 @@ class ModelManager:
if max_tokens is None: if max_tokens is None:
max_tokens = 512 max_tokens = 512
# Validate parameters
temperature, top_p, do_sample = self._validate_generation_params(temperature, top_p)
try:
with torch.no_grad(): with torch.no_grad():
outputs = self.model.generate( outputs = self.model.generate(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
temperature=temperature if temperature > 0 else 1.0, temperature=temperature if do_sample else None,
top_p=top_p, top_p=top_p if do_sample else None,
do_sample=temperature > 0, do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
stopping_criteria=self._create_stopping_criteria(stop) if stop else None, stopping_criteria=self._create_stopping_criteria(stop) if stop else None,
logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
) )
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:] generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
except RuntimeError as e:
if "probability tensor contains" in str(e):
print(f"Warning: Numerical error during generation: {e}")
print("Retrying with greedy decoding...")
# Fallback to greedy decoding
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
stopping_criteria=self._create_stopping_criteria(stop) if stop else None,
logits_processor=LogitsProcessorList([InvalidLogitsProcessor()]),
)
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
else:
raise
def _create_stopping_criteria(self, stop_sequences): def _create_stopping_criteria(self, stop_sequences):
"""Create stopping criteria for stop sequences.""" """Create stopping criteria for stop sequences."""
...@@ -821,6 +909,7 @@ async def stream_chat_response( ...@@ -821,6 +909,7 @@ async def stream_chat_response(
generated_text = "" generated_text = ""
try:
for chunk in model_manager.generate_stream( for chunk in model_manager.generate_stream(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -866,6 +955,22 @@ async def stream_chat_response( ...@@ -866,6 +955,22 @@ async def stream_chat_response(
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error during streaming generation: {e}")
# Send error event
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model_name,
"choices": [{
"index": 0,
"delta": {"content": f"\n[Generation error: {str(e)}]"},
"finish_reason": "stop",
}],
}
yield f"data: {json.dumps(data)}\n\n"
yield "data: [DONE]\n\n"
async def generate_chat_response( async def generate_chat_response(
...@@ -881,6 +986,7 @@ async def generate_chat_response( ...@@ -881,6 +986,7 @@ async def generate_chat_response(
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time()) created = int(time.time())
try:
generated_text = model_manager.generate( generated_text = model_manager.generate(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -920,6 +1026,9 @@ async def generate_chat_response( ...@@ -920,6 +1026,9 @@ async def generate_chat_response(
"total_tokens": len(model_manager.tokenizer.encode(prompt)) + len(model_manager.tokenizer.encode(generated_text)), "total_tokens": len(model_manager.tokenizer.encode(prompt)) + len(model_manager.tokenizer.encode(generated_text)),
}, },
} }
except Exception as e:
print(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
@app.post("/v1/completions") @app.post("/v1/completions")
...@@ -968,6 +1077,7 @@ async def stream_completion_response( ...@@ -968,6 +1077,7 @@ async def stream_completion_response(
completion_id = f"cmpl-{uuid.uuid4().hex}" completion_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time()) created = int(time.time())
try:
for chunk in model_manager.generate_stream( for chunk in model_manager.generate_stream(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -991,6 +1101,10 @@ async def stream_completion_response( ...@@ -991,6 +1101,10 @@ async def stream_completion_response(
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
except Exception as e:
print(f"Error during streaming completion: {e}")
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n"
yield "data: [DONE]\n\n"
async def generate_completion_response( async def generate_completion_response(
...@@ -1005,6 +1119,7 @@ async def generate_completion_response( ...@@ -1005,6 +1119,7 @@ async def generate_completion_response(
completion_id = f"cmpl-{uuid.uuid4().hex}" completion_id = f"cmpl-{uuid.uuid4().hex}"
created = int(time.time()) created = int(time.time())
try:
generated_text = model_manager.generate( generated_text = model_manager.generate(
prompt=prompt, prompt=prompt,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -1030,6 +1145,9 @@ async def generate_completion_response( ...@@ -1030,6 +1145,9 @@ async def generate_completion_response(
"total_tokens": len(model_manager.tokenizer.encode(prompt)) + len(model_manager.tokenizer.encode(generated_text)), "total_tokens": len(model_manager.tokenizer.encode(prompt)) + len(model_manager.tokenizer.encode(generated_text)),
}, },
} }
except Exception as e:
print(f"Error during completion: {e}")
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
# ============================================================================= # =============================================================================
......
...@@ -3,20 +3,27 @@ fastapi>=0.104.0 ...@@ -3,20 +3,27 @@ fastapi>=0.104.0
uvicorn[standard]>=0.24.0 uvicorn[standard]>=0.24.0
pydantic>=2.5.0 pydantic>=2.5.0
# PyTorch - Uncomment the appropriate version for your system: # PyTorch - Uncomment the appropriate version for your system.
# IMPORTANT: Use quotes around version specifiers to prevent shell interpretation!
# The >= operator will be interpreted as output redirection without quotes!
#
# Option 1: Use exact versions (recommended for requirements.txt)
# Option 2: Use quotes: pip install "torch>=2.0.0"
# For NVIDIA (CUDA): # For NVIDIA (CUDA):
# torch>=2.0.0 # torch==2.0.0
# torchvision>=0.15.0 # torchvision==0.15.0
# torchaudio>=2.0.0 # torchaudio==2.0.0
# For AMD (ROCm): # For AMD (ROCm) - see available versions at https://pytorch.org/get-started/locally/
# --index-url https://download.pytorch.org/whl/rocm5.4.2 # rocm6.0 is recommended for newer AMD GPUs, rocm5.6 for older ones
# torch>=2.0.0 # --index-url https://download.pytorch.org/whl/rocm6.0
# torchvision>=0.15.0 # torch==2.0.0
# torchaudio>=2.0.0 # torchvision==0.15.0
# torchaudio==2.0.0
# For CPU only: # For CPU only:
torch>=2.0.0 torch==2.0.0
# ML dependencies # ML dependencies
transformers>=4.35.0 transformers>=4.35.0
...@@ -37,6 +44,16 @@ procname>=0.3.0 ...@@ -37,6 +44,16 @@ procname>=0.3.0
# flash-attn>=2.5.0 # flash-attn>=2.5.0
# Installation instructions: # Installation instructions:
# 1. For NVIDIA GPUs: pip install torch torchvision torchaudio # IMPORTANT: Always use quotes or exact versions to avoid shell redirection issues!
# 2. For AMD GPUs: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2 #
# 3. For CPU only: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu # 1. For NVIDIA GPUs (CUDA 12.1):
# pip install torch torchvision torchaudio
#
# 2. For AMD GPUs (ROCm 6.0 recommended):
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0
#
# 3. For CPU only:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
#
# If you see "No such file or directory: '0.0'" errors, you forgot to use quotes!
# The shell interprets >= as redirection. Fix: pip install "torch>=2.0.0" (with quotes)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment