Spaces:
Runtime error
Runtime error
Commit
·
c520eb1
1
Parent(s):
10f12f1
updated
Browse files
app.py
CHANGED
|
@@ -15,7 +15,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 15 |
|
| 16 |
# Load base model with device_map="auto" to handle GPUs automatically
|
| 17 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 18 |
-
base_model_path, torch_dtype=torch.float16, device_map="auto"
|
|
|
|
| 19 |
|
| 20 |
# Load adapter and ensure it is on the correct device
|
| 21 |
model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
|
|
@@ -49,11 +50,7 @@ def generate_text_from_model(prompt: str):
|
|
| 49 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 50 |
|
| 51 |
# Extract only the assistant's response
|
| 52 |
-
|
| 53 |
-
response_text = generated_text.split("<|assistant|>")[-1].strip()
|
| 54 |
-
else:
|
| 55 |
-
response_text = generated_text.strip()
|
| 56 |
-
|
| 57 |
return response_text
|
| 58 |
except Exception as e:
|
| 59 |
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -67,4 +64,4 @@ async def root():
|
|
| 67 |
@app.post("/generate/")
|
| 68 |
async def generate_text(request: GenerateRequest):
|
| 69 |
response = generate_text_from_model(request.prompt)
|
| 70 |
-
return response
|
|
|
|
| 15 |
|
| 16 |
# Load base model with device_map="auto" to handle GPUs automatically
|
| 17 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 18 |
+
base_model_path, torch_dtype=torch.float16, device_map="auto"
|
| 19 |
+
)
|
| 20 |
|
| 21 |
# Load adapter and ensure it is on the correct device
|
| 22 |
model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
|
|
|
|
| 50 |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 51 |
|
| 52 |
# Extract only the assistant's response
|
| 53 |
+
response_text = generated_text.split("<|assistant|>\n")[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
return response_text
|
| 55 |
except Exception as e:
|
| 56 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 64 |
@app.post("/generate/")
|
| 65 |
async def generate_text(request: GenerateRequest):
|
| 66 |
response = generate_text_from_model(request.prompt)
|
| 67 |
+
return {"response": response}
|