ming commited on
Commit
6e48ad3
·
1 Parent(s): 56b5c90

Add support for sshleifer/distilbart-cnn-6-6 model for V2 API

Browse files

- Added BART model detection and specific input handling
- Updated default model from t5-small to sshleifer/distilbart-cnn-6-6
- BART models now receive direct text input without prefixes
- Updated warmup and health check methods for BART compatibility
- Updated Dockerfile and README documentation

Dockerfile CHANGED
@@ -7,7 +7,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
7
  PYTHONPATH=/app \
8
  ENABLE_V1_WARMUP=false \
9
  ENABLE_V2_WARMUP=true \
10
- HF_MODEL_ID=t5-small \
11
  HF_HOME=/tmp/huggingface
12
 
13
  # Set work directory
 
7
  PYTHONPATH=/app \
8
  ENABLE_V1_WARMUP=false \
9
  ENABLE_V2_WARMUP=true \
10
+ HF_MODEL_ID=sshleifer/distilbart-cnn-6-6 \
11
  HF_HOME=/tmp/huggingface
12
 
13
  # Set work directory
README.md CHANGED
@@ -82,7 +82,7 @@ The service uses the following environment variables:
82
  - `ENABLE_V1_WARMUP`: Enable V1 warmup (default: `false`)
83
 
84
  ### V2 Configuration (HuggingFace)
85
- - `HF_MODEL_ID`: HuggingFace model ID (default: `t5-small`)
86
  - `HF_DEVICE_MAP`: Device mapping (default: `auto` for GPU fallback to CPU)
87
  - `HF_TORCH_DTYPE`: Torch dtype (default: `auto`)
88
  - `HF_HOME`: HuggingFace cache directory (default: `/tmp/huggingface`)
@@ -121,7 +121,7 @@ This app is optimized for deployment on Hugging Face Spaces using Docker SDK.
121
  ```bash
122
  ENABLE_V1_WARMUP=false
123
  ENABLE_V2_WARMUP=true
124
- HF_MODEL_ID=t5-small
125
  HF_HOME=/tmp/huggingface
126
  ```
127
 
@@ -134,7 +134,7 @@ HF_HOME=/tmp/huggingface
134
  - **Startup time**: ~30-60 seconds (when V1 warmup enabled)
135
 
136
  ### V2 (HuggingFace Streaming) - Primary on HF Spaces
137
- - **V2 Model**: t5-small (~250MB download)
138
  - **Memory usage**: ~500MB RAM (when V2 warmup enabled)
139
  - **Inference speed**: Real-time token streaming
140
  - **Startup time**: ~30-60 seconds (includes model download when V2 warmup enabled)
@@ -144,7 +144,7 @@ HF_HOME=/tmp/huggingface
144
  - **V2 warmup enabled by default** (`ENABLE_V2_WARMUP=true`)
145
  - **HuggingFace Spaces**: V2-only deployment (no Ollama)
146
  - **Local development**: V1 endpoints work if Ollama is running externally
147
- - **t5-small model**: Optimized for HuggingFace Spaces free tier
148
 
149
  ## 🛠️ Development
150
 
 
82
  - `ENABLE_V1_WARMUP`: Enable V1 warmup (default: `false`)
83
 
84
  ### V2 Configuration (HuggingFace)
85
+ - `HF_MODEL_ID`: HuggingFace model ID (default: `sshleifer/distilbart-cnn-6-6`)
86
  - `HF_DEVICE_MAP`: Device mapping (default: `auto` for GPU fallback to CPU)
87
  - `HF_TORCH_DTYPE`: Torch dtype (default: `auto`)
88
  - `HF_HOME`: HuggingFace cache directory (default: `/tmp/huggingface`)
 
121
  ```bash
122
  ENABLE_V1_WARMUP=false
123
  ENABLE_V2_WARMUP=true
124
+ HF_MODEL_ID=sshleifer/distilbart-cnn-6-6
125
  HF_HOME=/tmp/huggingface
126
  ```
127
 
 
134
  - **Startup time**: ~30-60 seconds (when V1 warmup enabled)
135
 
136
  ### V2 (HuggingFace Streaming) - Primary on HF Spaces
137
+ - **V2 Model**: sshleifer/distilbart-cnn-6-6 (~300MB download)
138
  - **Memory usage**: ~500MB RAM (when V2 warmup enabled)
139
  - **Inference speed**: Real-time token streaming
140
  - **Startup time**: ~30-60 seconds (includes model download when V2 warmup enabled)
 
144
  - **V2 warmup enabled by default** (`ENABLE_V2_WARMUP=true`)
