hareshchander commited on
Commit
7d667da
·
verified ·
1 Parent(s): f143989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -53
app.py CHANGED
@@ -1,61 +1,96 @@
1
- import streamlit as st
2
  import pandas as pd
3
- import matplotlib.pyplot as plt
4
  import torch
5
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
6
 
7
- # Transformer model for anomaly detection
8
- MODEL_NAME = "your-username/smartgrid-anomaly-transformer" # replace with your Hugging Face model
9
-
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
12
-
13
- # Thresholds for fallback (optional)
14
- VOLTAGE_MIN, VOLTAGE_MAX = 215, 235
15
- CURRENT_MAX = 35
16
- FREQ_MIN, FREQ_MAX = 49.5, 50.5
17
-
18
- st.title("⚡ Smart Grid Anomaly Detection (Transformer-Based)")
19
-
20
- uploaded_file = st.file_uploader(
21
- "Upload your sensor data file (.csv or .xlsx)", type=["csv", "xlsx"]
22
  )
 
 
23
 
24
- if uploaded_file:
25
- if uploaded_file.name.endswith(".csv"):
26
- df = pd.read_csv(uploaded_file)
27
- else:
28
- df = pd.read_excel(uploaded_file)
29
-
30
- df.columns = [c.strip().lower() for c in df.columns]
31
- expected_cols = {"timestamp", "voltage", "current", "frequency"}
32
- if not expected_cols.issubset(df.columns):
33
- st.error("File missing columns: timestamp, voltage, current, frequency")
34
- else:
35
- # Prepare data for Transformer (as text sequence)
36
- sequences = df.apply(
37
- lambda row: f"{row['voltage']} {row['current']} {row['frequency']}", axis=1
38
- ).tolist()
39
-
40
- inputs = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
41
- with torch.no_grad():
42
- outputs = model(**inputs)
43
- predictions = torch.argmax(outputs.logits, dim=1).numpy()
44
-
45
- df["anomaly"] = predictions # 1 = anomaly, 0 = normal
46
 
47
- st.subheader("Anomaly Detection Output")
48
- st.dataframe(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- st.subheader("Voltage Plot with Anomalies")
51
- fig, ax = plt.subplots(figsize=(10, 4))
52
- ax.plot(df["timestamp"], df["voltage"], label="Voltage (V)", color="blue")
53
- ax.plot(df[df["anomaly"] == 1]["timestamp"], df[df["anomaly"] == 1]["voltage"], "ro", label="Anomaly")
54
- ax.set_xlabel("Timestamp")
55
- ax.set_ylabel("Voltage")
56
- ax.legend()
57
- plt.xticks(rotation=45)
58
- st.pyplot(fig)
59
 
60
- st.markdown("---")
61
- st.markdown("Developed by **Haresh Chander** | Powered by Python, Transformers & Streamlit")
 
1
+ import gradio as gr
2
  import pandas as pd
3
+ import numpy as np
4
  import torch
5
+ from momentfm import MOMENTPipeline
6
+ from scipy.stats import zscore
7
 
8
+ # Load the model
9
+ model = MOMENTPipeline.from_pretrained(
10
+ "AutonLab/MOMENT-1-large",
11
+ model_kwargs={"task_name": "reconstruction", "n_channels": 3},
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
+ model.init()
14
+ model.eval()
15
 
16
+ def detect_anomalies(data, window_size=512, stride=128):
17
+ n, num_features = data.shape
18
+ errors = np.zeros(n)
19
+ counts = np.zeros(n)
20
+ with torch.no_grad():
21
+ for start in range(0, n - window_size + 1, stride):
22
+ end = start + window_size
23
+ window = data[start:end]
24
+ x = torch.from_numpy(window).float().unsqueeze(0) # (1, window_size, 3)
25
+ output = model(x)
26
+ recon = output.reconstruction # (1, window_size, 3)
27
+ step_errors = ((x - recon) ** 2).mean(dim=2).squeeze(0).numpy() # (window_size,)
28
+ errors[start:end] += step_errors
29
+ counts[start:end] += 1
30
+ errors /= np.maximum(counts, 1e-6)
31
+ return errors
 
 
 
 
 
 
32
 
33
+ def process_csv(file):
34
+ if not file:
35
+ return "Please upload a CSV file."
36
+ try:
37
+ df = pd.read_csv(file)
38
+ required_cols = ['timestamp', 'voltage', 'current', 'frequency']
39
+ if not all(col in df.columns for col in required_cols):
40
+ return "CSV must contain columns: timestamp, voltage, current, frequency"
41
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
42
+ df = df.sort_values('timestamp')
43
+ features = ['voltage', 'current', 'frequency']
44
+ data = df[features].values.astype(float) # (len, 3)
45
+ # Normalize data per channel
46
+ data = zscore(data, axis=0)
47
+ seq_len = len(data)
48
+ if seq_len < 512:
49
+ # Pad with zeros
50
+ pad = np.zeros((512 - seq_len, 3))
51
+ data_padded = np.vstack((data, pad))
52
+ errors = detect_anomalies(data_padded)[:seq_len]
53
+ else:
54
+ errors = detect_anomalies(data)
55
+ # Compute threshold
56
+ mean_e = np.mean(errors)
57
+ std_e = np.std(errors)
58
+ threshold = mean_e + 3 * std_e
59
+ is_anomaly = errors > threshold
60
+ # Severity score (scaled 0-10)
61
+ severity = np.clip((errors - mean_e) / (3 * std_e), 0, np.inf) * 10
62
+ # For explanations, use statistical deviations
63
+ means = df[features].mean().values
64
+ stds = df[features].std().values
65
+ z_scores = (df[features].values - means) / (stds + 1e-6)
66
+ explanations = []
67
+ for i in range(seq_len):
68
+ if not is_anomaly[i]:
69
+ explanations.append("Normal")
70
+ else:
71
+ reasons = []
72
+ for j, feat in enumerate(features):
73
+ if abs(z_scores[i, j]) > 3:
74
+ direction = "High" if z_scores[i, j] > 0 else "Low"
75
+ reasons.append(f"{direction} {feat}")
76
+ exp = " and ".join(reasons) if reasons else "Unusual pattern detected by the model"
77
+ explanations.append(exp)
78
+ # Add columns to dataframe
79
+ df['Anomaly'] = ['Yes' if a else 'No' for a in is_anomaly]
80
+ df['Severity Score'] = severity.round(2)
81
+ df['Explanation'] = explanations
82
+ return df
83
+ except Exception as e:
84
+ return f"Error processing file: {str(e)}"
85
 
86
+ # Define Gradio interface
87
+ with gr.Blocks(title="Anomaly Detection in Smart Grid Sensor Data") as demo:
88
+ gr.Markdown("# Anomaly Detection in Smart Grid Sensor Data Using Transformers")
89
+ gr.Markdown("Upload a CSV file with columns: timestamp, voltage, current, frequency")
90
+ input_file = gr.File(label="Upload CSV file", file_types=[".csv"])
91
+ output_df = gr.Dataframe(label="Results")
92
+ btn = gr.Button("Detect Anomalies")
93
+ btn.click(process_csv, inputs=input_file, outputs=output_df)
 
94
 
95
+ # Launch the app
96
+ demo.launch()