Meet Patel
		
	commited on
		
		
					Commit 
							
							·
						
						d909d3b
	
1
								Parent(s):
							
							f30c061
								
Refactor API to use Azure Blob Storage instead of S3; update requirements and remove unused environment variables.
Browse files- Dockerfile.api +0 -4
 - api.py +66 -33
 - requirements.txt +1 -1
 
    	
        Dockerfile.api
    CHANGED
    
    | 
         @@ -27,10 +27,6 @@ COPY examples ./examples 
     | 
|
| 27 | 
         
             
            ENV PYTHONPATH=/app
         
     | 
| 28 | 
         
             
            ENV HF_HUB_CACHE=/app/checkpoints/hf_cache
         
     | 
| 29 | 
         
             
            ENV TORCH_HOME=/app/checkpoints
         
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
            ENV AWS_REGION=us-east-1
         
     | 
| 32 | 
         
            -
            ENV S3_BUCKET=elevenlabs-clone
         
     | 
| 33 | 
         
            -
            ENV S3_PREFIX=seedvc-outputs
         
     | 
| 34 | 
         
             
            ENV API_KEY=12345
         
     | 
| 35 | 
         | 
| 36 | 
         
             
            EXPOSE 8000
         
     | 
| 
         | 
|
| 27 | 
         
             
            ENV PYTHONPATH=/app
         
     | 
| 28 | 
         
             
            ENV HF_HUB_CACHE=/app/checkpoints/hf_cache
         
     | 
| 29 | 
         
             
            ENV TORCH_HOME=/app/checkpoints
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
            ENV API_KEY=12345
         
     | 
| 31 | 
         | 
| 32 | 
         
             
            EXPOSE 8000
         
     | 
    	
        api.py
    CHANGED
    
    | 
         @@ -4,7 +4,8 @@ import uuid 
     | 
|
| 4 | 
         
             
            from contextlib import asynccontextmanager
         
     | 
| 5 | 
         
             
            from tempfile import NamedTemporaryFile
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
            import  
     | 
| 
         | 
|
| 8 | 
         
             
            import torchaudio
         
     | 
| 9 | 
         
             
            from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException
         
     | 
| 10 | 
         
             
            from fastapi.security import APIKeyHeader
         
     | 
| 
         @@ -39,22 +40,35 @@ async def verify_api_key(authorization: str = Header(None)): 
     | 
|
| 39 | 
         
             
                return token
         
     | 
| 40 | 
         | 
| 41 | 
         | 
| 42 | 
         
            -
            def  
     | 
| 43 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 44 | 
         | 
| 45 | 
         
            -
                if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"):
         
     | 
| 46 | 
         
            -
                    client_kwargs.update({
         
     | 
| 47 | 
         
            -
                        'aws_access_key_id': os.getenv("AWS_ACCESS_KEY_ID"),
         
     | 
| 48 | 
         
            -
                        'aws_secret_access_key': os.getenv("AWS_SECRET_ACCESS_KEY")
         
     | 
| 49 | 
         
            -
                    })
         
     | 
| 50 | 
         | 
| 51 | 
         
            -
             
     | 
| 52 | 
         | 
| 
         | 
|
| 53 | 
         | 
| 54 | 
         
            -
            s3_client = get_s3_client()
         
     | 
| 55 | 
         | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 58 | 
         | 
| 59 | 
         | 
| 60 | 
         
             
            @asynccontextmanager
         
     | 
| 
         @@ -62,8 +76,8 @@ async def lifespan(app: FastAPI): 
     | 
|
| 62 | 
         
             
                global models
         
     | 
| 63 | 
         
             
                logger.info("Loading Seed-VC model...")
         
     | 
| 64 | 
         
             
                try:
         
     | 
| 
         | 
|
| 65 | 
         
             
                    models = load_models()
         
     | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
             
                    logger.info("Seed-VC model loaded successfully")
         
     | 
| 68 | 
         
             
                except Exception as e:
         
     | 
| 69 | 
         
             
                    logger.error(f"Failed to load model: {e}")
         
     | 
| 
         @@ -77,8 +91,8 @@ app = FastAPI(title="Seed-VC API", 
     | 
|
| 77 | 
         
             
                          lifespan=lifespan)
         
     | 
| 78 | 
         | 
| 79 | 
         
             
            TARGET_VOICES = {
         
     | 
| 80 | 
         
            -
                " 
     | 
| 81 | 
         
            -
                " 
     | 
| 82 | 
         
             
                "trump": "examples/reference/trump_0.wav",
         
     | 
| 83 | 
         
             
            }
         
     | 
