File size: 5,613 Bytes
d712ca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7975585
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Any, Dict, List
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os

MAX_INPUT_LENGTH = 256   
MAX_OUTPUT_LENGTH = 128 

class EndpointHandler:
    def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None:
        # Set environment hints for CPU efficiency
        os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

        # Configure torch threading for CPU
        if num_threads:
            try:
                torch.set_num_threads(num_threads)
                torch.set_num_interop_threads(max(1, num_threads // 2))
            except Exception:
                pass
            os.environ.setdefault("OMP_NUM_THREADS", str(num_threads))
            os.environ.setdefault("MKL_NUM_THREADS", str(num_threads))

        self.device = "cpu"  # Force CPU usage

        # Load tokenizer & model with CPU-friendly settings
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True)
        self.model.eval()
        self.model.to(self.device)

        # Optional bfloat16 cast on CPU (beneficial on Sapphire Rapids/oneDNN)
        self._use_bf16 = False
        if os.getenv("ENABLE_BF16", "1") == "1":
            try:
                self.model = self.model.to(dtype=torch.bfloat16)
                self._use_bf16 = True
            except Exception:
                self._use_bf16 = False

        # Determine a safe pad token id
        pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id

        # Default fast generation config (greedy) overridable by caller
        default_gen = {
            "max_length": MAX_OUTPUT_LENGTH,
            "num_beams": 4,              # Greedy for CPU speed
            "do_sample": False,
            "no_repeat_ngram_size": 3,
            "early_stopping": True,
            "use_cache": True,
            "pad_token_id": pad_id,
        }
        if generation_config:
            default_gen.update(generation_config)
        self.generation_args = default_gen

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: 
        inputs = data.get("inputs")
        if not inputs:
            raise ValueError("No 'inputs' found in the request data.")

        if isinstance(inputs, str):
            inputs = [inputs]

        # Allow per-request overrides under 'parameters'
        per_request_params = data.get("parameters") or {}
        # Unpack nested generate_parameters dict if provided
        if isinstance(per_request_params.get("generate_parameters"), dict):
            nested = per_request_params.pop("generate_parameters")
            per_request_params.update(nested)
        # Extract decode-only params
        decode_params = {}
        if "clean_up_tokenization_spaces" in per_request_params:
            decode_params["clean_up_tokenization_spaces"] = per_request_params.pop("clean_up_tokenization_spaces")

        # Sanitize sampling-related params to prevent invalid configs
        do_sample_req = bool(per_request_params.get("do_sample", self.generation_args.get("do_sample", False)))
        if "temperature" in per_request_params:
            # If not sampling, drop temperature entirely
            if not do_sample_req:
                per_request_params.pop("temperature", None)
            else:
                # Ensure strictly positive float
                try:
                    temp_val = float(per_request_params["temperature"])
                except (TypeError, ValueError):
                    temp_val = None
                if not temp_val or temp_val <= 0:
                    per_request_params["temperature"] = 1.0

        # Filter only supported generation args to avoid warnings
        allowed = set(self.model.generation_config.to_dict().keys()) | {
            "max_length","min_length","max_new_tokens","num_beams","num_return_sequences","temperature","top_k","top_p",
            "repetition_penalty","length_penalty","early_stopping","do_sample","no_repeat_ngram_size","use_cache",
            "pad_token_id","eos_token_id","bos_token_id","decoder_start_token_id","num_beam_groups","diversity_penalty",
            "penalty_alpha","typical_p","return_dict_in_generate","output_scores","output_attentions","output_hidden_states"
        }
        # Important: don't pass attention_mask via kwargs since we pass it explicitly
        per_request_params.pop("attention_mask", None)
        filtered_params = {k: v for k, v in per_request_params.items() if k in allowed}
        gen_args = {**self.generation_args, **filtered_params}

        tokenized_inputs = self.tokenizer(
            inputs,
            max_length=MAX_INPUT_LENGTH,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        try:
            with torch.inference_mode():
                outputs = self.model.generate(
                    tokenized_inputs["input_ids"],
                    attention_mask=tokenized_inputs["attention_mask"],
                    **gen_args
                )
            decoded_outputs = self.tokenizer.batch_decode(
                outputs,
                skip_special_tokens=True,
                **decode_params
            )
            results = [{"generated_text": text} for text in decoded_outputs]
            return results
        except Exception as e:
            return [{"generated_text": f"Error: {str(e)}"}]