File size: 12,463 Bytes
a02fb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""
Convert trained DTLN model to TensorFlow Lite INT8 format
Optimized for Alif E7 Ethos-U55 NPU deployment
"""

import tensorflow as tf
import numpy as np
import soundfile as sf
import librosa
from pathlib import Path
import argparse
import os


def load_representative_dataset(
    audio_dir,
    num_samples=100,
    frame_len=512,
    sampling_rate=16000
):
    """
    Load representative audio dataset for calibration
    
    Args:
        audio_dir: Directory containing audio files
        num_samples: Number of samples for calibration
        frame_len: Frame length
        sampling_rate: Audio sampling rate
    
    Returns:
        Generator yielding audio samples
    """
    audio_files = list(Path(audio_dir).glob('**/*.wav'))
    
    if len(audio_files) < num_samples:
        print(f"Warning: Only {len(audio_files)} files found, using all")
        num_samples = len(audio_files)
    
    selected_files = np.random.choice(audio_files, num_samples, replace=False)
    
    def representative_dataset_gen():
        for file_path in selected_files:
            # Load audio
            audio, sr = sf.read(file_path)
            
            # Resample if needed
            if sr != sampling_rate:
                audio = librosa.resample(
                    audio,
                    orig_sr=sr,
                    target_sr=sampling_rate
                )
            
            # Convert to mono
            if len(audio.shape) > 1:
                audio = np.mean(audio, axis=1)
            
            # Take 1 second segment
            segment_len = sampling_rate
            if len(audio) > segment_len:
                start = np.random.randint(0, len(audio) - segment_len)
                audio = audio[start:start + segment_len]
            else:
                audio = np.pad(audio, (0, segment_len - len(audio)))
            
            # Normalize
            audio = audio / (np.max(np.abs(audio)) + 1e-8)
            
            # Yield as float32 numpy array
            yield [audio.astype(np.float32)[np.newaxis, :]]
    
    return representative_dataset_gen


def convert_to_tflite_int8(
    model_path,
    output_path,
    representative_data_dir,
    num_calibration_samples=100
):
    """
    Convert Keras model to TFLite with full INT8 quantization
    
    Args:
        model_path: Path to trained Keras model (.h5)
        output_path: Output path for TFLite model (.tflite)
        representative_data_dir: Directory with audio for calibration
        num_calibration_samples: Number of samples for calibration
    
    Returns:
        TFLite model as bytes
    """
    print("="*60)
    print("Converting to TensorFlow Lite INT8")
    print("="*60)
    
    # Load model
    print("\n1. Loading model...")
    try:
        model = tf.keras.models.load_model(
            model_path,
            compile=False
        )
        print(f"   βœ“ Model loaded from {model_path}")
    except Exception as e:
        print(f"   βœ— Error loading model: {e}")
        return None
    
    model.summary()
    
    # Create converter
    print("\n2. Creating TFLite converter...")
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    # Enable optimizations
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # Set up representative dataset for calibration
    print("\n3. Setting up representative dataset...")
    representative_dataset = load_representative_dataset(
        audio_dir=representative_data_dir,
        num_samples=num_calibration_samples
    )
    converter.representative_dataset = representative_dataset
    print(f"   βœ“ Using {num_calibration_samples} samples for calibration")
    
    # Force full integer quantization
    print("\n4. Configuring INT8 quantization...")
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS_INT8
    ]
    
    # Set input/output to INT8
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    
    # Additional optimizations for Ethos-U55
    converter.experimental_new_converter = True
    converter.experimental_new_quantizer = True
    
    print("   βœ“ Quantization configured:")
    print("     - Optimization: DEFAULT")
    print("     - Ops: TFLITE_BUILTINS_INT8")
    print("     - Input type: INT8")
    print("     - Output type: INT8")
    
    # Convert
    print("\n5. Converting model (this may take a few minutes)...")
    try:
        tflite_model = converter.convert()
        print("   βœ“ Conversion successful!")
    except Exception as e:
        print(f"   βœ— Conversion failed: {e}")
        return None
    
    # Save
    print(f"\n6. Saving TFLite model to {output_path}...")
    with open(output_path, 'wb') as f:
        f.write(tflite_model)
    
    # Print statistics
    model_size_kb = len(tflite_model) / 1024
    print(f"   βœ“ Model saved")
    print(f"   βœ“ Model size: {model_size_kb:.2f} KB")
    
    if model_size_kb > 1024:
        print(f"   ⚠ Warning: Model size ({model_size_kb:.2f} KB) exceeds 1MB")
        print("   Consider reducing LSTM units or other optimizations")
    
    return tflite_model


def convert_to_tflite_dynamic_range(
    model_path,
    output_path
):
    """
    Convert with dynamic range quantization (weights only)
    Lighter quantization, good for testing
    
    Args:
        model_path: Path to trained Keras model
        output_path: Output path for TFLite model
    
    Returns:
        TFLite model as bytes
    """
    print("Converting with dynamic range quantization...")
    
    model = tf.keras.models.load_model(model_path, compile=False)
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    tflite_model = converter.convert()
    
    with open(output_path, 'wb') as f:
        f.write(tflite_model)
    
    print(f"βœ“ Model saved to {output_path}")
    print(f"βœ“ Size: {len(tflite_model) / 1024:.2f} KB")
    
    return tflite_model


def analyze_tflite_model(tflite_path):
    """
    Analyze converted TFLite model
    
    Args:
        tflite_path: Path to TFLite model
    """
    print("\n" + "="*60)
    print("Model Analysis")
    print("="*60)
    
    # Load interpreter
    interpreter = tf.lite.Interpreter(model_path=tflite_path)
    interpreter.allocate_tensors()
    
    # Get input details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print("\nπŸ“₯ Input Tensor Details:")
    for i, detail in enumerate(input_details):
        print(f"\n  Input {i}:")
        print(f"    Name: {detail['name']}")
        print(f"    Shape: {detail['shape']}")
        print(f"    Type: {detail['dtype']}")
        quant = detail['quantization']
        if quant[0] or quant[1]:
            print(f"    Scale: {quant[0]}")
            print(f"    Zero point: {quant[1]}")
    
    print("\nπŸ“€ Output Tensor Details:")
    for i, detail in enumerate(output_details):
        print(f"\n  Output {i}:")
        print(f"    Name: {detail['name']}")
        print(f"    Shape: {detail['shape']}")
        print(f"    Type: {detail['dtype']}")
        quant = detail['quantization']
        if quant[0] or quant[1]:
            print(f"    Scale: {quant[0]}")
            print(f"    Zero point: {quant[1]}")
    
    # Get tensor details
    tensor_details = interpreter.get_tensor_details()
    print(f"\nπŸ“Š Total Tensors: {len(tensor_details)}")
    
    # Count operations
    print("\nπŸ”§ Model Operations:")
    ops = {}
    for tensor in tensor_details:
        if 'name' in tensor and tensor['name']:
            # Extract op type from name
            parts = tensor['name'].split('/')
            if len(parts) > 1:
                op_type = parts[0]
                ops[op_type] = ops.get(op_type, 0) + 1
    
    for op_type, count in sorted(ops.items()):
        print(f"  {op_type}: {count}")


def test_inference(tflite_path, test_audio_path):
    """
    Test inference with TFLite model
    
    Args:
        tflite_path: Path to TFLite model
        test_audio_path: Path to test audio file
    """
    print("\n" + "="*60)
    print("Testing Inference")
    print("="*60)
    
    # Load model
    interpreter = tf.lite.Interpreter(model_path=tflite_path)
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    # Load test audio
    print(f"\nLoading test audio: {test_audio_path}")
    audio, sr = sf.read(test_audio_path)
    
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    
    if len(audio.shape) > 1:
        audio = np.mean(audio, axis=1)
    
    # Take 1 second
    audio = audio[:16000]
    if len(audio) < 16000:
        audio = np.pad(audio, (0, 16000 - len(audio)))
    
    # Normalize
    audio = audio / (np.max(np.abs(audio)) + 1e-8)
    
    # Prepare input
    input_data = audio.astype(np.float32)[np.newaxis, :]
    
    # Quantize input if needed
    input_dtype = input_details[0]['dtype']
    if input_dtype == np.int8:
        input_scale = input_details[0]['quantization'][0]
        input_zero_point = input_details[0]['quantization'][1]
        input_data = (input_data / input_scale + input_zero_point).astype(np.int8)
    
    # Run inference
    print("\nRunning inference...")
    interpreter.set_tensor(input_details[0]['index'], input_data)
    
    import time
    start = time.time()
    interpreter.invoke()
    latency = (time.time() - start) * 1000
    
    print(f"βœ“ Inference completed")
    print(f"βœ“ Latency: {latency:.2f} ms")
    
    # Get output
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    # Dequantize if needed
    output_dtype = output_details[0]['dtype']
    if output_dtype == np.int8:
        output_scale = output_details[0]['quantization'][0]
        output_zero_point = output_details[0]['quantization'][1]
        output_data = (output_data.astype(np.float32) - output_zero_point) * output_scale
    
    print(f"βœ“ Output shape: {output_data.shape}")
    print(f"βœ“ Output range: [{np.min(output_data):.4f}, {np.max(output_data):.4f}]")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Convert DTLN model to TFLite INT8 for Ethos-U55'
    )
    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help='Path to trained Keras model (.h5)'
    )
    parser.add_argument(
        '--output',
        type=str,
        required=True,
        help='Output path for TFLite model (.tflite)'
    )
    parser.add_argument(
        '--calibration-dir',
        type=str,
        required=True,
        help='Directory with audio for calibration'
    )
    parser.add_argument(
        '--num-calibration-samples',
        type=int,
        default=100,
        help='Number of samples for calibration'
    )
    parser.add_argument(
        '--test-audio',
        type=str,
        default=None,
        help='Path to test audio file'
    )
    parser.add_argument(
        '--dynamic-range',
        action='store_true',
        help='Use dynamic range quantization instead of full INT8'
    )
    
    args = parser.parse_args()
    
    # Convert model
    if args.dynamic_range:
        tflite_model = convert_to_tflite_dynamic_range(
            args.model,
            args.output
        )
    else:
        tflite_model = convert_to_tflite_int8(
            model_path=args.model,
            output_path=args.output,
            representative_data_dir=args.calibration_dir,
            num_calibration_samples=args.num_calibration_samples
        )
    
    if tflite_model is None:
        print("\nβœ— Conversion failed!")
        exit(1)
    
    # Analyze model
    analyze_tflite_model(args.output)
    
    # Test inference if test audio provided
    if args.test_audio and os.path.exists(args.test_audio):
        test_inference(args.output, args.test_audio)
    
    print("\n" + "="*60)
    print("βœ“ All done!")
    print(f"βœ“ TFLite model saved to: {args.output}")
    print("\nNext steps:")
    print("1. Use Vela compiler to optimize for Ethos-U55:")
    print(f"   vela --accelerator-config ethos-u55-256 {args.output}")
    print("2. Integrate into Alif E7 application")
    print("3. Profile on actual hardware")
    print("="*60)