thinkingnew commited on
Commit
c829824
·
1 Parent(s): 17b1867
Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -31,13 +31,31 @@ class GenerateRequest(BaseModel):
31
  # **Use model.generate() instead of pipeline()**
32
  def generate_text_from_model(prompt: str):
33
  try:
34
- input_ids = tokenizer(f"<s>[INST] {prompt} [/INST]", return_tensors="pt").input_ids.to(device)
35
- output_ids = model.generate(input_ids, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
37
- return generated_text
 
 
 
38
  except Exception as e:
39
  raise HTTPException(status_code=500, detail=str(e))
40
 
 
41
  # Root endpoint for testing
42
  @app.get("/")
43
  async def root():
 
31
  # **Use model.generate() instead of pipeline()**
32
  def generate_text_from_model(prompt: str):
33
  try:
34
+ input_data = tokenizer(
35
+ f"<s>[INST] {prompt} [/INST]",
36
+ return_tensors="pt",
37
+ padding=True,
38
+ truncation=True
39
+ )
40
+ input_ids = input_data.input_ids.to(device)
41
+ attention_mask = input_data.attention_mask.to(device)
42
+
43
+ output_ids = model.generate(
44
+ input_ids,
45
+ max_length=512,
46
+ pad_token_id=tokenizer.eos_token_id, # Explicitly setting pad_token_id
47
+ attention_mask=attention_mask
48
+ )
49
+
50
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
51
+
52
+ # Extract only the assistant's response
53
+ response_text = generated_text.split("<|assistant|>\n")[-1].strip()
54
+ return response_text
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=str(e))
57
 
58
+
59
  # Root endpoint for testing
60
  @app.get("/")
61
  async def root():