stt / serve.py
komyu1227's picture
change whisper
b3834f9
raw
history blame
3.38 kB
from reazonspeech.nemo.asr import load_model, transcribe, audio_from_numpy
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File
import uvicorn
import numpy as np
import io
from pydub import AudioSegment
import time
import logging
from transformers import WhisperProcessor, WhisperForConditionalGeneration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
# model = load_model(device)
processor = WhisperProcessor.from_pretrained("Ivydata/whisper-small-japanese").to(device)
model = WhisperForConditionalGeneration.from_pretrained("Ivydata/whisper-small-japanese").to(device)
def transcribe_audio(audio_data_bytes):
try:
start_time = time.time()
audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
# Get audio data as numpy array
audio_data_int16 = np.array(audio_segment.get_array_of_samples())
# Convert to float32 normalized to [-1, 1]
audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
# Process with reazonspeech
audio = audio_from_numpy(audio_data_float32, samplerate=audio_segment.frame_rate)
result = transcribe(model, audio)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def transcribe_whisper(audio_data_bytes):
try:
start_time = time.time()
audio_segment = AudioSegment.from_mp3(io.BytesIO(audio_data_bytes))
# Get audio data as numpy array
audio_data_int16 = np.array(audio_segment.get_array_of_samples())
# Convert to float32 normalized to [-1, 1]
audio_data_float32 = audio_data_int16.astype(np.float32) / 32768.0
# Process with whisper
input_features = processor(audio=audio_data_float32,
sampling_rate=audio_segment.frame_rate,
return_tensors="pt").input_features.to(device)
predicted_ids = model.generate(input_features=input_features)
result = processor.batch_decode(predicted_ids, skip_special_tokens=True)
resultText = result[0] if isinstance(result, list) and len(result) > 0 else str(result)
result = {
"text": resultText
}
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
app = FastAPI()
@app.post("/transcribe")
async def transcribe_endpoint(file: UploadFile = File(...)):
audio_data = await file.read()
try:
result = transcribe_audio(audio_data)
return {
"result": [
{
"text": result.text
}
]
}
except HTTPException as e:
return {
"result": [
{
"text": "γ‚¨γƒ©γƒΌγŒη™Ίη”Ÿγ—γΎγ—γŸ, もう一度試してください",
}
]
}
if __name__ == "__main__":
logger.info(f"Model loaded on {device}")
uvicorn.run(app, host="0.0.0.0", port=7860)