Spaces:
Sleeping
Sleeping
| """ | |
| GAN Interactive Demo - Aplicación Gradio | |
| Visualización interactiva del espacio latente y generación de dígitos MNIST | |
| """ | |
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| from sklearn.manifold import TSNE | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from PIL import Image | |
| import io | |
| import os | |
| # Configuración | |
| LATENT_DIM = 100 | |
| MODEL_DIR = "models" | |
| # Cargar el generador | |
| print("Cargando modelo generador...") | |
| try: | |
| generator = keras.models.load_model(f'{MODEL_DIR}/generator.h5', compile=False) | |
| print("✓ Generador cargado exitosamente") | |
| except Exception as e: | |
| print(f"Error cargando generador: {e}") | |
| generator = None | |
| # Cargar vectores latentes pre-generados para exploración | |
| try: | |
| latent_vectors = np.load(f'{MODEL_DIR}/latent_vectors.npy') | |
| generated_images_cache = np.load(f'{MODEL_DIR}/generated_images.npy') | |
| print(f"✓ Vectores latentes cargados: {latent_vectors.shape}") | |
| except Exception as e: | |
| print(f"Generando nuevos vectores latentes...") | |
| latent_vectors = np.random.normal(0, 1, (1000, LATENT_DIM)) | |
| if generator: | |
| generated_images_cache = generator(latent_vectors, training=False).numpy() | |
| else: | |
| generated_images_cache = None | |
| # Calcular reducción dimensional para visualización | |
| print("Calculando reducción dimensional...") | |
| pca = PCA(n_components=3) | |
| latent_pca = pca.fit_transform(latent_vectors) | |
| tsne = TSNE(n_components=2, random_state=42, perplexity=30) | |
| latent_tsne = tsne.fit_transform(latent_vectors[:500]) # Usar subset para velocidad | |
| print("✓ Aplicación lista") | |
| # ==================== FUNCIONES DE GENERACIÓN ==================== | |
| def generate_random_digit(): | |
| """Genera un dígito aleatorio desde un vector latente random""" | |
| if generator is None: | |
| return None, "Modelo no disponible" | |
| # Generar vector latente aleatorio | |
| latent_vector = np.random.normal(0, 1, (1, LATENT_DIM)) | |
| # Generar imagen | |
| generated_image = generator(latent_vector, training=False) | |
| image = generated_image[0, :, :, 0].numpy() | |
| # Desnormalizar | |
| image = (image * 127.5 + 127.5).astype(np.uint8) | |
| # Convertir a PIL Image para Gradio | |
| pil_image = Image.fromarray(image, mode='L') | |
| return pil_image, f"Vector latente: {latent_vector[0, :5]}... (primeros 5 valores)" | |
| def generate_from_sliders(*slider_values): | |
| """Genera un dígito desde valores de sliders (primeras 10 dimensiones)""" | |
| if generator is None: | |
| return None, "Modelo no disponible" | |
| # Crear vector latente: primeras 10 dimensiones desde sliders, resto aleatorio | |
| latent_vector = np.random.normal(0, 1, (1, LATENT_DIM)) | |
| latent_vector[0, :10] = slider_values | |
| # Generar imagen | |
| generated_image = generator(latent_vector, training=False) | |
| image = generated_image[0, :, :, 0].numpy() | |
| # Desnormalizar | |
| image = (image * 127.5 + 127.5).astype(np.uint8) | |
| # Convertir a PIL Image para Gradio | |
| pil_image = Image.fromarray(image, mode='L') | |
| return pil_image | |
| def interpolate_digits(start_seed, end_seed, steps): | |
| """Interpola entre dos dígitos generados desde semillas""" | |
| if generator is None: | |
| return None | |
| # Generar vectores latentes desde semillas | |
| np.random.seed(int(start_seed)) | |
| latent_start = np.random.normal(0, 1, (1, LATENT_DIM)) | |
| np.random.seed(int(end_seed)) | |
| latent_end = np.random.normal(0, 1, (1, LATENT_DIM)) | |
| # Crear interpolación lineal | |
| alphas = np.linspace(0, 1, int(steps)) | |
| # Generar imágenes interpoladas | |
| images = [] | |
| for alpha in alphas: | |
| latent_interp = (1 - alpha) * latent_start + alpha * latent_end | |
| generated = generator(latent_interp, training=False) | |
| image = generated[0, :, :, 0].numpy() | |
| image = (image * 127.5 + 127.5).astype(np.uint8) | |
| images.append(image) | |
| # Crear grid de imágenes | |
| n_images = len(images) | |
| cols = min(10, n_images) | |
| rows = (n_images + cols - 1) // cols | |
| fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5)) | |
| if rows == 1: | |
| axes = axes.reshape(1, -1) | |
| for idx, image in enumerate(images): | |
| row = idx // cols | |
| col = idx % cols | |
| axes[row, col].imshow(image, cmap='gray') | |
| axes[row, col].axis('off') | |
| axes[row, col].set_title(f'{idx+1}', fontsize=8) | |
| # Ocultar ejes vacíos | |
| for idx in range(n_images, rows * cols): | |
| row = idx // cols | |
| col = idx % cols | |
| axes[row, col].axis('off') | |
| plt.tight_layout() | |
| # Convertir a imagen | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| def visualize_latent_space_pca(): | |
| """Visualiza el espacio latente en 3D usando PCA""" | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=latent_pca[:, 0], | |
| y=latent_pca[:, 1], | |
| z=latent_pca[:, 2], | |
| mode='markers', | |
| marker=dict( | |
| size=3, | |
| color=latent_pca[:, 2], | |
| colorscale='Viridis', | |
| showscale=True, | |
| colorbar=dict(title="PC3"), | |
| opacity=0.7 | |
| ), | |
| text=[f'Punto {i}' for i in range(len(latent_pca))], | |
| hovertemplate='<b>Punto %{text}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<br>PC3: %{z:.2f}<extra></extra>' | |
| )]) | |
| fig.update_layout( | |
| title='Espacio Latente - Visualización PCA 3D', | |
| scene=dict( | |
| xaxis_title='Componente Principal 1', | |
| yaxis_title='Componente Principal 2', | |
| zaxis_title='Componente Principal 3', | |
| bgcolor='rgba(240, 240, 240, 0.9)' | |
| ), | |
| width=800, | |
| height=600, | |
| showlegend=False | |
| ) | |
| return fig | |
| def visualize_latent_space_tsne(): | |
| """Visualiza el espacio latente en 2D usando t-SNE""" | |
| fig = go.Figure(data=[go.Scatter( | |
| x=latent_tsne[:, 0], | |
| y=latent_tsne[:, 1], | |
| mode='markers', | |
| marker=dict( | |
| size=6, | |
| color=np.arange(len(latent_tsne)), | |
| colorscale='Plasma', | |
| showscale=True, | |
| colorbar=dict(title="Índice"), | |
| opacity=0.7 | |
| ), | |
| text=[f'Punto {i}' for i in range(len(latent_tsne))], | |
| hovertemplate='<b>Punto %{text}</b><br>t-SNE 1: %{x:.2f}<br>t-SNE 2: %{y:.2f}<extra></extra>' | |
| )]) | |
| fig.update_layout( | |
| title='Espacio Latente - Visualización t-SNE 2D', | |
| xaxis_title='Dimensión t-SNE 1', | |
| yaxis_title='Dimensión t-SNE 2', | |
| width=800, | |
| height=600, | |
| plot_bgcolor='rgba(240, 240, 240, 0.9)' | |
| ) | |
| return fig | |
| def generate_from_latent_index(index): | |
| """Genera imagen desde un índice del espacio latente pre-calculado""" | |
| if generated_images_cache is None: | |
| return None, "Cache no disponible" | |
| index = int(index) % len(generated_images_cache) | |
| image = generated_images_cache[index, :, :, 0] | |
| image = (image * 127.5 + 127.5).astype(np.uint8) | |
| # Convertir a PIL Image para Gradio | |
| pil_image = Image.fromarray(image, mode='L') | |
| return pil_image, f"Índice: {index}\nVector latente: {latent_vectors[index, :5]}..." | |
| def generate_grid_comparison(): | |
| """Genera un grid de comparación de múltiples dígitos""" | |
| if generator is None: | |
| return None | |
| # Generar 16 dígitos aleatorios | |
| latent_vectors_batch = np.random.normal(0, 1, (16, LATENT_DIM)) | |
| generated_images = generator(latent_vectors_batch, training=False) | |
| # Crear grid | |
| fig, axes = plt.subplots(4, 4, figsize=(10, 10)) | |
| for i in range(4): | |
| for j in range(4): | |
| idx = i * 4 + j | |
| image = generated_images[idx, :, :, 0].numpy() | |
| image = (image * 127.5 + 127.5).astype(np.uint8) | |
| axes[i, j].imshow(image, cmap='gray') | |
| axes[i, j].axis('off') | |
| plt.tight_layout() | |
| # Convertir a imagen | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| # ==================== INTERFAZ GRADIO ==================== | |
| # CSS personalizado | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .tab-nav button { | |
| font-size: 16px; | |
| font-weight: bold; | |
| } | |
| """ | |
| # Crear interfaz | |
| with gr.Blocks(css=custom_css, title="GAN Interactive Demo - MNIST", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎨 GAN Interactive Demo - Exploración del Espacio Latente | |
| ### Generative Adversarial Network entrenada en MNIST | |
| Explora cómo una GAN aprende a generar dígitos manuscritos desde vectores de ruido aleatorio. | |
| Inspirado en el TensorFlow Projector, esta demo te permite navegar el espacio latente de 100 dimensiones. | |
| """) | |
| with gr.Tabs(): | |
| # TAB 1: Generación Simple | |
| with gr.Tab("🎲 Generación Aleatoria"): | |
| gr.Markdown("### Genera dígitos aleatorios con un clic") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| btn_generate = gr.Button("🎲 Generar Dígito Aleatorio", variant="primary", size="lg") | |
| latent_info = gr.Textbox(label="Información del Vector Latente", lines=2) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Dígito Generado", type="pil") | |
| btn_generate.click( | |
| fn=generate_random_digit, | |
| outputs=[output_image, latent_info] | |
| ) | |
| # TAB 2: Control Manual | |
| with gr.Tab("🎛️ Control Manual"): | |
| gr.Markdown("### Controla las primeras 10 dimensiones del vector latente") | |
| gr.Markdown("Ajusta los sliders para ver cómo cada dimensión afecta la generación") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sliders = [] | |
| for i in range(10): | |
| slider = gr.Slider( | |
| minimum=-3, | |
| maximum=3, | |
| value=0, | |
| step=0.1, | |
| label=f"Dimensión {i+1}" | |
| ) | |
| sliders.append(slider) | |
| btn_generate_sliders = gr.Button("Generar desde Sliders", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image_sliders = gr.Image(label="Dígito Generado", type="pil") | |
| btn_generate_sliders.click( | |
| fn=generate_from_sliders, | |
| inputs=sliders, | |
| outputs=output_image_sliders | |
| ) | |
| # TAB 3: Interpolación | |
| with gr.Tab("🔄 Interpolación"): | |
| gr.Markdown("### Morphing entre dos dígitos") | |
| gr.Markdown("Observa cómo la GAN transforma suavemente un dígito en otro") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| start_seed = gr.Number(label="Semilla Inicial", value=42) | |
| end_seed = gr.Number(label="Semilla Final", value=123) | |
| steps = gr.Slider( | |
| minimum=5, | |
| maximum=20, | |
| value=10, | |
| step=1, | |
| label="Número de Pasos" | |
| ) | |
| btn_interpolate = gr.Button("🔄 Generar Interpolación", variant="primary") | |
| with gr.Column(scale=2): | |
| output_interpolation = gr.Image(label="Secuencia de Interpolación") | |
| btn_interpolate.click( | |
| fn=interpolate_digits, | |
| inputs=[start_seed, end_seed, steps], | |
| outputs=output_interpolation | |
| ) | |
| # TAB 4: Exploración del Espacio Latente | |
| with gr.Tab("🌌 Espacio Latente"): | |
| gr.Markdown("### Visualización del Espacio Latente de 100 Dimensiones") | |
| gr.Markdown("Similar al TensorFlow Projector: explora cómo se distribuyen los vectores latentes") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Visualización 3D (PCA)") | |
| btn_pca = gr.Button("Mostrar PCA 3D", variant="secondary") | |
| plot_pca = gr.Plot(label="Espacio Latente - PCA") | |
| btn_pca.click( | |
| fn=visualize_latent_space_pca, | |
| outputs=plot_pca | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Visualización 2D (t-SNE)") | |
| btn_tsne = gr.Button("Mostrar t-SNE 2D", variant="secondary") | |
| plot_tsne = gr.Plot(label="Espacio Latente - t-SNE") | |
| btn_tsne.click( | |
| fn=visualize_latent_space_tsne, | |
| outputs=plot_tsne | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### Genera desde un punto específico del espacio") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| latent_index = gr.Slider( | |
| minimum=0, | |
| maximum=999, | |
| value=0, | |
| step=1, | |
| label="Índice del Vector Latente" | |
| ) | |
| btn_generate_index = gr.Button("Generar desde Índice", variant="primary") | |
| latent_index_info = gr.Textbox(label="Información", lines=2) | |
| with gr.Column(scale=1): | |
| output_image_index = gr.Image(label="Dígito Generado", type="pil") | |
| btn_generate_index.click( | |
| fn=generate_from_latent_index, | |
| inputs=latent_index, | |
| outputs=[output_image_index, latent_index_info] | |
| ) | |
| # TAB 5: Grid de Comparación | |
| with gr.Tab("📊 Grid de Dígitos"): | |
| gr.Markdown("### Genera múltiples dígitos simultáneamente") | |
| gr.Markdown("Observa la diversidad y calidad de las generaciones") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| btn_grid = gr.Button("🎨 Generar Grid 4×4", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_grid = gr.Image(label="Grid de 16 Dígitos Generados") | |
| btn_grid.click( | |
| fn=generate_grid_comparison, | |
| outputs=output_grid | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 📚 Sobre esta Demo | |
| Esta aplicación interactiva demuestra el poder de las **Redes Generativas Adversarias (GANs)** entrenadas en el dataset MNIST. | |
| **Características:** | |
| - **Espacio Latente de 100 dimensiones**: Cada dígito es generado desde un vector de 100 números aleatorios | |
| - **Visualización dimensional**: PCA y t-SNE reducen las 100 dimensiones a 2D/3D para visualización | |
| - **Interpolación suave**: Demuestra que el espacio latente es continuo y significativo | |
| - **Generación instantánea**: Sin necesidad de re-entrenar | |
| **Arquitectura:** | |
| - **Generador**: 7×7×256 → 14×14×64 → 28×28×1 (Conv2DTranspose + BatchNorm + LeakyReLU) | |
| - **Discriminador**: 28×28×1 → 14×14×64 → 7×7×128 → Logit (Conv2D + Dropout) | |
| - **Entrenamiento**: 50 épocas, Adam optimizer, Binary Cross-Entropy loss | |
| 🎓 **Creado para la clase de Machine Learning** | |
| """) | |
| # Lanzar aplicación | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |