voice-denoising / convert_to_tflite.py
grgsaliba's picture
Upload convert_to_tflite.py with huggingface_hub
a02fb02 verified
"""
Convert trained DTLN model to TensorFlow Lite INT8 format
Optimized for Alif E7 Ethos-U55 NPU deployment
"""
import tensorflow as tf
import numpy as np
import soundfile as sf
import librosa
from pathlib import Path
import argparse
import os
def load_representative_dataset(
audio_dir,
num_samples=100,
frame_len=512,
sampling_rate=16000
):
"""
Load representative audio dataset for calibration
Args:
audio_dir: Directory containing audio files
num_samples: Number of samples for calibration
frame_len: Frame length
sampling_rate: Audio sampling rate
Returns:
Generator yielding audio samples
"""
audio_files = list(Path(audio_dir).glob('**/*.wav'))
if len(audio_files) < num_samples:
print(f"Warning: Only {len(audio_files)} files found, using all")
num_samples = len(audio_files)
selected_files = np.random.choice(audio_files, num_samples, replace=False)
def representative_dataset_gen():
for file_path in selected_files:
# Load audio
audio, sr = sf.read(file_path)
# Resample if needed
if sr != sampling_rate:
audio = librosa.resample(
audio,
orig_sr=sr,
target_sr=sampling_rate
)
# Convert to mono
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Take 1 second segment
segment_len = sampling_rate
if len(audio) > segment_len:
start = np.random.randint(0, len(audio) - segment_len)
audio = audio[start:start + segment_len]
else:
audio = np.pad(audio, (0, segment_len - len(audio)))
# Normalize
audio = audio / (np.max(np.abs(audio)) + 1e-8)
# Yield as float32 numpy array
yield [audio.astype(np.float32)[np.newaxis, :]]
return representative_dataset_gen
def convert_to_tflite_int8(
model_path,
output_path,
representative_data_dir,
num_calibration_samples=100
):
"""
Convert Keras model to TFLite with full INT8 quantization
Args:
model_path: Path to trained Keras model (.h5)
output_path: Output path for TFLite model (.tflite)
representative_data_dir: Directory with audio for calibration
num_calibration_samples: Number of samples for calibration
Returns:
TFLite model as bytes
"""
print("="*60)
print("Converting to TensorFlow Lite INT8")
print("="*60)
# Load model
print("\n1. Loading model...")
try:
model = tf.keras.models.load_model(
model_path,
compile=False
)
print(f" βœ“ Model loaded from {model_path}")
except Exception as e:
print(f" βœ— Error loading model: {e}")
return None
model.summary()
# Create converter
print("\n2. Creating TFLite converter...")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# Enable optimizations
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Set up representative dataset for calibration
print("\n3. Setting up representative dataset...")
representative_dataset = load_representative_dataset(
audio_dir=representative_data_dir,
num_samples=num_calibration_samples
)
converter.representative_dataset = representative_dataset
print(f" βœ“ Using {num_calibration_samples} samples for calibration")
# Force full integer quantization
print("\n4. Configuring INT8 quantization...")
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
# Set input/output to INT8
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
# Additional optimizations for Ethos-U55
converter.experimental_new_converter = True
converter.experimental_new_quantizer = True
print(" βœ“ Quantization configured:")
print(" - Optimization: DEFAULT")
print(" - Ops: TFLITE_BUILTINS_INT8")
print(" - Input type: INT8")
print(" - Output type: INT8")
# Convert
print("\n5. Converting model (this may take a few minutes)...")
try:
tflite_model = converter.convert()
print(" βœ“ Conversion successful!")
except Exception as e:
print(f" βœ— Conversion failed: {e}")
return None
# Save
print(f"\n6. Saving TFLite model to {output_path}...")
with open(output_path, 'wb') as f:
f.write(tflite_model)
# Print statistics
model_size_kb = len(tflite_model) / 1024
print(f" βœ“ Model saved")
print(f" βœ“ Model size: {model_size_kb:.2f} KB")
if model_size_kb > 1024:
print(f" ⚠ Warning: Model size ({model_size_kb:.2f} KB) exceeds 1MB")
print(" Consider reducing LSTM units or other optimizations")
return tflite_model
def convert_to_tflite_dynamic_range(
model_path,
output_path
):
"""
Convert with dynamic range quantization (weights only)
Lighter quantization, good for testing
Args:
model_path: Path to trained Keras model
output_path: Output path for TFLite model
Returns:
TFLite model as bytes
"""
print("Converting with dynamic range quantization...")
model = tf.keras.models.load_model(model_path, compile=False)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open(output_path, 'wb') as f:
f.write(tflite_model)
print(f"βœ“ Model saved to {output_path}")
print(f"βœ“ Size: {len(tflite_model) / 1024:.2f} KB")
return tflite_model
def analyze_tflite_model(tflite_path):
"""
Analyze converted TFLite model
Args:
tflite_path: Path to TFLite model
"""
print("\n" + "="*60)
print("Model Analysis")
print("="*60)
# Load interpreter
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
# Get input details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("\nπŸ“₯ Input Tensor Details:")
for i, detail in enumerate(input_details):
print(f"\n Input {i}:")
print(f" Name: {detail['name']}")
print(f" Shape: {detail['shape']}")
print(f" Type: {detail['dtype']}")
quant = detail['quantization']
if quant[0] or quant[1]:
print(f" Scale: {quant[0]}")
print(f" Zero point: {quant[1]}")
print("\nπŸ“€ Output Tensor Details:")
for i, detail in enumerate(output_details):
print(f"\n Output {i}:")
print(f" Name: {detail['name']}")
print(f" Shape: {detail['shape']}")
print(f" Type: {detail['dtype']}")
quant = detail['quantization']
if quant[0] or quant[1]:
print(f" Scale: {quant[0]}")
print(f" Zero point: {quant[1]}")
# Get tensor details
tensor_details = interpreter.get_tensor_details()
print(f"\nπŸ“Š Total Tensors: {len(tensor_details)}")
# Count operations
print("\nπŸ”§ Model Operations:")
ops = {}
for tensor in tensor_details:
if 'name' in tensor and tensor['name']:
# Extract op type from name
parts = tensor['name'].split('/')
if len(parts) > 1:
op_type = parts[0]
ops[op_type] = ops.get(op_type, 0) + 1
for op_type, count in sorted(ops.items()):
print(f" {op_type}: {count}")
def test_inference(tflite_path, test_audio_path):
"""
Test inference with TFLite model
Args:
tflite_path: Path to TFLite model
test_audio_path: Path to test audio file
"""
print("\n" + "="*60)
print("Testing Inference")
print("="*60)
# Load model
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Load test audio
print(f"\nLoading test audio: {test_audio_path}")
audio, sr = sf.read(test_audio_path)
if sr != 16000:
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Take 1 second
audio = audio[:16000]
if len(audio) < 16000:
audio = np.pad(audio, (0, 16000 - len(audio)))
# Normalize
audio = audio / (np.max(np.abs(audio)) + 1e-8)
# Prepare input
input_data = audio.astype(np.float32)[np.newaxis, :]
# Quantize input if needed
input_dtype = input_details[0]['dtype']
if input_dtype == np.int8:
input_scale = input_details[0]['quantization'][0]
input_zero_point = input_details[0]['quantization'][1]
input_data = (input_data / input_scale + input_zero_point).astype(np.int8)
# Run inference
print("\nRunning inference...")
interpreter.set_tensor(input_details[0]['index'], input_data)
import time
start = time.time()
interpreter.invoke()
latency = (time.time() - start) * 1000
print(f"βœ“ Inference completed")
print(f"βœ“ Latency: {latency:.2f} ms")
# Get output
output_data = interpreter.get_tensor(output_details[0]['index'])
# Dequantize if needed
output_dtype = output_details[0]['dtype']
if output_dtype == np.int8:
output_scale = output_details[0]['quantization'][0]
output_zero_point = output_details[0]['quantization'][1]
output_data = (output_data.astype(np.float32) - output_zero_point) * output_scale
print(f"βœ“ Output shape: {output_data.shape}")
print(f"βœ“ Output range: [{np.min(output_data):.4f}, {np.max(output_data):.4f}]")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Convert DTLN model to TFLite INT8 for Ethos-U55'
)
parser.add_argument(
'--model',
type=str,
required=True,
help='Path to trained Keras model (.h5)'
)
parser.add_argument(
'--output',
type=str,
required=True,
help='Output path for TFLite model (.tflite)'
)
parser.add_argument(
'--calibration-dir',
type=str,
required=True,
help='Directory with audio for calibration'
)
parser.add_argument(
'--num-calibration-samples',
type=int,
default=100,
help='Number of samples for calibration'
)
parser.add_argument(
'--test-audio',
type=str,
default=None,
help='Path to test audio file'
)
parser.add_argument(
'--dynamic-range',
action='store_true',
help='Use dynamic range quantization instead of full INT8'
)
args = parser.parse_args()
# Convert model
if args.dynamic_range:
tflite_model = convert_to_tflite_dynamic_range(
args.model,
args.output
)
else:
tflite_model = convert_to_tflite_int8(
model_path=args.model,
output_path=args.output,
representative_data_dir=args.calibration_dir,
num_calibration_samples=args.num_calibration_samples
)
if tflite_model is None:
print("\nβœ— Conversion failed!")
exit(1)
# Analyze model
analyze_tflite_model(args.output)
# Test inference if test audio provided
if args.test_audio and os.path.exists(args.test_audio):
test_inference(args.output, args.test_audio)
print("\n" + "="*60)
print("βœ“ All done!")
print(f"βœ“ TFLite model saved to: {args.output}")
print("\nNext steps:")
print("1. Use Vela compiler to optimize for Ethos-U55:")
print(f" vela --accelerator-config ethos-u55-256 {args.output}")
print("2. Integrate into Alif E7 application")
print("3. Profile on actual hardware")
print("="*60)