145
  - **HuggingFace Spaces**: V2-only deployment (no Ollama)
146
  - **Local development**: V1 endpoints work if Ollama is running externally
147
+ - **distilbart-cnn-6-6 model**: Optimized for HuggingFace Spaces free tier with CNN/DailyMail fine-tuning
148
 
149
  ## 🛠️ Development
150
 
app/core/config.py CHANGED
@@ -34,7 +34,7 @@ class Settings(BaseSettings):
34
  max_tokens_default: int = Field(default=256, env="MAX_TOKENS_DEFAULT", ge=1)
35
 
36
  # V2 HuggingFace Configuration
37
- hf_model_id: str = Field(default="t5-small", env="HF_MODEL_ID")
38
  hf_device_map: str = Field(default="auto", env="HF_DEVICE_MAP") # "auto" for GPU fallback to CPU
39
  hf_torch_dtype: str = Field(default="auto", env="HF_TORCH_DTYPE") # "auto" for automatic dtype selection
40
  hf_cache_dir: str = Field(default="/tmp/huggingface", env="HF_HOME") # HuggingFace cache directory
 
34
  max_tokens_default: int = Field(default=256, env="MAX_TOKENS_DEFAULT", ge=1)
35
 
36
  # V2 HuggingFace Configuration
37
+ hf_model_id: str = Field(default="sshleifer/distilbart-cnn-6-6", env="HF_MODEL_ID")
38
  hf_device_map: str = Field(default="auto", env="HF_DEVICE_MAP") # "auto" for GPU fallback to CPU
39
  hf_torch_dtype: str = Field(default="auto", env="HF_TORCH_DTYPE") # "auto" for automatic dtype selection
40
  hf_cache_dir: str = Field(default="/tmp/huggingface", env="HF_HOME") # HuggingFace cache directory
app/services/hf_streaming_summarizer.py CHANGED
@@ -93,8 +93,15 @@ class HFStreamingSummarizer:
93
  logger.warning("⚠️ HuggingFace model not initialized, skipping warmup")
94
  return
95
 
96
- # Use T5 format for warmup
97
- test_prompt = "summarize: This is a test."
 
 
 
 
 
 
 
98
 
99
  try:
100
  # Run in executor to avoid blocking
@@ -175,6 +182,15 @@ class HFStreamingSummarizer:
175
  max_length=512,
176
  truncation=True
177
  )
 
 
 
 
 
 
 
 
 
178
  else:
179
  # Other models use chat template
180
  messages = [
@@ -267,8 +283,16 @@ class HFStreamingSummarizer:
267
  return False
268
 
269
  try:
270
- # Quick test generation with T5 format
271
- test_input = self.tokenizer("summarize: test", return_tensors="pt")
 
 
 
 
 
 
 
 
272
  test_input = test_input.to(self.model.device)
273
 
274
  with torch.no_grad():
 
93
  logger.warning("⚠️ HuggingFace model not initialized, skipping warmup")
94
  return
95
 
96
+ # Determine appropriate test prompt based on model type
97
+ if "t5" in settings.hf_model_id.lower():
98
+ test_prompt = "summarize: This is a test."
99
+ elif "bart" in settings.hf_model_id.lower():
100
+ # BART models expect direct text input
101
+ test_prompt = "This is a test article for summarization."
102
+ else:
103
+ # Generic fallback
104
+ test_prompt = "This is a test article for summarization."
105
 
106
  try:
107
  # Run in executor to avoid blocking
 
182
  max_length=512,
183
  truncation=True
184
  )
185
+ elif "bart" in settings.hf_model_id.lower():
186
+ # BART models (including DistilBART) expect direct text input
187
+ # No prefixes or chat templates needed
188
+ inputs = self.tokenizer(
189
+ text,
190
+ return_tensors="pt",
191
+ max_length=1024,
192
+ truncation=True
193
+ )
194
  else:
195
  # Other models use chat template
196
  messages = [
 
283
  return False
284
 
285
  try:
286
+ # Determine appropriate test input based on model type
287
+ if "t5" in settings.hf_model_id.lower():
288
+ test_input_text = "summarize: test"
289
+ elif "bart" in settings.hf_model_id.lower():
290
+ # BART models expect direct text input
291
+ test_input_text = "This is a test article."
292
+ else:
293
+ test_input_text = "This is a test article."
294
+
295
+ test_input = self.tokenizer(test_input_text, return_tensors="pt")
296
  test_input = test_input.to(self.model.device)
297
 
298
  with torch.no_grad():