Mir-2002 commited on
Commit
d712ca9
·
verified ·
1 Parent(s): 7975585

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +126 -126
handler.py CHANGED
@@ -1,127 +1,127 @@
1
- from typing import Any, Dict, List
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
- import os
5
-
6
- MAX_INPUT_LENGTH = 256
7
- MAX_OUTPUT_LENGTH = 128
8
-
9
- class EndpointHandler:
10
- def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None:
11
- # Set environment hints for CPU efficiency
12
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
-
14
- # Configure torch threading for CPU
15
- if num_threads:
16
- try:
17
- torch.set_num_threads(num_threads)
18
- torch.set_num_interop_threads(max(1, num_threads // 2))
19
- except Exception:
20
- pass
21
- os.environ.setdefault("OMP_NUM_THREADS", str(num_threads))
22
- os.environ.setdefault("MKL_NUM_THREADS", str(num_threads))
23
-
24
- self.device = "cpu" # Force CPU usage
25
-
26
- # Load tokenizer & model with CPU-friendly settings
27
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
28
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True)
29
- self.model.eval()
30
- self.model.to(self.device)
31
-
32
- # Optional bfloat16 cast on CPU (beneficial on Sapphire Rapids/oneDNN)
33
- self._use_bf16 = False
34
- if os.getenv("ENABLE_BF16", "1") == "1":
35
- try:
36
- self.model = self.model.to(dtype=torch.bfloat16)
37
- self._use_bf16 = True
38
- except Exception:
39
- self._use_bf16 = False
40
-
41
- # Determine a safe pad token id
42
- pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
43
-
44
- # Default fast generation config (greedy) overridable by caller
45
- default_gen = {
46
- "max_length": MAX_OUTPUT_LENGTH,
47
- "num_beams": 1, # Greedy for CPU speed
48
- "do_sample": False,
49
- "no_repeat_ngram_size": 3,
50
- "early_stopping": True,
51
- "use_cache": True,
52
- "pad_token_id": pad_id,
53
- }
54
- if generation_config:
55
- default_gen.update(generation_config)
56
- self.generation_args = default_gen
57
-
58
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
59
- inputs = data.get("inputs")
60
- if not inputs:
61
- raise ValueError("No 'inputs' found in the request data.")
62
-
63
- if isinstance(inputs, str):
64
- inputs = [inputs]
65
-
66
- # Allow per-request overrides under 'parameters'
67
- per_request_params = data.get("parameters") or {}
68
- # Unpack nested generate_parameters dict if provided
69
- if isinstance(per_request_params.get("generate_parameters"), dict):
70
- nested = per_request_params.pop("generate_parameters")
71
- per_request_params.update(nested)
72
- # Extract decode-only params
73
- decode_params = {}
74
- if "clean_up_tokenization_spaces" in per_request_params:
75
- decode_params["clean_up_tokenization_spaces"] = per_request_params.pop("clean_up_tokenization_spaces")
76
-
77
- # Sanitize sampling-related params to prevent invalid configs
78
- do_sample_req = bool(per_request_params.get("do_sample", self.generation_args.get("do_sample", False)))
79
- if "temperature" in per_request_params:
80
- # If not sampling, drop temperature entirely
81
- if not do_sample_req:
82
- per_request_params.pop("temperature", None)
83
- else:
84
- # Ensure strictly positive float
85
- try:
86
- temp_val = float(per_request_params["temperature"])
87
- except (TypeError, ValueError):
88
- temp_val = None
89
- if not temp_val or temp_val <= 0:
90
- per_request_params["temperature"] = 1.0
91
-
92
- # Filter only supported generation args to avoid warnings
93
- allowed = set(self.model.generation_config.to_dict().keys()) | {
94
- "max_length","min_length","max_new_tokens","num_beams","num_return_sequences","temperature","top_k","top_p",
95
- "repetition_penalty","length_penalty","early_stopping","do_sample","no_repeat_ngram_size","use_cache",
96
- "pad_token_id","eos_token_id","bos_token_id","decoder_start_token_id","num_beam_groups","diversity_penalty",
97
- "penalty_alpha","typical_p","return_dict_in_generate","output_scores","output_attentions","output_hidden_states"
98
- }
99
- # Important: don't pass attention_mask via kwargs since we pass it explicitly
100
- per_request_params.pop("attention_mask", None)
101
- filtered_params = {k: v for k, v in per_request_params.items() if k in allowed}
102
- gen_args = {**self.generation_args, **filtered_params}
103
-
104
- tokenized_inputs = self.tokenizer(
105
- inputs,
106
- max_length=MAX_INPUT_LENGTH,
107
- padding=True,
108
- truncation=True,
109
- return_tensors="pt"
110
- ).to(self.device)
111
-
112
- try:
113
- with torch.inference_mode():
114
- outputs = self.model.generate(
115
- tokenized_inputs["input_ids"],
116
- attention_mask=tokenized_inputs["attention_mask"],
117
- **gen_args
118
- )
119
- decoded_outputs = self.tokenizer.batch_decode(
120
- outputs,
121
- skip_special_tokens=True,
122
- **decode_params
123
- )
124
- results = [{"generated_text": text} for text in decoded_outputs]
125
- return results
126
- except Exception as e:
127
  return [{"generated_text": f"Error: {str(e)}"}]
 
1
+ from typing import Any, Dict, List
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import os
5
+
6
+ MAX_INPUT_LENGTH = 256
7
+ MAX_OUTPUT_LENGTH = 128
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, model_dir: str = "", num_threads: int | None = None, generation_config: Dict[str, Any] | None = None, **kwargs: Any) -> None:
11
+ # Set environment hints for CPU efficiency
12
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
+
14
+ # Configure torch threading for CPU
15
+ if num_threads:
16
+ try:
17
+ torch.set_num_threads(num_threads)
18
+ torch.set_num_interop_threads(max(1, num_threads // 2))
19
+ except Exception:
20
+ pass
21
+ os.environ.setdefault("OMP_NUM_THREADS", str(num_threads))
22
+ os.environ.setdefault("MKL_NUM_THREADS", str(num_threads))
23
+
24
+ self.device = "cpu" # Force CPU usage
25
+
26
+ # Load tokenizer & model with CPU-friendly settings
27
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
28
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, low_cpu_mem_usage=True)
29
+ self.model.eval()
30
+ self.model.to(self.device)
31
+
32
+ # Optional bfloat16 cast on CPU (beneficial on Sapphire Rapids/oneDNN)
33
+ self._use_bf16 = False
34
+ if os.getenv("ENABLE_BF16", "1") == "1":
35
+ try:
36
+ self.model = self.model.to(dtype=torch.bfloat16)
37
+ self._use_bf16 = True
38
+ except Exception:
39
+ self._use_bf16 = False
40
+
41
+ # Determine a safe pad token id
42
+ pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
43
+
44
+ # Default fast generation config (greedy) overridable by caller
45
+ default_gen = {
46
+ "max_length": MAX_OUTPUT_LENGTH,
47
+ "num_beams": 4, # Greedy for CPU speed
48
+ "do_sample": False,
49
+ "no_repeat_ngram_size": 3,
50
+ "early_stopping": True,
51
+ "use_cache": True,
52
+ "pad_token_id": pad_id,
53
+ }
54
+ if generation_config:
55
+ default_gen.update(generation_config)
56
+ self.generation_args = default_gen
57
+
58
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
59
+ inputs = data.get("inputs")
60
+ if not inputs:
61
+ raise ValueError("No 'inputs' found in the request data.")
62
+
63
+ if isinstance(inputs, str):
64
+ inputs = [inputs]
65
+
66
+ # Allow per-request overrides under 'parameters'
67
+ per_request_params = data.get("parameters") or {}
68
+ # Unpack nested generate_parameters dict if provided
69
+ if isinstance(per_request_params.get("generate_parameters"), dict):
70
+ nested = per_request_params.pop("generate_parameters")
71
+ per_request_params.update(nested)
72
+ # Extract decode-only params
73
+ decode_params = {}
74
+ if "clean_up_tokenization_spaces" in per_request_params:
75
+ decode_params["clean_up_tokenization_spaces"] = per_request_params.pop("clean_up_tokenization_spaces")
76
+
77
+ # Sanitize sampling-related params to prevent invalid configs
78
+ do_sample_req = bool(per_request_params.get("do_sample", self.generation_args.get("do_sample", False)))
79
+ if "temperature" in per_request_params:
80
+ # If not sampling, drop temperature entirely
81
+ if not do_sample_req:
82
+ per_request_params.pop("temperature", None)
83
+ else:
84
+ # Ensure strictly positive float
85
+ try:
86
+ temp_val = float(per_request_params["temperature"])
87
+ except (TypeError, ValueError):
88
+ temp_val = None
89
+ if not temp_val or temp_val <= 0:
90
+ per_request_params["temperature"] = 1.0
91
+
92
+ # Filter only supported generation args to avoid warnings
93
+ allowed = set(self.model.generation_config.to_dict().keys()) | {
94
+ "max_length","min_length","max_new_tokens","num_beams","num_return_sequences","temperature","top_k","top_p",
95
+ "repetition_penalty","length_penalty","early_stopping","do_sample","no_repeat_ngram_size","use_cache",
96
+ "pad_token_id","eos_token_id","bos_token_id","decoder_start_token_id","num_beam_groups","diversity_penalty",
97
+ "penalty_alpha","typical_p","return_dict_in_generate","output_scores","output_attentions","output_hidden_states"
98
+ }
99
+ # Important: don't pass attention_mask via kwargs since we pass it explicitly
100
+ per_request_params.pop("attention_mask", None)
101
+ filtered_params = {k: v for k, v in per_request_params.items() if k in allowed}
102
+ gen_args = {**self.generation_args, **filtered_params}
103
+
104
+ tokenized_inputs = self.tokenizer(
105
+ inputs,
106
+ max_length=MAX_INPUT_LENGTH,
107
+ padding=True,
108
+ truncation=True,
109
+ return_tensors="pt"
110
+ ).to(self.device)
111
+
112
+ try:
113
+ with torch.inference_mode():
114
+ outputs = self.model.generate(
115
+ tokenized_inputs["input_ids"],
116
+ attention_mask=tokenized_inputs["attention_mask"],
117
+ **gen_args
118
+ )
119
+ decoded_outputs = self.tokenizer.batch_decode(
120
+ outputs,
121
+ skip_special_tokens=True,
122
+ **decode_params
123
+ )
124
+ results = [{"generated_text": text} for text in decoded_outputs]
125
+ return results
126
+ except Exception as e:
127
  return [{"generated_text": f"Error: {str(e)}"}]