Mir-2002's picture
Update handler.py
d712ca9 verified
from typing import Any, Dict, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
MAX_INPUT_LENGTH = 256
MAX_OUTPUT_LENGTH = 128
class EndpointHandler:
def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None:
# Set environment hints for CPU efficiency
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# Configure torch threading for CPU
if num_threads:
try:
torch.set_num_threads(num_threads)
torch.set_num_interop_threads(max(1, num_threads // 2))
except Exception:
pass
os.environ.setdefault("OMP_NUM_THREADS", str(num_threads))
os.environ.setdefault("MKL_NUM_THREADS", str(num_threads))
self.device = "cpu" # Force CPU usage
# Load tokenizer & model with CPU-friendly settings
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True)
self.model.eval()
self.model.to(self.device)
# Optional bfloat16 cast on CPU (beneficial on Sapphire Rapids/oneDNN)
self._use_bf16 = False
if os.getenv("ENABLE_BF16", "1") == "1":
try:
self.model = self.model.to(dtype=torch.bfloat16)
self._use_bf16 = True
except Exception:
self._use_bf16 = False
# Determine a safe pad token id
pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
# Default fast generation config (greedy) overridable by caller
default_gen = {
"max_length": MAX_OUTPUT_LENGTH,
"num_beams": 4, # Greedy for CPU speed
"do_sample": False,
"no_repeat_ngram_size": 3,
"early_stopping": True,
"use_cache": True,
"pad_token_id": pad_id,
}
if generation_config:
default_gen.update(generation_config)
self.generation_args = default_gen
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs")
if not inputs:
raise ValueError("No 'inputs' found in the request data.")
if isinstance(inputs, str):
inputs = [inputs]
# Allow per-request overrides under 'parameters'
per_request_params = data.get("parameters") or {}
# Unpack nested generate_parameters dict if provided
if isinstance(per_request_params.get("generate_parameters"), dict):
nested = per_request_params.pop("generate_parameters")
per_request_params.update(nested)
# Extract decode-only params
decode_params = {}
if "clean_up_tokenization_spaces" in per_request_params:
decode_params["clean_up_tokenization_spaces"] = per_request_params.pop("clean_up_tokenization_spaces")
# Sanitize sampling-related params to prevent invalid configs
do_sample_req = bool(per_request_params.get("do_sample", self.generation_args.get("do_sample", False)))
if "temperature" in per_request_params:
# If not sampling, drop temperature entirely
if not do_sample_req:
per_request_params.pop("temperature", None)
else:
# Ensure strictly positive float
try:
temp_val = float(per_request_params["temperature"])
except (TypeError, ValueError):
temp_val = None
if not temp_val or temp_val <= 0:
per_request_params["temperature"] = 1.0
# Filter only supported generation args to avoid warnings
allowed = set(self.model.generation_config.to_dict().keys()) | {
"max_length","min_length","max_new_tokens","num_beams","num_return_sequences","temperature","top_k","top_p",
"repetition_penalty","length_penalty","early_stopping","do_sample","no_repeat_ngram_size","use_cache",
"pad_token_id","eos_token_id","bos_token_id","decoder_start_token_id","num_beam_groups","diversity_penalty",
"penalty_alpha","typical_p","return_dict_in_generate","output_scores","output_attentions","output_hidden_states"
}
# Important: don't pass attention_mask via kwargs since we pass it explicitly
per_request_params.pop("attention_mask", None)
filtered_params = {k: v for k, v in per_request_params.items() if k in allowed}
gen_args = {**self.generation_args, **filtered_params}
tokenized_inputs = self.tokenizer(
inputs,
max_length=MAX_INPUT_LENGTH,
padding=True,
truncation=True,
return_tensors="pt"
).to(self.device)
try:
with torch.inference_mode():
outputs = self.model.generate(
tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
**gen_args
)
decoded_outputs = self.tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
**decode_params
)
results = [{"generated_text": text} for text in decoded_outputs]
return results
except Exception as e:
return [{"generated_text": f"Error: {str(e)}"}]