Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Optional | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class MarketingFeature: | |
| """Structure to hold marketing-relevant feature information""" | |
| feature_id: int | |
| name: str | |
| category: str | |
| description: str | |
| interpretation_guide: str | |
| layer: int | |
| threshold: float = 0.1 | |
| # Define relevant features | |
| MARKETING_FEATURES = [ | |
| MarketingFeature( | |
| feature_id=35, | |
| name="Technical Term Detector", | |
| category="technical", | |
| description="Detects technical and specialized terminology", | |
| interpretation_guide="High activation indicates strong technical focus", | |
| layer=20, | |
| ), | |
| MarketingFeature( | |
| feature_id=6680, | |
| name="Compound Technical Terms", | |
| category="technical", | |
| description="Identifies complex technical concepts", | |
| interpretation_guide="Consider simplifying language if activation is too high", | |
| layer=20, | |
| ), | |
| MarketingFeature( | |
| feature_id=2, | |
| name="SEO Keyword Detector", | |
| category="seo", | |
| description="Identifies potential SEO keywords", | |
| interpretation_guide="High activation suggests strong SEO potential", | |
| layer=20, | |
| ), | |
| ] | |
| class JumpReLUSAE(nn.Module): | |
| def __init__(self, d_model, d_sae): | |
| super().__init__() | |
| self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae)) | |
| self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model)) | |
| self.threshold = nn.Parameter(torch.zeros(d_sae)) | |
| self.b_enc = nn.Parameter(torch.zeros(d_sae)) | |
| self.b_dec = nn.Parameter(torch.zeros(d_model)) | |
| def encode(self, input_acts): | |
| pre_acts = input_acts @ self.W_enc + self.b_enc | |
| mask = pre_acts > self.threshold | |
| acts = mask * torch.nn.functional.relu(pre_acts) | |
| return acts | |
| def decode(self, acts): | |
| return acts @ self.W_dec + self.b_dec | |
| def forward(self, acts): | |
| acts = self.encode(acts) | |
| recon = self.decode(acts) | |
| return recon | |
| class MarketingAnalyzer: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.set_grad_enabled(False) # Avoid memory issues | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-2-2b", device_map="auto" | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") | |
| self.model.eval() | |
| logger.info("Model initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error initializing model: {str(e)}") | |
| raise | |
| def _load_sae(self, feature_id: int, layer: int = 20): | |
| """Dynamically load a single SAE""" | |
| try: | |
| path = hf_hub_download( | |
| repo_id="google/gemma-scope-2b-pt-res", | |
| filename=f"layer_{layer}/width_16k/average_l0_71/params.npz", | |
| force_download=False, | |
| ) | |
| params = np.load(path) | |
| # Create SAE | |
| d_model = params["W_enc"].shape[0] | |
| d_sae = params["W_enc"].shape[1] | |
| sae = JumpReLUSAE(d_model, d_sae).to(self.device) | |
| # Load parameters | |
| sae_params = { | |
| k: torch.from_numpy(v).to(self.device) for k, v in params.items() | |
| } | |
| sae.load_state_dict(sae_params) | |
| return sae | |
| except Exception as e: | |
| logger.error(f"Error loading SAE for feature {feature_id}: {str(e)}") | |
| return None | |
| def _gather_activations(self, text: str, layer: int): | |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.device) | |
| target_act = None | |
| def hook(mod, inputs, outputs): | |
| nonlocal target_act | |
| target_act = outputs[0] | |
| return outputs | |
| handle = self.model.model.layers[layer].register_forward_hook(hook) | |
| with torch.no_grad(): | |
| _ = self.model(**inputs) | |
| handle.remove() | |
| return target_act, inputs | |
| def _get_feature_activations(self, text: str, sae, layer: int = 20): | |
| """Get activations for a single feature""" | |
| activations, _ = self._gather_activations(text, layer) | |
| sae_acts = sae.encode(activations.to(torch.float32)) | |
| sae_acts = sae_acts[:, 1:] # Skip BOS token | |
| if sae_acts.numel() > 0: | |
| mean_activation = float(sae_acts.mean()) | |
| max_activation = float(sae_acts.max()) | |
| else: | |
| mean_activation = 0.0 | |
| max_activation = 0.0 | |
| return mean_activation, max_activation | |
| def analyze_content(self, text: str) -> Dict: | |
| """Analyze content and find most relevant features""" | |
| results = { | |
| "text": text, | |
| "features": {}, | |
| "categories": {}, | |
| "recommendations": [], | |
| } | |
| try: | |
| # Start with a set of potential features to explore | |
| feature_pool = list(range(1, 16385)) # Full range of features | |
| sample_size = 50 # Number of features to sample | |
| sampled_features = np.random.choice( | |
| feature_pool, sample_size, replace=False | |
| ) | |
| # Test each feature | |
| feature_activations = [] | |
| for feature_id in sampled_features: | |
| sae = self._load_sae(feature_id) | |
| if sae is None: | |
| continue | |
| mean_activation, max_activation = self._get_feature_activations( | |
| text, sae | |
| ) | |
| feature_activations.append( | |
| { | |
| "feature_id": feature_id, | |
| "mean_activation": mean_activation, | |
| "max_activation": max_activation, | |
| } | |
| ) | |
| # Sort by activation and take top features | |
| top_features = sorted( | |
| feature_activations, key=lambda x: x["max_activation"], reverse=True | |
| )[ | |
| :3 | |
| ] # Keep top 3 features | |
| # Analyze top features in detail | |
| for feature_data in top_features: | |
| feature_id = feature_data["feature_id"] | |
| # Get neuronpedia data if available (this would be a placeholder) | |
| feature_name = f"Feature {feature_id}" | |
| feature_category = "neural" # Default category | |
| feature_result = { | |
| "name": feature_name, | |
| "category": feature_category, | |
| "activation_score": feature_data["mean_activation"], | |
| "max_activation": feature_data["max_activation"], | |
| "interpretation": self._interpret_activation( | |
| feature_data["mean_activation"], feature_id | |
| ), | |
| } | |
| results["features"][feature_id] = feature_result | |
| if feature_category not in results["categories"]: | |
| results["categories"][feature_category] = [] | |
| results["categories"][feature_category].append(feature_result) | |
| # Generate recommendations based on activations | |
| if top_features: | |
| max_activation = max(f["max_activation"] for f in top_features) | |
| if max_activation > 0.8: | |
| results["recommendations"].append( | |
| f"Strong activation detected in feature {top_features[0]['feature_id']}. " | |
| "Consider exploring this aspect further." | |
| ) | |
| elif max_activation < 0.3: | |
| results["recommendations"].append( | |
| "Low feature activations overall. Content might benefit from more distinctive elements." | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error analyzing content: {str(e)}") | |
| raise | |
| return results | |
| def _interpret_activation(self, activation: float, feature_id: int) -> str: | |
| """Interpret activation levels for a feature""" | |
| if activation > 0.8: | |
| return f"Very strong activation of feature {feature_id}" | |
| elif activation > 0.5: | |
| return f"Moderate activation of feature {feature_id}" | |
| else: | |
| return f"Limited activation of feature {feature_id}" | |
| def create_gradio_interface(): | |
| try: | |
| analyzer = MarketingAnalyzer() | |
| except Exception as e: | |
| logger.error(f"Failed to initialize analyzer: {str(e)}") | |
| return gr.Interface( | |
| fn=lambda x: "Error: Failed to initialize model. Please check authentication.", | |
| inputs=gr.Textbox(), | |
| outputs=gr.Textbox(), | |
| title="Marketing Content Analyzer (Error)", | |
| description="Failed to initialize.", | |
| ) | |
| def analyze(text): | |
| results = analyzer.analyze_content(text) | |
| output = "# Content Analysis Results\n\n" | |
| output += "## Category Scores\n" | |
| for category, features in results["categories"].items(): | |
| if features: | |
| avg_score = np.mean([f["activation_score"] for f in features]) | |
| output += f"**{category.title()}**: {avg_score:.2f}\n" | |
| output += "\n## Feature Details\n" | |
| for feature_id, feature in results["features"].items(): | |
| output += f"\n### {feature['name']} (Feature {feature_id})\n" | |
| output += f"**Score**: {feature['activation_score']:.2f}\n\n" | |
| output += f"**Interpretation**: {feature['interpretation']}\n\n" | |
| # Add feature explanation from Neuronpedia reference | |
| output += f"[View feature details on Neuronpedia](https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id})\n\n" | |
| if results["recommendations"]: | |
| output += "\n## Recommendations\n" | |
| for rec in results["recommendations"]: | |
| output += f"- {rec}\n" | |
| feature_id = max( | |
| results["features"].items(), key=lambda x: x[1]["activation_score"] | |
| )[0] | |
| # Build dashboard URL for the highest activating feature | |
| dashboard_url = f"https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
| return output, dashboard_url, feature_id | |
| with gr.Blocks( | |
| theme=gr.themes.Default( | |
| font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"], | |
| primary_hue="indigo", | |
| secondary_hue="blue", | |
| neutral_hue="gray", | |
| ) | |
| ) as interface: | |
| gr.Markdown("# Marketing Content Analyzer") | |
| gr.Markdown( | |
| "Analyze your marketing content using Gemma Scope's neural features" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| lines=5, | |
| placeholder="Enter your marketing content here...", | |
| label="Marketing Content", | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| "WordLift is an AI-powered SEO tool", | |
| "Our advanced machine learning algorithms optimize your content", | |
| "Simple and effective website optimization", | |
| ], | |
| inputs=input_text, | |
| ) | |
| with gr.Column(scale=2): | |
| output_text = gr.Markdown(label="Analysis Results") | |
| with gr.Group(): | |
| gr.Markdown("## Feature Dashboard") | |
| feature_id_text = gr.Text( | |
| label="Currently viewing feature", show_label=False | |
| ) | |
| dashboard_frame = gr.HTML( | |
| value="Analysis results will appear here", | |
| label="Feature Dashboard", | |
| ) | |
| def update_dashboard(text): | |
| output, dashboard_url, feature_id = analyze(text) | |
| return ( | |
| output, | |
| f"<iframe src='{dashboard_url}' width='100%' height='600px' frameborder='0' style='border: 1px solid #eee; border-radius: 8px;'></iframe>", | |
| f"Currently viewing Feature {feature_id} - Most active feature in your content", | |
| ) | |
| analyze_btn.click( | |
| fn=update_dashboard, | |
| inputs=input_text, | |
| outputs=[output_text, dashboard_frame, feature_id_text], | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| iface = create_gradio_interface() | |
| iface.launch() | |