Commit
·
f10a6b7
1
Parent(s):
994be6c
Update model_training.py
Browse files- model_training.py +42 -0
model_training.py
CHANGED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model_training.py
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
from watermarking_functions import embed_watermark_LSB
|
| 5 |
+
|
| 6 |
+
# Sample data for text classification (replace with your data)
|
| 7 |
+
texts = [
|
| 8 |
+
"This is a positive statement.",
|
| 9 |
+
"I love working on machine learning projects.",
|
| 10 |
+
# Add more texts for training
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
# Labels: 0 - Negative sentiment, 1 - Positive sentiment
|
| 14 |
+
labels = [1, 1] # Sample labels (binary classification)
|
| 15 |
+
|
| 16 |
+
# Tokenizing and preparing the data
|
| 17 |
+
max_words = 1000
|
| 18 |
+
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=max_words)
|
| 19 |
+
tokenizer.fit_on_texts(texts)
|
| 20 |
+
sequences = tokenizer.texts_to_sequences(texts)
|
| 21 |
+
data = tf.keras.preprocessing.sequence.pad_sequences(sequences)
|
| 22 |
+
|
| 23 |
+
# Define the model architecture (simple example)
|
| 24 |
+
model = tf.keras.Sequential([
|
| 25 |
+
tf.keras.layers.Embedding(max_words, 16),
|
| 26 |
+
tf.keras.layers.GlobalAveragePooling1D(),
|
| 27 |
+
tf.keras.layers.Dense(1, activation='sigmoid')
|
| 28 |
+
])
|
| 29 |
+
|
| 30 |
+
# Compile the model
|
| 31 |
+
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
| 32 |
+
|
| 33 |
+
# Train the model
|
| 34 |
+
model.fit(data, labels, epochs=10, batch_size=32)
|
| 35 |
+
|
| 36 |
+
# Save the trained model
|
| 37 |
+
model.save('text_classification_model.h5')
|
| 38 |
+
|
| 39 |
+
# Embed watermark into the trained model
|
| 40 |
+
watermark_data = "MyWatermark" # Replace with your watermark data
|
| 41 |
+
model_with_watermark = embed_watermark_LSB(model, watermark_data)
|
| 42 |
+
model_with_watermark.save('text_classification_model_with_watermark.h5')
|