thinkingnew commited on
Commit
c520eb1
·
1 Parent(s): 10f12f1
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -15,7 +15,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  # Load base model with device_map="auto" to handle GPUs automatically
17
  base_model = AutoModelForCausalLM.from_pretrained(
18
- base_model_path, torch_dtype=torch.float16, device_map="auto")
 
19
 
20
  # Load adapter and ensure it is on the correct device
21
  model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
@@ -49,11 +50,7 @@ def generate_text_from_model(prompt: str):
49
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
50
 
51
  # Extract only the assistant's response
52
- if "<|assistant|>" in generated_text:
53
- response_text = generated_text.split("<|assistant|>")[-1].strip()
54
- else:
55
- response_text = generated_text.strip()
56
-
57
  return response_text
58
  except Exception as e:
59
  raise HTTPException(status_code=500, detail=str(e))
@@ -67,4 +64,4 @@ async def root():
67
  @app.post("/generate/")
68
  async def generate_text(request: GenerateRequest):
69
  response = generate_text_from_model(request.prompt)
70
- return response
 
15
 
16
  # Load base model with device_map="auto" to handle GPUs automatically
17
  base_model = AutoModelForCausalLM.from_pretrained(
18
+ base_model_path, torch_dtype=torch.float16, device_map="auto"
19
+ )
20
 
21
  # Load adapter and ensure it is on the correct device
22
  model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
 
50
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
51
 
52
  # Extract only the assistant's response
53
+ response_text = generated_text.split("<|assistant|>\n")[-1].strip()
 
 
 
 
54
  return response_text
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=str(e))
 
64
  @app.post("/generate/")
65
  async def generate_text(request: GenerateRequest):
66
  response = generate_text_from_model(request.prompt)
67
+ return {"response": response}