Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import shutil | |
| import time | |
| import re | |
| from fastapi import FastAPI, Request, UploadFile | |
| from fastapi.middleware import Middleware | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import PlainTextResponse, StreamingResponse | |
| from .rag import ChatPDF | |
| middleware = [ | |
| Middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=['*'], | |
| allow_headers=['*'] | |
| ) | |
| ] | |
| app = FastAPI(middleware=middleware) | |
| files_dir = os.path.expanduser("~/wtp_be_files/") | |
| session_assistant = ChatPDF() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Session: | |
| isBusy = False # Processing upload or query response | |
| curUserID = "" | |
| prevUserID = "" | |
| lastQueryTimestamp = 0 | |
| session = Session() | |
| async def resolve_availability(request: Request, call_next): | |
| if session.isBusy: | |
| return PlainTextResponse("Server is busy", status_code=503) | |
| request_args = dict(request.query_params) | |
| if not "id" in request_args: | |
| return PlainTextResponse("Bad request", status_code=400) | |
| now = time.time() | |
| if session.curUserID == request_args["id"]: | |
| session.lastQueryTimestamp = now | |
| return await call_next(request) | |
| if session.prevUserID == request_args["id"]: | |
| return PlainTextResponse("Session has expired", status_code=419) | |
| if now - session.lastQueryTimestamp >= 300: | |
| session.lastQueryTimestamp = now | |
| session.prevUserID = session.curUserID | |
| session.curUserID = request_args["id"] | |
| return await call_next(request) | |
| return PlainTextResponse("Server is busy", status_code=503) | |
| def astreamer(generator): | |
| t0 = time.time() | |
| for i in generator: | |
| logger.info(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms) - {i}") | |
| yield i | |
| logger.info(f"X-Process-Time: {int((time.time()-t0)*1000)}ms") | |
| session.isBusy = False | |
| async def process_input(text: str): | |
| session.isBusy = True | |
| generator = None | |
| if text and len(text.strip()) > 0: | |
| if session_assistant.pdf_count > 0: | |
| text = text.strip() | |
| streaming_response = session_assistant.ask(text) | |
| generator = streaming_response.response_gen | |
| else: | |
| message = "Please provide the PDF document you'd like to add." | |
| generator = re.split(r'(\s)', message) | |
| else: | |
| message = "Your query is empty. Please provide a query for me to process." | |
| generator = re.split(r'(\s)', message) | |
| return StreamingResponse(astreamer(generator), media_type='text/event-stream') | |
| def upload(files: list[UploadFile]): | |
| session.isBusy = True | |
| try: | |
| os.makedirs(files_dir) | |
| for file in files: | |
| try: | |
| path = f"{files_dir}/{file.filename}" | |
| file.file.seek(0) | |
| with open(path, 'wb') as destination: | |
| shutil.copyfileobj(file.file, destination) | |
| finally: | |
| file.file.close() | |
| finally: | |
| session_assistant.ingest(files_dir) | |
| shutil.rmtree(files_dir) | |
| message = "All files have been added successfully to your account. Your first query may take a little longer as the system indexes your documents. Please be patient while we process your request." | |
| generator = re.split(r'(\s)', message) | |
| return StreamingResponse(astreamer(generator), media_type='text/event-stream') | |
| def clear(): | |
| session.isBusy = True | |
| session_assistant.clear() | |
| message = "Your files have been cleared successfully." | |
| generator = re.split(r'(\s)', message) | |
| return StreamingResponse(astreamer(generator), media_type='text/event-stream') | |
| def ping(): | |
| return "Pong!" | |