Spaces:
Running
Running
ming
commited on
Commit
·
6e48ad3
1
Parent(s):
56b5c90
Add support for sshleifer/distilbart-cnn-6-6 model for V2 API
Browse files- Added BART model detection and specific input handling
- Updated default model from t5-small to sshleifer/distilbart-cnn-6-6
- BART models now receive direct text input without prefixes
- Updated warmup and health check methods for BART compatibility
- Updated Dockerfile and README documentation
- Dockerfile +1 -1
- README.md +4 -4
- app/core/config.py +1 -1
- app/services/hf_streaming_summarizer.py +28 -4
Dockerfile
CHANGED
|
@@ -7,7 +7,7 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|
| 7 |
PYTHONPATH=/app \
|
| 8 |
ENABLE_V1_WARMUP=false \
|
| 9 |
ENABLE_V2_WARMUP=true \
|
| 10 |
-
HF_MODEL_ID=
|
| 11 |
HF_HOME=/tmp/huggingface
|
| 12 |
|
| 13 |
# Set work directory
|
|
|
|
| 7 |
PYTHONPATH=/app \
|
| 8 |
ENABLE_V1_WARMUP=false \
|
| 9 |
ENABLE_V2_WARMUP=true \
|
| 10 |
+
HF_MODEL_ID=sshleifer/distilbart-cnn-6-6 \
|
| 11 |
HF_HOME=/tmp/huggingface
|
| 12 |
|
| 13 |
# Set work directory
|
README.md
CHANGED
|
@@ -82,7 +82,7 @@ The service uses the following environment variables:
|
|
| 82 |
- `ENABLE_V1_WARMUP`: Enable V1 warmup (default: `false`)
|
| 83 |
|
| 84 |
### V2 Configuration (HuggingFace)
|
| 85 |
-
- `HF_MODEL_ID`: HuggingFace model ID (default: `
|
| 86 |
- `HF_DEVICE_MAP`: Device mapping (default: `auto` for GPU fallback to CPU)
|
| 87 |
- `HF_TORCH_DTYPE`: Torch dtype (default: `auto`)
|
| 88 |
- `HF_HOME`: HuggingFace cache directory (default: `/tmp/huggingface`)
|
|
@@ -121,7 +121,7 @@ This app is optimized for deployment on Hugging Face Spaces using Docker SDK.
|
|
| 121 |
```bash
|
| 122 |
ENABLE_V1_WARMUP=false
|
| 123 |
ENABLE_V2_WARMUP=true
|
| 124 |
-
HF_MODEL_ID=
|
| 125 |
HF_HOME=/tmp/huggingface
|
| 126 |
```
|
| 127 |
|
|
@@ -134,7 +134,7 @@ HF_HOME=/tmp/huggingface
|
|
| 134 |
- **Startup time**: ~30-60 seconds (when V1 warmup enabled)
|
| 135 |
|
| 136 |
### V2 (HuggingFace Streaming) - Primary on HF Spaces
|
| 137 |
-
- **V2 Model**:
|
| 138 |
- **Memory usage**: ~500MB RAM (when V2 warmup enabled)
|
| 139 |
- **Inference speed**: Real-time token streaming
|
| 140 |
- **Startup time**: ~30-60 seconds (includes model download when V2 warmup enabled)
|
|
@@ -144,7 +144,7 @@ HF_HOME=/tmp/huggingface
|
|
| 144 |
- **V2 warmup enabled by default** (`ENABLE_V2_WARMUP=true`)
|
| 145 |
- **HuggingFace Spaces**: V2-only deployment (no Ollama)
|
| 146 |
- **Local development**: V1 endpoints work if Ollama is running externally
|
| 147 |
-
- **
|
| 148 |
|
| 149 |
## 🛠️ Development
|
| 150 |
|
|
|
|
| 82 |
- `ENABLE_V1_WARMUP`: Enable V1 warmup (default: `false`)
|
| 83 |
|
| 84 |
### V2 Configuration (HuggingFace)
|
| 85 |
+
- `HF_MODEL_ID`: HuggingFace model ID (default: `sshleifer/distilbart-cnn-6-6`)
|
| 86 |
- `HF_DEVICE_MAP`: Device mapping (default: `auto` for GPU fallback to CPU)
|
| 87 |
- `HF_TORCH_DTYPE`: Torch dtype (default: `auto`)
|
| 88 |
- `HF_HOME`: HuggingFace cache directory (default: `/tmp/huggingface`)
|
|
|
|
| 121 |
```bash
|
| 122 |
ENABLE_V1_WARMUP=false
|
| 123 |
ENABLE_V2_WARMUP=true
|
| 124 |
+
HF_MODEL_ID=sshleifer/distilbart-cnn-6-6
|
| 125 |
HF_HOME=/tmp/huggingface
|
| 126 |
```
|
| 127 |
|
|
|
|
| 134 |
- **Startup time**: ~30-60 seconds (when V1 warmup enabled)
|
| 135 |
|
| 136 |
### V2 (HuggingFace Streaming) - Primary on HF Spaces
|
| 137 |
+
- **V2 Model**: sshleifer/distilbart-cnn-6-6 (~300MB download)
|
| 138 |
- **Memory usage**: ~500MB RAM (when V2 warmup enabled)
|
| 139 |
- **Inference speed**: Real-time token streaming
|
| 140 |
- **Startup time**: ~30-60 seconds (includes model download when V2 warmup enabled)
|
|
|
|
| 144 |
- **V2 warmup enabled by default** (`ENABLE_V2_WARMUP=true`)
|
| 145 |
- **HuggingFace Spaces**: V2-only deployment (no Ollama)
|
| 146 |
- **Local development**: V1 endpoints work if Ollama is running externally
|
| 147 |
+
- **distilbart-cnn-6-6 model**: Optimized for HuggingFace Spaces free tier with CNN/DailyMail fine-tuning
|
| 148 |
|
| 149 |
## 🛠️ Development
|
| 150 |
|
app/core/config.py
CHANGED
|
@@ -34,7 +34,7 @@ class Settings(BaseSettings):
|
|
| 34 |
max_tokens_default: int = Field(default=256, env="MAX_TOKENS_DEFAULT", ge=1)
|
| 35 |
|
| 36 |
# V2 HuggingFace Configuration
|
| 37 |
-
hf_model_id: str = Field(default="
|
| 38 |
hf_device_map: str = Field(default="auto", env="HF_DEVICE_MAP") # "auto" for GPU fallback to CPU
|
| 39 |
hf_torch_dtype: str = Field(default="auto", env="HF_TORCH_DTYPE") # "auto" for automatic dtype selection
|
| 40 |
hf_cache_dir: str = Field(default="/tmp/huggingface", env="HF_HOME") # HuggingFace cache directory
|
|
|
|
| 34 |
max_tokens_default: int = Field(default=256, env="MAX_TOKENS_DEFAULT", ge=1)
|
| 35 |
|
| 36 |
# V2 HuggingFace Configuration
|
| 37 |
+
hf_model_id: str = Field(default="sshleifer/distilbart-cnn-6-6", env="HF_MODEL_ID")
|
| 38 |
hf_device_map: str = Field(default="auto", env="HF_DEVICE_MAP") # "auto" for GPU fallback to CPU
|
| 39 |
hf_torch_dtype: str = Field(default="auto", env="HF_TORCH_DTYPE") # "auto" for automatic dtype selection
|
| 40 |
hf_cache_dir: str = Field(default="/tmp/huggingface", env="HF_HOME") # HuggingFace cache directory
|
app/services/hf_streaming_summarizer.py
CHANGED
|
@@ -93,8 +93,15 @@ class HFStreamingSummarizer:
|
|
| 93 |
logger.warning("⚠️ HuggingFace model not initialized, skipping warmup")
|
| 94 |
return
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
try:
|
| 100 |
# Run in executor to avoid blocking
|
|
@@ -175,6 +182,15 @@ class HFStreamingSummarizer:
|
|
| 175 |
max_length=512,
|
| 176 |
truncation=True
|
| 177 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
# Other models use chat template
|
| 180 |
messages = [
|
|
@@ -267,8 +283,16 @@ class HFStreamingSummarizer:
|
|
| 267 |
return False
|
| 268 |
|
| 269 |
try:
|
| 270 |
-
#
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
test_input = test_input.to(self.model.device)
|
| 273 |
|
| 274 |
with torch.no_grad():
|
|
|
|
| 93 |
logger.warning("⚠️ HuggingFace model not initialized, skipping warmup")
|
| 94 |
return
|
| 95 |
|
| 96 |
+
# Determine appropriate test prompt based on model type
|
| 97 |
+
if "t5" in settings.hf_model_id.lower():
|
| 98 |
+
test_prompt = "summarize: This is a test."
|
| 99 |
+
elif "bart" in settings.hf_model_id.lower():
|
| 100 |
+
# BART models expect direct text input
|
| 101 |
+
test_prompt = "This is a test article for summarization."
|
| 102 |
+
else:
|
| 103 |
+
# Generic fallback
|
| 104 |
+
test_prompt = "This is a test article for summarization."
|
| 105 |
|
| 106 |
try:
|
| 107 |
# Run in executor to avoid blocking
|
|
|
|
| 182 |
max_length=512,
|
| 183 |
truncation=True
|
| 184 |
)
|
| 185 |
+
elif "bart" in settings.hf_model_id.lower():
|
| 186 |
+
# BART models (including DistilBART) expect direct text input
|
| 187 |
+
# No prefixes or chat templates needed
|
| 188 |
+
inputs = self.tokenizer(
|
| 189 |
+
text,
|
| 190 |
+
return_tensors="pt",
|
| 191 |
+
max_length=1024,
|
| 192 |
+
truncation=True
|
| 193 |
+
)
|
| 194 |
else:
|
| 195 |
# Other models use chat template
|
| 196 |
messages = [
|
|
|
|
| 283 |
return False
|
| 284 |
|
| 285 |
try:
|
| 286 |
+
# Determine appropriate test input based on model type
|
| 287 |
+
if "t5" in settings.hf_model_id.lower():
|
| 288 |
+
test_input_text = "summarize: test"
|
| 289 |
+
elif "bart" in settings.hf_model_id.lower():
|
| 290 |
+
# BART models expect direct text input
|
| 291 |
+
test_input_text = "This is a test article."
|
| 292 |
+
else:
|
| 293 |
+
test_input_text = "This is a test article."
|
| 294 |
+
|
| 295 |
+
test_input = self.tokenizer(test_input_text, return_tensors="pt")
|
| 296 |
test_input = test_input.to(self.model.device)
|
| 297 |
|
| 298 |
with torch.no_grad():
|