Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import os | |
| import re | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModel, | |
| T5ForConditionalGeneration, | |
| MBartForConditionalGeneration, | |
| AutoModelForSeq2SeqLM, | |
| ) | |
| from tqdm.auto import tqdm | |
| import streamlit as st | |
| from typing import Dict, List | |
| def load_model(model_name, device): | |
| print(f"Using model {model_name}") | |
| os.makedirs("cache", exist_ok=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache") | |
| model.to(device) | |
| model_name = model_name.split("/")[-1] | |
| load_model_path = os.path.join("models", f"{model_name}-best_loss.bin") | |
| print(f"Loading model from {load_model_path}") | |
| model.load_state_dict( | |
| torch.load(load_model_path, map_location=torch.device(device)) | |
| ) | |
| return model | |
| def load_tokenizer(model_name): | |
| print(f"Loading tokenizer {model_name}") | |
| if "mbart" in model_name.lower(): | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, src_lang="vi_VN", tgt_lang="vi_VN" | |
| ) | |
| # tokenizer.src_lang = "vi_VN" | |
| # tokenizer.tgt_lang = "vi_VN" | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| return tokenizer | |
| def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"): | |
| inputs = tokenizer( | |
| batch["src"], | |
| text_target=batch["tgt"] if is_train else None, | |
| padding="longest", | |
| max_length=max_len, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| for k, v in inputs.items(): | |
| inputs[k] = v.to(device) | |
| return inputs | |
| def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"): | |
| inputs = tokenizer( | |
| src, | |
| padding="longest", | |
| max_length=max_len, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| for k, v in inputs.items(): | |
| inputs[k] = v.to(device) | |
| return inputs | |
| def make_input_sentence_from_strings(data): | |
| # data = { | |
| # "CHỈ TIÊU": objective_name, | |
| # "ĐƠN VỊ": unit, | |
| # "ĐIỀU KIỆN": condition, | |
| # "KPI mục tiêu tháng": kpi_target, | |
| # "Đánh giá": evaluation_value, | |
| # "Thời gian báo cáo": current_time, | |
| # f"T{current_time[1]}.{current_time[0]} thực tế": real_value, | |
| # "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}", | |
| # f"T{previous_month[1]}.{previous_month[0]}": previous_month_value, | |
| # "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}", | |
| # f"T{previous_year[1]}.{previous_year[0]}": previous_year_value, | |
| # "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm", | |
| # f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare, | |
| # "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm", | |
| # "Previous month": previous_month, | |
| # "Previous year": previous_year, | |
| # } | |
| previous_month_value_key = data["Previous month value key"] | |
| previous_year_value_key = data["Previous year value key"] | |
| objective_name = data["CHỈ TIÊU"] | |
| unit = data["ĐƠN VỊ"] | |
| condition = data["ĐIỀU KIỆN"] | |
| kpi_target = data["KPI mục tiêu tháng"] | |
| current_time = data["Thời gian báo cáo"] | |
| real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"] | |
| evaluation_value = data["Đánh giá"] | |
| previous_month_value = data[previous_month_value_key] | |
| previous_year_value = data[previous_year_value_key] | |
| previous_month_compare_key = data["Previous month compare key"] | |
| previous_year_compare_key = data["Previous year compare key"] | |
| previous_month_compare = data[previous_month_compare_key] | |
| previous_year_compare = data[previous_year_compare_key] | |
| previous_month = data["Previous month"] | |
| previous_year = data["Previous year"] | |
| # make a template string from the following example: | |
| # """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}""" | |
| template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}' | |
| return template_str.format( | |
| objective_name, | |
| unit, | |
| condition, | |
| kpi_target, | |
| current_time[1], | |
| current_time[0], | |
| real_value, | |
| evaluation_value, | |
| previous_month[1], | |
| previous_month[0], | |
| previous_month_value, | |
| previous_month[1], | |
| previous_month[0], | |
| previous_month_compare, | |
| previous_year[1], | |
| previous_year[0], | |
| previous_year_value, | |
| previous_year[1], | |
| previous_year[0], | |
| previous_year_compare, | |
| ) | |
| def generate_description( | |
| input_string, model, tokenizer, device, max_len, model_name, beam_size | |
| ): | |
| model.eval() | |
| model = model.to(device) | |
| inputs = prepare_single_model_inputs( | |
| input_string, tokenizer, max_len=max_len, device=device | |
| ) | |
| if "mbart" in model_name.lower(): | |
| inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"] | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_len, | |
| num_beams=beam_size, | |
| # early_stopping=True, | |
| ) | |
| return tokenizer.batch_decode( | |
| outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |