Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Load the tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("chatdb/natural-sql-7b") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "chatdb/natural-sql-7b", | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| schema = """ | |
| CREATE TABLE users ( | |
| id SERIAL PRIMARY KEY, | |
| manager_id INTEGER, | |
| first_name VARCHAR(100) NOT NULL, | |
| last_name VARCHAR(100) NOT NULL, | |
| designation VARCHAR(100), | |
| email VARCHAR(100) UNIQUE NOT NULL, | |
| phone VARCHAR(15) UNIQUE NOT NULL, | |
| password TEXT NOT NULL, | |
| role VARCHAR(50) NOT NULL, -- employee, manager, hr | |
| country VARCHAR(50) NOT NULL, -- pakistan, uae, uk | |
| fcm_token VARCHAR(255), | |
| image VARCHAR(255) DEFAULT '', | |
| created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| CREATE TABLE leaves_balances ( | |
| id SERIAL PRIMARY KEY, | |
| sick_available FLOAT NOT NULL, | |
| casual_available FLOAT NOT NULL, | |
| wfh_available FLOAT NOT NULL, | |
| sick_taken FLOAT NOT NULL, | |
| casual_taken FLOAT NOT NULL, | |
| wfh_taken FLOAT NOT NULL, | |
| user_id INTEGER UNIQUE REFERENCES users(id) ON DELETE CASCADE, | |
| created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| CREATE TABLE leaves ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
| manager_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
| username VARCHAR(100) NOT NULL, | |
| type VARCHAR(50) NOT NULL, -- sick, casual, wfh | |
| from_date TIMESTAMP NOT NULL, | |
| to_date TIMESTAMP NOT NULL, | |
| comments TEXT, | |
| status VARCHAR(50) DEFAULT 'pending', -- pending, approved, rejected | |
| created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| CREATE TABLE user_otps ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
| otp INTEGER NOT NULL, | |
| otp_expiry TIMESTAMP NOT NULL, | |
| created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| """ | |
| # Define the request body model using Pydantic | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| async def generate_sql(request: QuestionRequest): | |
| """ | |
| Endpoint to generate a SQL query based on a given question. | |
| The schema is defined within the code (in the `schema` variable). | |
| """ | |
| question = request.question | |
| if not question: | |
| raise HTTPException(status_code=400, detail="No question provided") | |
| prompt = f""" | |
| ### Task | |
| Generate a SQL query to answer the following question: `{question}` | |
| ### PostgreSQL Database Schema | |
| The query will run on a database with the following schema: | |
| {schema} | |
| ### Answer | |
| Here is the SQL query that answers the question: `{question}` | |
| ```sql | |
| """ | |
| # Generate SQL query | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
| generated_ids = model.generate( | |
| **inputs, | |
| num_return_sequences=1, | |
| eos_token_id=100001, | |
| pad_token_id=100001, | |
| max_new_tokens=400, | |
| do_sample=False, | |
| num_beams=1, | |
| ) | |
| outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| sql_query = outputs[0].split("```sql")[-1].strip() | |
| return JSONResponse(content={'sql_query': sql_query}) | |