| 84 | 
         | 
| 
         @@ -88,6 +102,18 @@ class VoiceConversionRequest(BaseModel): 
     | 
|
| 88 | 
         
             
                target_voice: str
         
     | 
| 89 | 
         | 
| 90 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 91 | 
         
             
            @app.post("/convert", dependencies=[Depends(verify_api_key)])
         
     | 
| 92 | 
         
             
            async def generate_speech(request: VoiceConversionRequest, background_tasks: BackgroundTasks):
         
     | 
| 93 | 
         
             
                if not models:
         
     | 
| 
         @@ -107,39 +133,46 @@ async def generate_speech(request: VoiceConversionRequest, background_tasks: Bac 
     | 
|
| 107 | 
         
             
                    output_filename = f"{audio_id}.wav"
         
     | 
| 108 | 
         
             
                    local_path = f"/tmp/{output_filename}"
         
     | 
| 109 | 
         | 
| 110 | 
         
            -
                    logger.info("Downloading source audio")
         
     | 
| 111 | 
         
            -
                    source_temp = NamedTemporaryFile(delete=False, suffix=".wav")
         
     | 
| 112 | 
         
             
                    try:
         
     | 
| 113 | 
         
            -
                         
     | 
| 114 | 
         
            -
                            S3_BUCKET, Key=request.source_audio_key, Fileobj=source_temp)
         
     | 
| 115 | 
         
            -
                        source_temp.close()
         
     | 
| 116 | 
         
             
                    except Exception as e:
         
     | 
| 117 | 
         
            -
                         
     | 
| 118 | 
         
             
                        raise HTTPException(
         
     | 
| 119 | 
         
             
                            status_code=404, detail="Source audio not found")
         
     | 
| 120 | 
         | 
| 121 | 
         
             
                    vc_wave, sr = process_voice_conversion(
         
     | 
| 122 | 
         
            -
                        models=models, source= 
     | 
| 123 | 
         | 
| 124 | 
         
            -
                    os.unlink( 
     | 
| 125 | 
         | 
| 126 | 
         
             
                    torchaudio.save(local_path, vc_wave, sr)
         
     | 
| 127 | 
         | 
| 128 | 
         
            -
                    # Upload to  
     | 
| 129 | 
         
            -
                     
     | 
| 130 | 
         
            -
                     
     | 
| 131 | 
         
            -
             
     | 
| 132 | 
         
            -
             
     | 
| 133 | 
         
            -
             
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 136 | 
         
             
                    )
         
     | 
| 
         | 
|
| 137 | 
         | 
| 138 | 
         
             
                    background_tasks.add_task(os.remove, local_path)
         
     | 
| 139 | 
         | 
| 140 | 
         
             
                    return {
         
     | 
| 141 | 
         
            -
                        "audio_url":  
     | 
| 142 | 
         
            -
                        " 
     | 
| 143 | 
         
             
                    }
         
     | 
| 144 | 
         
             
                except Exception as e:
         
     | 
| 145 | 
         
             
                    logger.error(f"Error in voice conversion: {e}")
         
     | 
| 
         | 
|
| 4 | 
         
             
            from contextlib import asynccontextmanager
         
     | 
| 5 | 
         
             
            from tempfile import NamedTemporaryFile
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
            from azure.storage.blob import BlobServiceClient, generate_blob_sas, BlobSasPermissions
         
     | 
| 8 | 
         
            +
            from datetime import datetime, timedelta
         
     | 
| 9 | 
         
             
            import torchaudio
         
     | 
| 10 | 
         
             
            from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException
         
     | 
| 11 | 
         
             
            from fastapi.security import APIKeyHeader
         
     | 
| 
         | 
|
| 40 | 
         
             
                return token
         
     | 
| 41 | 
         | 
| 42 | 
         | 
| 43 | 
         
            +
            def get_azure_blob_client():
         
     | 
| 44 | 
         
            +
                account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME", "getpoints")
         
     | 
| 45 | 
         
            +
                account_key = os.getenv("AZURE_STORAGE_KEY", "ts/PL1cr3X1F9JWgksAtqcWsQvPBK9UJ3BtNQBL98kYU17U3JxEiFI2vJrNDzmAyFRleOdRdoG03+ASt9RDnZA==")
         
     | 
| 46 | 
         
            +
                blob_endpoint = os.getenv("AZURE_BLOB_ENDPOINT", "https://getpoints.blob.core.windows.net/")
         
     | 
| 47 | 
         
            +
                blob_service_client = BlobServiceClient(
         
     | 
| 48 | 
         
            +
                    account_url=blob_endpoint,
         
     | 
| 49 | 
         
            +
                    credential=account_key
         
     | 
| 50 | 
         
            +
                )
         
     | 
| 51 | 
         
            +
                return blob_service_client
         
     | 
| 52 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 53 | 
         | 
| 54 | 
         
            +
            blob_client = get_azure_blob_client()
         
     | 
| 55 | 
         | 
| 56 | 
         
            +
            AZURE_CONTAINER_NAME = os.getenv("AZURE_CONTAINER_NAME", "seedvc-outputs")
         
     | 
| 57 | 
         | 
| 
         | 
|
| 58 | 
         | 
| 59 | 
         
            +
            async def ensure_container_exists():
         
     | 
| 60 | 
         
            +
                """Ensure the Azure container exists, create if it doesn't"""
         
     | 
| 61 | 
         
            +
                try:
         
     | 
| 62 | 
         
            +
                    container_client = blob_client.get_container_client(AZURE_CONTAINER_NAME)
         
     | 
| 63 | 
         
            +
                    container_client.get_container_properties()
         
     | 
| 64 | 
         
            +
                    logger.info(f"Container '{AZURE_CONTAINER_NAME}' already exists")
         
     | 
| 65 | 
         
            +
                except Exception:
         
     | 
| 66 | 
         
            +
                    try:
         
     | 
| 67 | 
         
            +
                        blob_client.create_container(AZURE_CONTAINER_NAME)
         
     | 
| 68 | 
         
            +
                        logger.info(f"Created container '{AZURE_CONTAINER_NAME}'")
         
     | 
| 69 | 
         
            +
                    except Exception as e:
         
     | 
| 70 | 
         
            +
                        logger.error(f"Failed to create container '{AZURE_CONTAINER_NAME}': {e}")
         
     | 
| 71 | 
         
            +
                        raise
         
     | 
| 72 | 
         | 
| 73 | 
         | 
| 74 | 
         
             
            @asynccontextmanager
         
     | 
| 
         | 
|
| 76 | 
         
             
                global models
         
     | 
| 77 | 
         
             
                logger.info("Loading Seed-VC model...")
         
     | 
| 78 | 
         
             
                try:
         
     | 
| 79 | 
         
            +
                    await ensure_container_exists()
         
     | 
| 80 | 
         
             
                    models = load_models()
         
     | 
| 
         | 
|
| 81 | 
         
             
                    logger.info("Seed-VC model loaded successfully")
         
     | 
| 82 | 
         
             
                except Exception as e:
         
     | 
| 83 | 
         
             
                    logger.error(f"Failed to load model: {e}")
         
     | 
| 
         | 
|
| 91 | 
         
             
                          lifespan=lifespan)
         
     | 
