Spaces:
Runtime error
Runtime error
Commit
·
678c4c4
1
Parent(s):
c829824
updated
Browse files
app.py
CHANGED
|
@@ -21,8 +21,9 @@ base_model = AutoModelForCausalLM.from_pretrained(
|
|
| 21 |
# Load adapter and ensure it is on the correct device
|
| 22 |
model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
|
| 23 |
|
| 24 |
-
# Load tokenizer
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
|
|
|
| 26 |
|
| 27 |
# Define request model for validation
|
| 28 |
class GenerateRequest(BaseModel):
|
|
@@ -40,6 +41,7 @@ def generate_text_from_model(prompt: str):
|
|
| 40 |
input_ids = input_data.input_ids.to(device)
|
| 41 |
attention_mask = input_data.attention_mask.to(device)
|
| 42 |
|
|
|
|
| 43 |
output_ids = model.generate(
|
| 44 |
input_ids,
|
| 45 |
max_length=512,
|
|
@@ -52,6 +54,9 @@ def generate_text_from_model(prompt: str):
|
|
| 52 |
# Extract only the assistant's response
|
| 53 |
response_text = generated_text.split("<|assistant|>\n")[-1].strip()
|
| 54 |
return response_text
|
|
|
|
|
|
|
|
|
|
| 55 |
except Exception as e:
|
| 56 |
raise HTTPException(status_code=500, detail=str(e))
|
| 57 |
|
|
@@ -65,4 +70,4 @@ async def root():
|
|
| 65 |
@app.post("/generate/")
|
| 66 |
async def generate_text(request: GenerateRequest):
|
| 67 |
response = generate_text_from_model(request.prompt)
|
| 68 |
-
return {"response": response}
|
|
|
|
| 21 |
# Load adapter and ensure it is on the correct device
|
| 22 |
model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
|
| 23 |
|
| 24 |
+
# Load tokenizer and ensure padding token is set
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
| 26 |
+
tokenizer.pad_token = tokenizer.eos_token # Avoids padding issues
|
| 27 |
|
| 28 |
# Define request model for validation
|
| 29 |
class GenerateRequest(BaseModel):
|
|
|
|
| 41 |
input_ids = input_data.input_ids.to(device)
|
| 42 |
attention_mask = input_data.attention_mask.to(device)
|
| 43 |
|
| 44 |
+
# Generate output
|
| 45 |
output_ids = model.generate(
|
| 46 |
input_ids,
|
| 47 |
max_length=512,
|
|
|
|
| 54 |
# Extract only the assistant's response
|
| 55 |
response_text = generated_text.split("<|assistant|>\n")[-1].strip()
|
| 56 |
return response_text
|
| 57 |
+
except torch.cuda.OutOfMemoryError:
|
| 58 |
+
torch.cuda.empty_cache()
|
| 59 |
+
raise HTTPException(status_code=500, detail="CUDA Out of Memory. Try using a smaller model or lowering max_length.")
|
| 60 |
except Exception as e:
|
| 61 |
raise HTTPException(status_code=500, detail=str(e))
|
| 62 |
|
|
|
|
| 70 |
@app.post("/generate/")
|
| 71 |
async def generate_text(request: GenerateRequest):
|
| 72 |
response = generate_text_from_model(request.prompt)
|
| 73 |
+
return {"response": response}
|