This repository contains code for detecting IoT botnet attacks (specifically Mirai) using a GraphSAGE model. The approach involves representing IoT device data as a graph and leveraging the power of Graph Neural Networks (GNNs) for classification.
The project follows these key steps:
- Data Loading and Preprocessing: Benign and Mirai attack traffic data (in CSV format) are loaded, concatenated, and preprocessed. This includes handling missing values, scaling numerical features, and separating timestamps and labels.
- Graph Construction: The preprocessed time-series data is transformed into a graph structure where each row (representing a snapshot of device activity) becomes a node. Edges are created between consecutive nodes to capture temporal relationships.
- GraphSAGE Model Implementation: A two-layer GraphSAGE model is implemented using
torch_geometricfor node classification. The model learns embeddings for each node by aggregating information from its neighbors. - Training and Evaluation: The GraphSAGE model is trained on the constructed graph using a supervised learning approach. Training involves optimizing cross-entropy loss. The model's performance is evaluated using metrics like accuracy, precision, recall, F1-score, confusion matrix, and ROC curve.
- Early Stopping: To prevent overfitting and ensure optimal model performance, early stopping is implemented based on validation accuracy.
The model utilizes two types of IoT network traffic data:
- Benign: Normal IoT device network behavior.
- Mirai: Network traffic generated during a Mirai botnet attack.
These datasets are expected to be in CSV format within benign/ and mirai/ directories, respectively. The data is preprocessed to handle Timestamp columns, replace infinite values, drop columns with NaNs, and scale numerical features using StandardScaler.
To run the code, you'll need to install the following libraries:
!pip install torch torch-geometric scikit-learn pandas networkx
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html -q
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu124.html -qNote: The torch-scatter and torch-sparse installations are specific to torch==2.6.0+cu124. Adjust the URL if you are using a different PyTorch version.
The main script performs the following actions:
- Imports: Necessary libraries like
torch,torch_geometric,sklearn, andpandasare imported. - Google Drive Mounting: If running in Google Colab, it mounts Google Drive to access datasets.
- Data Loading and Preprocessing:
- Loads CSV files from specified
benign_dirandmirai_dir. - Combines datasets and adds a
labelcolumn (0 for benign, 1 for Mirai). - Handles
infvalues, drops columns withNaNs, and scales numerical features.
- Loads CSV files from specified
- Graph Creation:
- Converts scaled features and labels into PyTorch tensors.
- Constructs a simple sequential graph where each node is connected to its immediate successor and predecessor, forming
edge_index. - Creates a
torch_geometric.data.Dataobject and saves it asiot_data_graph.pth.
- Model Definition (
GraphSAGE):- Defines a two-layer GraphSAGE model with
SAGEConvlayers,BatchNorm1d,ReLUactivations, and dropout. - Includes a final linear layer for classification.
- Defines a two-layer GraphSAGE model with
- Training and Evaluation Functions:
train(model, data, optimizer, criterion): Performs one training step.evaluate(model, data, verbose=False): Evaluates the model on the test set, printing accuracy and a detailed classification report.evaluate_model(model, data): Provides comprehensive evaluation including accuracy, precision, recall, F1-score, confusion matrix, and ROC curve plotting.plot_roc_curve(model, data): Specifically plots the ROC curve for binary classification.
- Training Loop:
- Initializes the
GraphSAGEmodel,Adamoptimizer, andCrossEntropyLoss(with class weights for imbalance). - Splits data into training, validation, and test sets using
train_test_split. - Trains the model for a specified number of epochs with early stopping based on validation accuracy.
- Saves the best performing model.
- Initializes the
- Results Visualization:
- Plots training loss and test accuracy over epochs.
- Displays the confusion matrix and ROC curve after final evaluation.
- Prepare your data: Ensure your benign and Mirai attack CSV files are organized in
/content/drive/My Drive/IOT/benign/and/content/drive/My Drive/IOT/mirai/respectively, or update thebenign_dirandmirai_dirvariables in the script. - Run the notebook/script: Execute the code cells sequentially. The script will:
- Load and preprocess the data.
- Construct the graph.
- Train the GraphSAGE model.
- Evaluate its performance with various metrics and visualizations.
- Visual Interface: The dashboard helps to visualize the data fed to the global (fusion classifier) and attack - specific models for viewing class probabilities, graph plot visualization and accuracy metrics, confidence scores of both models and the probable reason behind the respective model's classification.
The script will output:
- Shape and label counts of the combined dataset.
- Details about feature columns used.
- Dimensions of feature (
x) and label (y) tensors. - Total number of nodes and edges in the constructed graph.
- Training loss and test accuracy for each epoch.
- A detailed classification report on the test set (accuracy, precision, recall, F1-score).
- Plots of training loss and test accuracy over epochs.
- A confusion matrix visualizing classification performance.
- An ROC curve with AUC score.
The model achieved strong performance on phishing detection:
| Metric | Value |
|---|---|
| Accuracy | 96.12% |
| Precision | 96.11% |
| Recall | 96.12% |
| F1-Score | 96.11% |
| AUC-ROC | 0.9935 |
The project relies on the following key libraries:
Python 3.x torch (PyTorch) torch-geometric (PyG) torch-scatter pandas numpy scikit-learn matplotlib gradio
git clone https://github.com/spk-22/BotNet-Insightpip install -r requirements.txt
# (Or manually install: torch, torch-geometric, scikit-learn, pandas, numpy, matplotlib)
# Ensure torch-geometric, torch-scatter, and torch-sparse versions are compatible with your PyTorch version.python iot_final.pystreamlit run web_app.py- Explore more sophisticated graph construction methods that capture richer relationships (e.g., based on IP addresses, port numbers, or specific flow features).
- Implement more advanced Graph Neural Network architectures (e.g., GAT, GCN with attention mechanisms).
- Investigate the impact of different hyperparameter settings on model performance.
- Consider causal sampling or sliding window approaches for creating graph snapshots to better capture temporal dynamics in larger datasets.
- Evaluate the model on a wider range of IoT botnet attacks and different IoT device types.