thinkingnew commited on
Commit
678c4c4
·
1 Parent(s): c829824
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -21,8 +21,9 @@ base_model = AutoModelForCausalLM.from_pretrained(
21
  # Load adapter and ensure it is on the correct device
22
  model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
23
 
24
- # Load tokenizer
25
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
 
26
 
27
  # Define request model for validation
28
  class GenerateRequest(BaseModel):
@@ -40,6 +41,7 @@ def generate_text_from_model(prompt: str):
40
  input_ids = input_data.input_ids.to(device)
41
  attention_mask = input_data.attention_mask.to(device)
42
 
 
43
  output_ids = model.generate(
44
  input_ids,
45
  max_length=512,
@@ -52,6 +54,9 @@ def generate_text_from_model(prompt: str):
52
  # Extract only the assistant's response
53
  response_text = generated_text.split("<|assistant|>\n")[-1].strip()
54
  return response_text
 
 
 
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=str(e))
57
 
@@ -65,4 +70,4 @@ async def root():
65
  @app.post("/generate/")
66
  async def generate_text(request: GenerateRequest):
67
  response = generate_text_from_model(request.prompt)
68
- return {"response": response}
 
21
  # Load adapter and ensure it is on the correct device
22
  model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
23
 
24
+ # Load tokenizer and ensure padding token is set
25
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
26
+ tokenizer.pad_token = tokenizer.eos_token # Avoids padding issues
27
 
28
  # Define request model for validation
29
  class GenerateRequest(BaseModel):
 
41
  input_ids = input_data.input_ids.to(device)
42
  attention_mask = input_data.attention_mask.to(device)
43
 
44
+ # Generate output
45
  output_ids = model.generate(
46
  input_ids,
47
  max_length=512,
 
54
  # Extract only the assistant's response
55
  response_text = generated_text.split("<|assistant|>\n")[-1].strip()
56
  return response_text
57
+ except torch.cuda.OutOfMemoryError:
58
+ torch.cuda.empty_cache()
59
+ raise HTTPException(status_code=500, detail="CUDA Out of Memory. Try using a smaller model or lowering max_length.")
60
  except Exception as e:
61
  raise HTTPException(status_code=500, detail=str(e))
62
 
 
70
  @app.post("/generate/")
71
  async def generate_text(request: GenerateRequest):
72
  response = generate_text_from_model(request.prompt)
73
+ return {"response": response}