Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import argparse | |
| import numpy as np | |
| import connexion | |
| from flask_cors import CORS | |
| from flask import render_template, redirect, send_from_directory | |
| import utils.path_fixes as pf | |
| from utils.f import ifnone | |
| from model_api import get_details | |
| app = connexion.FlaskApp(__name__, static_folder="client/dist", specification_dir=".") | |
| flask_app = app.app | |
| CORS(flask_app) | |
| parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| parser.add_argument("--debug", action="store_true", help=" Debug mode") | |
| parser.add_argument("--port", default=5051, help="Port to run the app. ") | |
| # Flask main routes | |
| def hello_world(): | |
| return redirect("client/exBERT.html") | |
| # send everything from client as static content | |
| def send_static_client(path): | |
| """ serves all files from ./client/ to ``/client/<path:path>`` | |
| :param path: path from api call | |
| """ | |
| return send_from_directory(str(pf.CLIENT_DIST), path) | |
| # ====================================================================== | |
| ## CONNEXION API ## | |
| # ====================================================================== | |
| def get_model_details(**request): | |
| """Get important information about a model, like the number of layers and heads | |
| Args: | |
| request['model']: The model name | |
| Returns: | |
| { | |
| status: 200, | |
| payload: { | |
| nlayers (int) | |
| nheads (int) | |
| } | |
| } | |
| """ | |
| mname = request['model'] | |
| deets = get_details(mname) | |
| info = deets.config | |
| nlayers = info.num_hidden_layers | |
| nheads = info.num_attention_heads | |
| payload_out = { | |
| "nlayers": nlayers, | |
| "nheads": nheads, | |
| } | |
| return { | |
| "status": 200, | |
| "payload": payload_out, | |
| } | |
| def get_attentions_and_preds(**request): | |
| """For a sentence, at a layer, get the attentions and predictions | |
| Args: | |
| request['model']: Model name | |
| request['sentence']: Sentence to get the attentions for | |
| request['layer']: Which layer to extract from | |
| Returns: | |
| { | |
| status: 200 | |
| payload: { | |
| aa: { | |
| att: Array((nheads, ntoks, ntoks)) | |
| left: [{ | |
| text (str), | |
| topk_words (List[str]), | |
| topk_probs (List[float]) | |
| }, ...] | |
| right: [{ | |
| text (str), | |
| topk_words (List[str]), | |
| topk_probs (List[float]) | |
| }, ...] | |
| } | |
| } | |
| } | |
| """ | |
| model = request["model"] | |
| details = get_details(model) | |
| sentence = request["sentence"] | |
| layer = int(request["layer"]) | |
| deets = details.from_sentence(sentence) | |
| payload_out = deets.to_json(layer) | |
| return { | |
| "status": 200, | |
| "payload": payload_out | |
| } | |
| def update_masked_attention(**request): | |
| """From tokens and indices of what should be masked, get the attentions and predictions | |
| payload = request['payload'] | |
| Args: | |
| payload['model'] (str): Model name | |
| payload['tokens'] (List[str]): Tokens to pass through the model | |
| payload['sentence'] (str): Original sentence the tokens came from | |
| payload['mask'] (List[int]): Which indices to mask | |
| payload['layer'] (int): Which layer to extract information from | |
| Returns: | |
| { | |
| status: 200 | |
| payload: { | |
| aa: { | |
| att: Array((nheads, ntoks, ntoks)) | |
| left: [{ | |
| text (str), | |
| topk_words (List[str]), | |
| topk_probs (List[float]) | |
| }, ...] | |
| right: [{ | |
| text (str), | |
| topk_words (List[str]), | |
| topk_probs (List[float]) | |
| }, ...] | |
| } | |
| } | |
| } | |
| """ | |
| payload = request["payload"] | |
| model = payload['model'] | |
| details = get_details(model) | |
| tokens = payload["tokens"] | |
| sentence = payload["sentence"] | |
| mask = payload["mask"] | |
| layer = int(payload["layer"]) | |
| MASK = details.tok.mask_token | |
| mask_tokens = lambda toks, maskinds: [ | |
| t if i not in maskinds else ifnone(MASK, t) for (i, t) in enumerate(toks) | |
| ] | |
| token_inputs = mask_tokens(tokens, mask) | |
| deets = details.from_tokens(token_inputs, sentence) | |
| payload_out = deets.to_json(layer) | |
| return { | |
| "status": 200, | |
| "payload": payload_out, | |
| } | |
| app.add_api("swagger.yaml") | |
| # Setup code | |
| if __name__ != "__main__": | |
| print("SETTING UP ENDPOINTS") | |
| # Then deploy app | |
| else: | |
| args, _ = parser.parse_known_args() | |
| print("Initiating app") | |
| app.run(port=args.port, use_reloader=False, debug=args.debug) | |