thinkingnew commited on
Commit
4ec308a
·
1 Parent(s): 8c6d7e5
Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -5,18 +5,39 @@ import torch
5
 
6
  app = FastAPI()
7
 
8
- # Load Model from Hugging Face Hub
9
  base_model_path = "NousResearch/Hermes-3-Llama-3.2-3B"
10
  adapter_path = "thinkingnew/llama_invs_adapter"
11
 
 
 
 
 
12
  base_model = AutoModelForCausalLM.from_pretrained(
13
- base_model_path, torch_dtype=torch.float16, device_map="auto"
14
- )
15
- model = PeftModel.from_pretrained(base_model, adapter_path)
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @app.post("/generate/")
19
  async def generate_text(prompt: str):
20
- pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=512)
21
- result = pipe(f"<s>[INST] {prompt} [/INST]")
22
  return {"response": result[0]['generated_text']}
 
5
 
6
  app = FastAPI()
7
 
8
+ # Define paths
9
  base_model_path = "NousResearch/Hermes-3-Llama-3.2-3B"
10
  adapter_path = "thinkingnew/llama_invs_adapter"
11
 
12
+ # Check if GPU is available
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # Load base model
16
  base_model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto"
18
+ ).to(device)
19
+
20
+ # Load adapter
21
+ model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
22
+
23
+ # Load tokenizer
24
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
25
 
26
+ # Load pipeline once (for better performance)
27
+ text_pipe = pipeline(
28
+ task="text-generation",
29
+ model=model,
30
+ tokenizer=tokenizer,
31
+ max_length=512
32
+ )
33
+
34
+ # Root endpoint for testing
35
+ @app.get("/")
36
+ async def root():
37
+ return {"message": "Model is running! Use /generate/ for text generation."}
38
+
39
+ # Text generation endpoint
40
  @app.post("/generate/")
41
  async def generate_text(prompt: str):
42
+ result = text_pipe(f"<s>[INST] {prompt} [/INST]")
 
43
  return {"response": result[0]['generated_text']}