Update app.py
Browse files
app.py
CHANGED
|
@@ -40,21 +40,30 @@ async def load_models():
|
|
| 40 |
|
| 41 |
try:
|
| 42 |
# 感情分析モデル(軽量版を使用)
|
|
|
|
| 43 |
sentiment_classifier = pipeline(
|
| 44 |
"sentiment-analysis",
|
| 45 |
-
model="cardiffnlp/twitter-roberta-base-sentiment-latest"
|
|
|
|
| 46 |
)
|
|
|
|
| 47 |
|
| 48 |
# テキスト生成モデル(軽量版)
|
|
|
|
| 49 |
text_generator = pipeline(
|
| 50 |
"text-generation",
|
| 51 |
-
model="distilgpt2"
|
|
|
|
| 52 |
)
|
|
|
|
| 53 |
|
| 54 |
-
print("✅
|
| 55 |
|
| 56 |
except Exception as e:
|
| 57 |
print(f"❌ モデルロードエラー: {e}")
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
@app.get("/")
|
| 60 |
async def root():
|
|
@@ -97,27 +106,50 @@ async def generate_text(request: TextRequest):
|
|
| 97 |
detail="Text generation model not loaded. Please try again later."
|
| 98 |
)
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
num_return_sequences=1,
|
| 107 |
-
temperature=0.7,
|
| 108 |
-
do_sample=True,
|
| 109 |
-
pad_token_id=text_generator.tokenizer.eos_token_id
|
| 110 |
-
)
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
|
|
|
|
|
|
|
|
|
| 120 |
except Exception as e:
|
|
|
|
| 121 |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 122 |
|
| 123 |
@app.get("/models")
|
|
@@ -135,16 +167,24 @@ async def get_models():
|
|
| 135 |
"platform": "Hugging Face Spaces"
|
| 136 |
}
|
| 137 |
|
| 138 |
-
@app.get("/
|
| 139 |
-
async def
|
| 140 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 141 |
return {
|
| 142 |
-
"
|
| 143 |
-
"
|
|
|
|
| 144 |
"sentiment": sentiment_classifier is not None,
|
| 145 |
-
"
|
| 146 |
},
|
| 147 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
}
|
| 149 |
|
| 150 |
# Spaces用の追加設定
|
|
|
|
| 40 |
|
| 41 |
try:
|
| 42 |
# 感情分析モデル(軽量版を使用)
|
| 43 |
+
print("Loading sentiment analysis model...")
|
| 44 |
sentiment_classifier = pipeline(
|
| 45 |
"sentiment-analysis",
|
| 46 |
+
model="cardiffnlp/twitter-roberta-base-sentiment-latest",
|
| 47 |
+
return_all_scores=False
|
| 48 |
)
|
| 49 |
+
print("✅ Sentiment model loaded")
|
| 50 |
|
| 51 |
# テキスト生成モデル(軽量版)
|
| 52 |
+
print("Loading text generation model...")
|
| 53 |
text_generator = pipeline(
|
| 54 |
"text-generation",
|
| 55 |
+
model="distilgpt2",
|
| 56 |
+
pad_token_id=50256 # GPT-2のEOSトークンID
|
| 57 |
)
|
| 58 |
+
print("✅ Text generation model loaded")
|
| 59 |
|
| 60 |
+
print("✅ 全てのモデルのロードが完了しました")
|
| 61 |
|
| 62 |
except Exception as e:
|
| 63 |
print(f"❌ モデルロードエラー: {e}")
|
| 64 |
+
print(f"エラーの詳細: {type(e).__name__}")
|
| 65 |
+
import traceback
|
| 66 |
+
traceback.print_exc()
|
| 67 |
|
| 68 |
@app.get("/")
|
| 69 |
async def root():
|
|
|
|
| 106 |
detail="Text generation model not loaded. Please try again later."
|
| 107 |
)
|
| 108 |
|
| 109 |
+
# 入力テキストの検証
|
| 110 |
+
if not request.text or len(request.text.strip()) == 0:
|
| 111 |
+
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
| 112 |
|
| 113 |
+
# Spacesの制限を考慮して短めに設定
|
| 114 |
+
max_length = min(request.max_length, 100)
|
| 115 |
+
input_length = len(request.text.split())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
# 入力より長い出力を生成するように調整
|
| 118 |
+
if max_length <= input_length:
|
| 119 |
+
max_length = input_length + 20
|
| 120 |
|
| 121 |
+
try:
|
| 122 |
+
result = text_generator(
|
| 123 |
+
request.text,
|
| 124 |
+
max_length=max_length,
|
| 125 |
+
num_return_sequences=1,
|
| 126 |
+
temperature=0.7,
|
| 127 |
+
do_sample=True,
|
| 128 |
+
truncation=True,
|
| 129 |
+
pad_token_id=text_generator.tokenizer.eos_token_id
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
generated_text = result[0]["generated_text"]
|
| 133 |
+
|
| 134 |
+
return GenerateResponse(
|
| 135 |
+
input_text=request.text,
|
| 136 |
+
generated_text=generated_text,
|
| 137 |
+
model_name="distilgpt2"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
except Exception as model_error:
|
| 141 |
+
# モデル固有のエラーをキャッチ
|
| 142 |
+
print(f"Model error: {model_error}")
|
| 143 |
+
raise HTTPException(
|
| 144 |
+
status_code=500,
|
| 145 |
+
detail=f"Model processing failed: {str(model_error)}"
|
| 146 |
+
)
|
| 147 |
|
| 148 |
+
except HTTPException:
|
| 149 |
+
# HTTPExceptionは再発生
|
| 150 |
+
raise
|
| 151 |
except Exception as e:
|
| 152 |
+
print(f"Unexpected error: {e}")
|
| 153 |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
|
| 154 |
|
| 155 |
@app.get("/models")
|
|
|
|
| 167 |
"platform": "Hugging Face Spaces"
|
| 168 |
}
|
| 169 |
|
| 170 |
+
@app.get("/debug")
|
| 171 |
+
async def debug_info():
|
| 172 |
+
"""デバッグ情報を取得"""
|
| 173 |
+
import sys
|
| 174 |
+
import torch
|
| 175 |
+
|
| 176 |
return {
|
| 177 |
+
"python_version": sys.version,
|
| 178 |
+
"torch_version": torch.__version__ if 'torch' in sys.modules else "not installed",
|
| 179 |
+
"models_loaded": {
|
| 180 |
"sentiment": sentiment_classifier is not None,
|
| 181 |
+
"generator": text_generator is not None
|
| 182 |
},
|
| 183 |
+
"generator_tokenizer": {
|
| 184 |
+
"vocab_size": text_generator.tokenizer.vocab_size if text_generator else None,
|
| 185 |
+
"eos_token_id": text_generator.tokenizer.eos_token_id if text_generator else None,
|
| 186 |
+
"pad_token_id": text_generator.tokenizer.pad_token_id if text_generator else None
|
| 187 |
+
} if text_generator else None
|
| 188 |
}
|
| 189 |
|
| 190 |
# Spaces用の追加設定
|