| 92 | 
         | 
| 93 | 
         
             
            TARGET_VOICES = {
         
     | 
| 94 | 
         
            +
                "male": "examples/reference/s1p2.wav",
         
     | 
| 95 | 
         
            +
                "female": "examples/reference/s1p1.wav",
         
     | 
| 96 | 
         
             
                "trump": "examples/reference/trump_0.wav",
         
     | 
| 97 | 
         
             
            }
         
     | 
| 98 | 
         | 
| 
         | 
|
| 102 | 
         
             
                target_voice: str
         
     | 
| 103 | 
         | 
| 104 | 
         | 
| 105 | 
         
            +
            def download_blob_to_temp(blob_name):
         
     | 
| 106 | 
         
            +
                temp_file = NamedTemporaryFile(delete=False, suffix=".wav")
         
     | 
| 107 | 
         
            +
                blob_client_instance = blob_client.get_blob_client(
         
     | 
| 108 | 
         
            +
                    container=AZURE_CONTAINER_NAME,
         
     | 
| 109 | 
         
            +
                    blob=blob_name
         
     | 
| 110 | 
         
            +
                )
         
     | 
| 111 | 
         
            +
                with open(temp_file.name, "wb") as f:
         
     | 
| 112 | 
         
            +
                    download_stream = blob_client_instance.download_blob()
         
     | 
| 113 | 
         
            +
                    f.write(download_stream.readall())
         
     | 
| 114 | 
         
            +
                return temp_file.name
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
             
            @app.post("/convert", dependencies=[Depends(verify_api_key)])
         
     | 
| 118 | 
         
             
            async def generate_speech(request: VoiceConversionRequest, background_tasks: BackgroundTasks):
         
     | 
| 119 | 
         
             
                if not models:
         
     | 
| 
         | 
|
| 133 | 
         
             
                    output_filename = f"{audio_id}.wav"
         
     | 
