|
|
from typing import Any, Dict, List |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
import os |
|
|
|
|
|
MAX_INPUT_LENGTH = 256 |
|
|
MAX_OUTPUT_LENGTH = 128 |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None: |
|
|
|
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
|
|
|
if num_threads: |
|
|
try: |
|
|
torch.set_num_threads(num_threads) |
|
|
torch.set_num_interop_threads(max(1, num_threads // 2)) |
|
|
except Exception: |
|
|
pass |
|
|
os.environ.setdefault("OMP_NUM_THREADS", str(num_threads)) |
|
|
os.environ.setdefault("MKL_NUM_THREADS", str(num_threads)) |
|
|
|
|
|
self.device = "cpu" |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True) |
|
|
self.model.eval() |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
self._use_bf16 = False |
|
|
if os.getenv("ENABLE_BF16", "1") == "1": |
|
|
try: |
|
|
self.model = self.model.to(dtype=torch.bfloat16) |
|
|
self._use_bf16 = True |
|
|
except Exception: |
|
|
self._use_bf16 = False |
|
|
|
|
|
|
|
|
pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
|
|
|
|
|
|
|
|
default_gen = { |
|
|
"max_length": MAX_OUTPUT_LENGTH, |
|
|
"num_beams": 4, |
|
|
"do_sample": False, |
|
|
"no_repeat_ngram_size": 3, |
|
|
"early_stopping": True, |
|
|
"use_cache": True, |
|
|
"pad_token_id": pad_id, |
|
|
} |
|
|
if generation_config: |
|
|
default_gen.update(generation_config) |
|
|
self.generation_args = default_gen |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
inputs = data.get("inputs") |
|
|
if not inputs: |
|
|
raise ValueError("No 'inputs' found in the request data.") |
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
|
|
|
|
|
|
per_request_params = data.get("parameters") or {} |
|
|
|
|
|
if isinstance(per_request_params.get("generate_parameters"), dict): |
|
|
nested = per_request_params.pop("generate_parameters") |
|
|
per_request_params.update(nested) |
|
|
|
|
|
decode_params = {} |
|
|
if "clean_up_tokenization_spaces" in per_request_params: |
|
|
decode_params["clean_up_tokenization_spaces"] = per_request_params.pop("clean_up_tokenization_spaces") |
|
|
|
|
|
|
|
|
do_sample_req = bool(per_request_params.get("do_sample", self.generation_args.get("do_sample", False))) |
|
|
if "temperature" in per_request_params: |
|
|
|
|
|
if not do_sample_req: |
|
|
per_request_params.pop("temperature", None) |
|
|
else: |
|
|
|
|
|
try: |
|
|
temp_val = float(per_request_params["temperature"]) |
|
|
except (TypeError, ValueError): |
|
|
temp_val = None |
|
|
if not temp_val or temp_val <= 0: |
|
|
per_request_params["temperature"] = 1.0 |
|
|
|
|
|
|
|
|
allowed = set(self.model.generation_config.to_dict().keys()) | { |
|
|
"max_length","min_length","max_new_tokens","num_beams","num_return_sequences","temperature","top_k","top_p", |
|
|
"repetition_penalty","length_penalty","early_stopping","do_sample","no_repeat_ngram_size","use_cache", |
|
|
"pad_token_id","eos_token_id","bos_token_id","decoder_start_token_id","num_beam_groups","diversity_penalty", |
|
|
"penalty_alpha","typical_p","return_dict_in_generate","output_scores","output_attentions","output_hidden_states" |
|
|
} |
|
|
|
|
|
per_request_params.pop("attention_mask", None) |
|
|
filtered_params = {k: v for k, v in per_request_params.items() if k in allowed} |
|
|
gen_args = {**self.generation_args, **filtered_params} |
|
|
|
|
|
tokenized_inputs = self.tokenizer( |
|
|
inputs, |
|
|
max_length=MAX_INPUT_LENGTH, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
outputs = self.model.generate( |
|
|
tokenized_inputs["input_ids"], |
|
|
attention_mask=tokenized_inputs["attention_mask"], |
|
|
**gen_args |
|
|
) |
|
|
decoded_outputs = self.tokenizer.batch_decode( |
|
|
outputs, |
|
|
skip_special_tokens=True, |
|
|
**decode_params |
|
|
) |
|
|
results = [{"generated_text": text} for text in decoded_outputs] |
|
|
return results |
|
|
except Exception as e: |
|
|
return [{"generated_text": f"Error: {str(e)}"}] |