thinkingnew commited on
Commit
17b1867
·
1 Parent(s): 30d0db5
Files changed (1) hide show
  1. app.py +5 -23
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from peft import PeftModel
5
  import torch
6
 
@@ -31,27 +31,10 @@ class GenerateRequest(BaseModel):
31
  # **Use model.generate() instead of pipeline()**
32
  def generate_text_from_model(prompt: str):
33
  try:
34
- input_data = tokenizer(
35
- f"<s>[INST] {prompt} [/INST]",
36
- return_tensors="pt",
37
- padding=True,
38
- truncation=True
39
- )
40
- input_ids = input_data.input_ids.to(device)
41
- attention_mask = input_data.attention_mask.to(device)
42
-
43
- output_ids = model.generate(
44
- input_ids,
45
- max_length=512,
46
- pad_token_id=tokenizer.eos_token_id,
47
- attention_mask=attention_mask
48
- )
49
-
50
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
51
-
52
- # Extract only the assistant's response
53
- response_text = generated_text.split("<|assistant|>\n")[-1].strip()
54
- return response_text
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=str(e))
57
 
@@ -64,5 +47,4 @@ async def root():
64
  @app.post("/generate/")
65
  async def generate_text(request: GenerateRequest):
66
  response = generate_text_from_model(request.prompt)
67
- return {"response": response}
68
-
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel
5
  import torch
6
 
 
31
  # **Use model.generate() instead of pipeline()**
32
  def generate_text_from_model(prompt: str):
33
  try:
34
+ input_ids = tokenizer(f"<s>[INST] {prompt} [/INST]", return_tensors="pt").input_ids.to(device)
35
+ output_ids = model.generate(input_ids, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
37
+ return generated_text
 
 
 
38
  except Exception as e:
39
  raise HTTPException(status_code=500, detail=str(e))
40
 
 
47
  @app.post("/generate/")
48
  async def generate_text(request: GenerateRequest):
49
  response = generate_text_from_model(request.prompt)
50
+ return {"response": response}