| 134 | 
         
             
                    local_path = f"/tmp/{output_filename}"
         
     | 
| 135 | 
         | 
| 136 | 
         
            +
                    logger.info("Downloading source audio from Azure Blob Storage")
         
     | 
| 
         | 
|
| 137 | 
         
             
                    try:
         
     | 
| 138 | 
         
            +
                        source_temp_path = download_blob_to_temp(request.source_audio_key)
         
     | 
| 
         | 
|
| 
         | 
|
| 139 | 
         
             
                    except Exception as e:
         
     | 
| 140 | 
         
            +
                        logger.error(f"Failed to download source audio: {e}")
         
     | 
| 141 | 
         
             
                        raise HTTPException(
         
     | 
| 142 | 
         
             
                            status_code=404, detail="Source audio not found")
         
     | 
| 143 | 
         | 
| 144 | 
         
             
                    vc_wave, sr = process_voice_conversion(
         
     | 
| 145 | 
         
            +
                        models=models, source=source_temp_path, target_name=target_audio_path, output=None)
         
     | 
| 146 | 
         | 
| 147 | 
         
            +
                    os.unlink(source_temp_path)
         
     | 
| 148 | 
         | 
| 149 | 
         
             
                    torchaudio.save(local_path, vc_wave, sr)
         
     | 
| 150 | 
         | 
| 151 | 
         
            +
                    # Upload to Azure Blob Storage
         
     | 
| 152 | 
         
            +
                    blob_name = f"seedvc-outputs/{output_filename}"
         
     | 
| 153 | 
         
            +
                    blob_client_instance = blob_client.get_blob_client(
         
     | 
| 154 | 
         
            +
                        container=AZURE_CONTAINER_NAME,
         
     | 
| 155 | 
         
            +
                        blob=blob_name
         
     | 
| 156 | 
         
            +
                    )
         
     | 
| 157 | 
         
            +
                    with open(local_path, "rb") as data:
         
     | 
| 158 | 
         
            +
                        blob_client_instance.upload_blob(data, overwrite=True)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    # Generate SAS URL for temporary access
         
     | 
| 161 | 
         
            +
                    sas_token = generate_blob_sas(
         
     | 
| 162 | 
         
            +
                        account_name=blob_client.account_name,
         
     | 
| 163 | 
         
            +
                        container_name=AZURE_CONTAINER_NAME,
         
     | 
| 164 | 
         
            +
                        blob_name=blob_name,
         
     | 
| 165 | 
         
            +
                        account_key=os.getenv("AZURE_STORAGE_KEY", "ts/PL1cr3X1F9JWgksAtqcWsQvPBK9UJ3BtNQBL98kYU17U3JxEiFI2vJrNDzmAyFRleOdRdoG03+ASt9RDnZA=="),
         
     | 
| 166 | 
         
            +
                        permission=BlobSasPermissions(read=True),
         
     | 
| 167 | 
         
            +
                        expiry=datetime.utcnow() + timedelta(hours=1)
         
     | 
| 168 | 
         
             
                    )
         
     | 
| 169 | 
         
            +
                    blob_url = f"{blob_client_instance.url}?{sas_token}"
         
     | 
| 170 | 
         | 
| 171 | 
         
             
                    background_tasks.add_task(os.remove, local_path)
         
     | 
| 172 | 
         | 
| 173 | 
         
             
                    return {
         
     | 
| 174 | 
         
            +
                        "audio_url": blob_url,
         
     | 
| 175 | 
         
            +
                        "blob_name": blob_name
         
     | 
| 176 | 
         
             
                    }
         
     | 
| 177 | 
         
             
                except Exception as e:
         
     | 
| 178 | 
         
             
                    logger.error(f"Error in voice conversion: {e}")
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -20,6 +20,6 @@ funasr==1.1.5 
     | 
|
| 20 | 
         
             
            numpy==1.26.4
         
     | 
| 21 | 
         
             
            pyyaml
         
     | 
| 22 | 
         
             
            python-dotenv
         
     | 
| 23 | 
         
            -
            boto3
         
     | 
| 24 | 
         
             
            uvicorn
         
     | 
| 25 | 
         
             
            fastapi
         
     | 
| 
         | 
| 
         | 
|
| 20 | 
         
             
            numpy==1.26.4
         
     | 
| 21 | 
         
             
            pyyaml
         
     | 
| 22 | 
         
             
            python-dotenv
         
     | 
| 
         | 
|
| 23 | 
         
             
            uvicorn
         
     | 
| 24 | 
         
             
            fastapi
         
     | 
| 25 | 
         
            +
            azure-storage-blob
         
     |