grgsaliba commited on
Commit
a02fb02
Β·
verified Β·
1 Parent(s): e5d5706

Upload convert_to_tflite.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_to_tflite.py +414 -0
convert_to_tflite.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert trained DTLN model to TensorFlow Lite INT8 format
3
+ Optimized for Alif E7 Ethos-U55 NPU deployment
4
+ """
5
+
6
+ import tensorflow as tf
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import librosa
10
+ from pathlib import Path
11
+ import argparse
12
+ import os
13
+
14
+
15
+ def load_representative_dataset(
16
+ audio_dir,
17
+ num_samples=100,
18
+ frame_len=512,
19
+ sampling_rate=16000
20
+ ):
21
+ """
22
+ Load representative audio dataset for calibration
23
+
24
+ Args:
25
+ audio_dir: Directory containing audio files
26
+ num_samples: Number of samples for calibration
27
+ frame_len: Frame length
28
+ sampling_rate: Audio sampling rate
29
+
30
+ Returns:
31
+ Generator yielding audio samples
32
+ """
33
+ audio_files = list(Path(audio_dir).glob('**/*.wav'))
34
+
35
+ if len(audio_files) < num_samples:
36
+ print(f"Warning: Only {len(audio_files)} files found, using all")
37
+ num_samples = len(audio_files)
38
+
39
+ selected_files = np.random.choice(audio_files, num_samples, replace=False)
40
+
41
+ def representative_dataset_gen():
42
+ for file_path in selected_files:
43
+ # Load audio
44
+ audio, sr = sf.read(file_path)
45
+
46
+ # Resample if needed
47
+ if sr != sampling_rate:
48
+ audio = librosa.resample(
49
+ audio,
50
+ orig_sr=sr,
51
+ target_sr=sampling_rate
52
+ )
53
+
54
+ # Convert to mono
55
+ if len(audio.shape) > 1:
56
+ audio = np.mean(audio, axis=1)
57
+
58
+ # Take 1 second segment
59
+ segment_len = sampling_rate
60
+ if len(audio) > segment_len:
61
+ start = np.random.randint(0, len(audio) - segment_len)
62
+ audio = audio[start:start + segment_len]
63
+ else:
64
+ audio = np.pad(audio, (0, segment_len - len(audio)))
65
+
66
+ # Normalize
67
+ audio = audio / (np.max(np.abs(audio)) + 1e-8)
68
+
69
+ # Yield as float32 numpy array
70
+ yield [audio.astype(np.float32)[np.newaxis, :]]
71
+
72
+ return representative_dataset_gen
73
+
74
+
75
+ def convert_to_tflite_int8(
76
+ model_path,
77
+ output_path,
78
+ representative_data_dir,
79
+ num_calibration_samples=100
80
+ ):
81
+ """
82
+ Convert Keras model to TFLite with full INT8 quantization
83
+
84
+ Args:
85
+ model_path: Path to trained Keras model (.h5)
86
+ output_path: Output path for TFLite model (.tflite)
87
+ representative_data_dir: Directory with audio for calibration
88
+ num_calibration_samples: Number of samples for calibration
89
+
90
+ Returns:
91
+ TFLite model as bytes
92
+ """
93
+ print("="*60)
94
+ print("Converting to TensorFlow Lite INT8")
95
+ print("="*60)
96
+
97
+ # Load model
98
+ print("\n1. Loading model...")
99
+ try:
100
+ model = tf.keras.models.load_model(
101
+ model_path,
102
+ compile=False
103
+ )
104
+ print(f" βœ“ Model loaded from {model_path}")
105
+ except Exception as e:
106
+ print(f" βœ— Error loading model: {e}")
107
+ return None
108
+
109
+ model.summary()
110
+
111
+ # Create converter
112
+ print("\n2. Creating TFLite converter...")
113
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
114
+
115
+ # Enable optimizations
116
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
117
+
118
+ # Set up representative dataset for calibration
119
+ print("\n3. Setting up representative dataset...")
120
+ representative_dataset = load_representative_dataset(
121
+ audio_dir=representative_data_dir,
122
+ num_samples=num_calibration_samples
123
+ )
124
+ converter.representative_dataset = representative_dataset
125
+ print(f" βœ“ Using {num_calibration_samples} samples for calibration")
126
+
127
+ # Force full integer quantization
128
+ print("\n4. Configuring INT8 quantization...")
129
+ converter.target_spec.supported_ops = [
130
+ tf.lite.OpsSet.TFLITE_BUILTINS_INT8
131
+ ]
132
+
133
+ # Set input/output to INT8
134
+ converter.inference_input_type = tf.int8
135
+ converter.inference_output_type = tf.int8
136
+
137
+ # Additional optimizations for Ethos-U55
138
+ converter.experimental_new_converter = True
139
+ converter.experimental_new_quantizer = True
140
+
141
+ print(" βœ“ Quantization configured:")
142
+ print(" - Optimization: DEFAULT")
143
+ print(" - Ops: TFLITE_BUILTINS_INT8")
144
+ print(" - Input type: INT8")
145
+ print(" - Output type: INT8")
146
+
147
+ # Convert
148
+ print("\n5. Converting model (this may take a few minutes)...")
149
+ try:
150
+ tflite_model = converter.convert()
151
+ print(" βœ“ Conversion successful!")
152
+ except Exception as e:
153
+ print(f" βœ— Conversion failed: {e}")
154
+ return None
155
+
156
+ # Save
157
+ print(f"\n6. Saving TFLite model to {output_path}...")
158
+ with open(output_path, 'wb') as f:
159
+ f.write(tflite_model)
160
+
161
+ # Print statistics
162
+ model_size_kb = len(tflite_model) / 1024
163
+ print(f" βœ“ Model saved")
164
+ print(f" βœ“ Model size: {model_size_kb:.2f} KB")
165
+
166
+ if model_size_kb > 1024:
167
+ print(f" ⚠ Warning: Model size ({model_size_kb:.2f} KB) exceeds 1MB")
168
+ print(" Consider reducing LSTM units or other optimizations")
169
+
170
+ return tflite_model
171
+
172
+
173
+ def convert_to_tflite_dynamic_range(
174
+ model_path,
175
+ output_path
176
+ ):
177
+ """
178
+ Convert with dynamic range quantization (weights only)
179
+ Lighter quantization, good for testing
180
+
181
+ Args:
182
+ model_path: Path to trained Keras model
183
+ output_path: Output path for TFLite model
184
+
185
+ Returns:
186
+ TFLite model as bytes
187
+ """
188
+ print("Converting with dynamic range quantization...")
189
+
190
+ model = tf.keras.models.load_model(model_path, compile=False)
191
+
192
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
193
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
194
+
195
+ tflite_model = converter.convert()
196
+
197
+ with open(output_path, 'wb') as f:
198
+ f.write(tflite_model)
199
+
200
+ print(f"βœ“ Model saved to {output_path}")
201
+ print(f"βœ“ Size: {len(tflite_model) / 1024:.2f} KB")
202
+
203
+ return tflite_model
204
+
205
+
206
+ def analyze_tflite_model(tflite_path):
207
+ """
208
+ Analyze converted TFLite model
209
+
210
+ Args:
211
+ tflite_path: Path to TFLite model
212
+ """
213
+ print("\n" + "="*60)
214
+ print("Model Analysis")
215
+ print("="*60)
216
+
217
+ # Load interpreter
218
+ interpreter = tf.lite.Interpreter(model_path=tflite_path)
219
+ interpreter.allocate_tensors()
220
+
221
+ # Get input details
222
+ input_details = interpreter.get_input_details()
223
+ output_details = interpreter.get_output_details()
224
+
225
+ print("\nπŸ“₯ Input Tensor Details:")
226
+ for i, detail in enumerate(input_details):
227
+ print(f"\n Input {i}:")
228
+ print(f" Name: {detail['name']}")
229
+ print(f" Shape: {detail['shape']}")
230
+ print(f" Type: {detail['dtype']}")
231
+ quant = detail['quantization']
232
+ if quant[0] or quant[1]:
233
+ print(f" Scale: {quant[0]}")
234
+ print(f" Zero point: {quant[1]}")
235
+
236
+ print("\nπŸ“€ Output Tensor Details:")
237
+ for i, detail in enumerate(output_details):
238
+ print(f"\n Output {i}:")
239
+ print(f" Name: {detail['name']}")
240
+ print(f" Shape: {detail['shape']}")
241
+ print(f" Type: {detail['dtype']}")
242
+ quant = detail['quantization']
243
+ if quant[0] or quant[1]:
244
+ print(f" Scale: {quant[0]}")
245
+ print(f" Zero point: {quant[1]}")
246
+
247
+ # Get tensor details
248
+ tensor_details = interpreter.get_tensor_details()
249
+ print(f"\nπŸ“Š Total Tensors: {len(tensor_details)}")
250
+
251
+ # Count operations
252
+ print("\nπŸ”§ Model Operations:")
253
+ ops = {}
254
+ for tensor in tensor_details:
255
+ if 'name' in tensor and tensor['name']:
256
+ # Extract op type from name
257
+ parts = tensor['name'].split('/')
258
+ if len(parts) > 1:
259
+ op_type = parts[0]
260
+ ops[op_type] = ops.get(op_type, 0) + 1
261
+
262
+ for op_type, count in sorted(ops.items()):
263
+ print(f" {op_type}: {count}")
264
+
265
+
266
+ def test_inference(tflite_path, test_audio_path):
267
+ """
268
+ Test inference with TFLite model
269
+
270
+ Args:
271
+ tflite_path: Path to TFLite model
272
+ test_audio_path: Path to test audio file
273
+ """
274
+ print("\n" + "="*60)
275
+ print("Testing Inference")
276
+ print("="*60)
277
+
278
+ # Load model
279
+ interpreter = tf.lite.Interpreter(model_path=tflite_path)
280
+ interpreter.allocate_tensors()
281
+
282
+ input_details = interpreter.get_input_details()
283
+ output_details = interpreter.get_output_details()
284
+
285
+ # Load test audio
286
+ print(f"\nLoading test audio: {test_audio_path}")
287
+ audio, sr = sf.read(test_audio_path)
288
+
289
+ if sr != 16000:
290
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
291
+
292
+ if len(audio.shape) > 1:
293
+ audio = np.mean(audio, axis=1)
294
+
295
+ # Take 1 second
296
+ audio = audio[:16000]
297
+ if len(audio) < 16000:
298
+ audio = np.pad(audio, (0, 16000 - len(audio)))
299
+
300
+ # Normalize
301
+ audio = audio / (np.max(np.abs(audio)) + 1e-8)
302
+
303
+ # Prepare input
304
+ input_data = audio.astype(np.float32)[np.newaxis, :]
305
+
306
+ # Quantize input if needed
307
+ input_dtype = input_details[0]['dtype']
308
+ if input_dtype == np.int8:
309
+ input_scale = input_details[0]['quantization'][0]
310
+ input_zero_point = input_details[0]['quantization'][1]
311
+ input_data = (input_data / input_scale + input_zero_point).astype(np.int8)
312
+
313
+ # Run inference
314
+ print("\nRunning inference...")
315
+ interpreter.set_tensor(input_details[0]['index'], input_data)
316
+
317
+ import time
318
+ start = time.time()
319
+ interpreter.invoke()
320
+ latency = (time.time() - start) * 1000
321
+
322
+ print(f"βœ“ Inference completed")
323
+ print(f"βœ“ Latency: {latency:.2f} ms")
324
+
325
+ # Get output
326
+ output_data = interpreter.get_tensor(output_details[0]['index'])
327
+
328
+ # Dequantize if needed
329
+ output_dtype = output_details[0]['dtype']
330
+ if output_dtype == np.int8:
331
+ output_scale = output_details[0]['quantization'][0]
332
+ output_zero_point = output_details[0]['quantization'][1]
333
+ output_data = (output_data.astype(np.float32) - output_zero_point) * output_scale
334
+
335
+ print(f"βœ“ Output shape: {output_data.shape}")
336
+ print(f"βœ“ Output range: [{np.min(output_data):.4f}, {np.max(output_data):.4f}]")
337
+
338
+
339
+ if __name__ == "__main__":
340
+ parser = argparse.ArgumentParser(
341
+ description='Convert DTLN model to TFLite INT8 for Ethos-U55'
342
+ )
343
+ parser.add_argument(
344
+ '--model',
345
+ type=str,
346
+ required=True,
347
+ help='Path to trained Keras model (.h5)'
348
+ )
349
+ parser.add_argument(
350
+ '--output',
351
+ type=str,
352
+ required=True,
353
+ help='Output path for TFLite model (.tflite)'
354
+ )
355
+ parser.add_argument(
356
+ '--calibration-dir',
357
+ type=str,
358
+ required=True,
359
+ help='Directory with audio for calibration'
360
+ )
361
+ parser.add_argument(
362
+ '--num-calibration-samples',
363
+ type=int,
364
+ default=100,
365
+ help='Number of samples for calibration'
366
+ )
367
+ parser.add_argument(
368
+ '--test-audio',
369
+ type=str,
370
+ default=None,
371
+ help='Path to test audio file'
372
+ )
373
+ parser.add_argument(
374
+ '--dynamic-range',
375
+ action='store_true',
376
+ help='Use dynamic range quantization instead of full INT8'
377
+ )
378
+
379
+ args = parser.parse_args()
380
+
381
+ # Convert model
382
+ if args.dynamic_range:
383
+ tflite_model = convert_to_tflite_dynamic_range(
384
+ args.model,
385
+ args.output
386
+ )
387
+ else:
388
+ tflite_model = convert_to_tflite_int8(
389
+ model_path=args.model,
390
+ output_path=args.output,
391
+ representative_data_dir=args.calibration_dir,
392
+ num_calibration_samples=args.num_calibration_samples
393
+ )
394
+
395
+ if tflite_model is None:
396
+ print("\nβœ— Conversion failed!")
397
+ exit(1)
398
+
399
+ # Analyze model
400
+ analyze_tflite_model(args.output)
401
+
402
+ # Test inference if test audio provided
403
+ if args.test_audio and os.path.exists(args.test_audio):
404
+ test_inference(args.output, args.test_audio)
405
+
406
+ print("\n" + "="*60)
407
+ print("βœ“ All done!")
408
+ print(f"βœ“ TFLite model saved to: {args.output}")
409
+ print("\nNext steps:")
410
+ print("1. Use Vela compiler to optimize for Ethos-U55:")
411
+ print(f" vela --accelerator-config ethos-u55-256 {args.output}")
412
+ print("2. Integrate into Alif E7 application")
413
+ print("3. Profile on actual hardware")
414
+ print("="*60)