-
Notifications
You must be signed in to change notification settings - Fork 0
/
sdn.py
296 lines (250 loc) · 12.7 KB
/
sdn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import sys
import pickle
import numpy as np
import pandas as pd
from ryu.base import app_manager
from ryu.controller import ofp_event
from ryu.controller.handler import CONFIG_DISPATCHER, MAIN_DISPATCHER, set_ev_cls
from ryu.ofproto import ofproto_v1_3
from ryu.lib.packet import packet, ethernet, ipv4, tcp, udp
import time
import json
import logging
import requests # Use HTTP POST to send model updates to the server
from server import FederatedAggregator
import joblib
import os
import random
# Set up logging for better debug information
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Define the IP address of the federated server for local testing
FED_SERVER_IP = '127.0.0.1' # Use localhost for testing
FED_SERVER_PORT = 8000 # Port for the server to listen
# Increase recursion limit to avoid RecursionError
sys.setrecursionlimit(10000)
# Full label mapping for attack types as used in training
label_dict = {
'Benign': 0,
'FTP-BruteForce': 1,
'SSH-Bruteforce': 2,
'DDOS attack-HOIC': 3,
'Bot': 4,
'DoS attacks-GoldenEye': 5,
'DoS attacks-Slowloris': 6,
'DDOS attack-LOIC-UDP': 7,
'Brute Force -Web': 8,
'Brute Force -XSS': 9,
'SQL Injection': 10
}
class SimpleSwitch13(app_manager.RyuApp):
OFP_VERSIONS = [ofproto_v1_3.OFP_VERSION]
def __init__(self, *args, **kwargs):
super(SimpleSwitch13, self).__init__(*args, **kwargs)
self.mac_to_port = {}
# Load models
self.rf_model = joblib.load('rf_model.joblib')
self.xgb_model = joblib.load('xgb_model.joblib')
self.orf_model = joblib.load('orf_model.joblib')
self.ht_model = joblib.load('ht_model.joblib')
# Initialize federated aggregator for RF and XGB models
self.rf_aggregator = FederatedAggregator('rf')
self.xgb_aggregator = FederatedAggregator('xgb')
self.enable_prediction = True # Enable prediction during control traffic
# Known attack IDs (1-11), with 1 being benign
self.known_attack_ids = list(range(1, 12))
# Features used in the model
self.features = [
'Tot Fwd Pkts', 'TotLen Fwd Pkts', 'Bwd Pkt Len Max', 'Flow Pkts/s',
'Fwd IAT Mean', 'Bwd IAT Tot', 'Bwd IAT Mean', 'RST Flag Cnt',
'URG Flag Cnt', 'Init Fwd Win Byts', 'Fwd Seg Size Min', 'Idle Max'
]
# Local buffer for data aggregation before sending updates to the server
self.local_data_buffer_rf = []
self.local_data_buffer_xgb = []
self.local_data_buffer_size = 100 # Buffer size threshold for sending models
# Attack simulation counters
self.attack_index = label_dict['FTP-BruteForce'] # Start from FTP-BruteForce
self.current_attack_count = 0
self.attack_counts = [
30, # FTP-BruteForce
40, # SSH-Bruteforce
100, # DDOS-HOIC
20, # Bot
5, # GoldenEye
10, # Slowloris
50, # LOIC-UDP
25, # Web Brute Force
15, # XSS
20 # SQL Injection
]
self.attack_names = list(label_dict.keys())[1:] # Skip 'Benign' for attack names
# Add a counter for normal traffic logs
self.normal_log_count = 0
self.max_normal_logs = 10 # Number of normal logs to display first
def extract_features(self, payload):
try:
features = json.loads(payload.decode('utf-8'))
packet_id = features.get('id') # Extract 'id' for checking
if packet_id is None:
self.logger.error("Packet ID is missing.")
return None, None
# Ensure all expected features are present
if all(feature in features for feature in self.features):
return features, packet_id
else:
self.logger.error("Some features are missing from the payload.")
return None, packet_id
except Exception as e:
self.logger.error(f"Error extracting features from packet: {e}")
return None, None
@set_ev_cls(ofp_event.EventOFPSwitchFeatures, CONFIG_DISPATCHER)
def switch_features_handler(self, ev):
datapath = ev.msg.datapath
ofproto = datapath.ofproto
parser = datapath.ofproto_parser
match = parser.OFPMatch()
actions = [parser.OFPActionOutput(ofproto.OFPP_CONTROLLER, ofproto.OFPCML_NO_BUFFER)]
self.add_flow(datapath, 0, match, actions)
self.logger.info("Switch setup complete. Prediction enabled.")
def add_flow(self, datapath, priority, match, actions, buffer_id=None):
ofproto = datapath.ofproto
parser = datapath.ofproto_parser
inst = [parser.OFPInstructionActions(ofproto.OFPIT_APPLY_ACTIONS, actions)]
if buffer_id:
mod = parser.OFPFlowMod(datapath=datapath, buffer_id=buffer_id, priority=priority,
match=match, instructions=inst)
else:
mod = parser.OFPFlowMod(datapath=datapath, priority=priority,
match=match, instructions=inst)
datapath.send_msg(mod)
@set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER)
def _packet_in_handler(self, ev):
msg = ev.msg
datapath = msg.datapath
ofproto = datapath.ofproto
parser = datapath.ofproto_parser
in_port = msg.match['in_port']
pkt = packet.Packet(msg.data)
eth = pkt.get_protocols(ethernet.ethernet)[0]
dst = eth.dst
src = eth.src
dpid = datapath.id
self.mac_to_port.setdefault(dpid, {})
self.mac_to_port[dpid][src] = in_port
if dst in self.mac_to_port[dpid]:
out_port = self.mac_to_port[dpid][dst]
else:
out_port = ofproto.OFPP_FLOOD
actions = [parser.OFPActionOutput(out_port)]
ipv4_pkt = pkt.get_protocol(ipv4.ipv4)
tcp_pkt = pkt.get_protocol(tcp.tcp)
udp_pkt = pkt.get_protocol(udp.udp)
raw_payload = None
if tcp_pkt or udp_pkt:
raw_payload = pkt[-1] if isinstance(pkt[-1], bytes) else None
if raw_payload:
features, packet_id = self.extract_features(raw_payload)
if packet_id is not None:
# If the packet is normal traffic (ID=1)
if packet_id == 1:
if self.normal_log_count < self.max_normal_logs:
self.logger.info(f"Normal traffic from {src}.")
if features:
features_df = pd.DataFrame([features], columns=self.features)
#self.logger.info(f"Packet Features Extracted: {features_df}")
self.normal_log_count += 1 # Increment normal traffic counter
return # Skip processing attacks until normal logs are complete
# Process attacks only after normal logs are displayed
if self.normal_log_count >= self.max_normal_logs:
if packet_id not in self.known_attack_ids:
self.logger.info(f"Unknown (zero-day) attack detected: Packet ID {packet_id}")
if features:
features_df = pd.DataFrame([features], columns=self.features)
# Learn from zero-day attack, labeled as -1
self.ht_model.learn_one(features, -1)
self.orf_model.learn_one(features, -1)
self.local_data_buffer_rf.append((features_df, -1)) # Buffer as -1 for zero-day attack
else:
# Attack processing logic
if features:
features_df = pd.DataFrame([features], columns=self.features)
self.logger.info(f"Attack detected from {src}, blocking traffic.")
self.logger.info(f"Attack type: {self.attack_names[self.attack_index - 1]}")
rf_pred = self.rf_model.predict(features_df)[0]
xgb_pred = self.xgb_model.predict(features_df)[0]
orf_pred = self.orf_model.predict_one(features)
ht_pred = self.ht_model.predict_one(features)
self.local_data_buffer_rf.append((features_df, label_dict[self.attack_names[self.attack_index - 1]]))
#self.logger.info(f"Packet Features Extracted: {features_df}")
max_pred = max(rf_pred, xgb_pred, ht_pred, orf_pred)
self.ht_model.learn_one(features, max_pred)
self.orf_model.learn_one(features, max_pred)
self.block_traffic(datapath, src)
self.current_attack_count += 1
attack_name = self.attack_names[self.attack_index - 1] # Adjusted indexing
if self.current_attack_count <= self.attack_counts[self.attack_index - 1]:
self.logger.info(f"Attack type: {attack_name}")
if self.current_attack_count == self.attack_counts[self.attack_index - 1]:
self.attack_index += 1
self.current_attack_count = 0
if self.attack_index > len(self.attack_counts):
self.attack_index = 1
# Federated learning: Buffer data for RF and XGB
self.local_data_buffer_rf.append((features_df, label_dict[self.attack_names[self.attack_index - 1]]))
self.local_data_buffer_xgb.append((features_df, label_dict[self.attack_names[self.attack_index - 1]]))
# Aggregate and update models periodically
if len(self.local_data_buffer_rf) >= self.local_data_buffer_size:
self.update_models()
data = None
if msg.buffer_id == ofproto.OFP_NO_BUFFER:
data = msg.data
out = parser.OFPPacketOut(datapath=datapath, buffer_id=msg.buffer_id,
in_port=in_port, actions=actions, data=data)
datapath.send_msg(out)
def update_models(self):
"""Helper function to aggregate and send model updates."""
if self.local_data_buffer_rf:
X_rf, y_rf = zip(*self.local_data_buffer_rf)
X_rf = pd.concat(X_rf)
self.rf_model.fit(X_rf, y_rf) # Local training on buffered data
# Send model to the server
self.send_model_to_server(self.rf_model, 'rf')
self.local_data_buffer_rf = [] # Clear buffer after update
self.logger.info("RandomForest model sent to server via federated learning")
#if self.local_data_buffer_xgb:
# X_xgb, y_xgb = zip(*self.local_data_buffer_xgb)
# X_xgb = pd.concat(X_xgb)
# self.xgb_model.fit(X_xgb, y_xgb) # Local training on buffered data
# Send model to the server
#self.send_model_to_server(self.xgb_model, 'xgb')
#self.local_data_buffer_xgb = [] # Clear buffer after update
self.logger.info("XGBoost model sent to server via federated learning")
def send_model_to_server(self, model, model_type):
"""
Save the trained model to a file and send the file to the server via HTTP POST.
"""
try:
# Save the model to a temporary file
model_filename = f"/tmp/{model_type}_model.joblib"
joblib.dump(model, model_filename)
# Send the file to the server
files = {'model': open(model_filename, 'rb')}
data = {'model_type': model_type}
response = requests.post(f'http://{FED_SERVER_IP}:{FED_SERVER_PORT}/update_model', files=files, data=data)
if response.status_code == 200:
self.logger.info(f"{model_type} model sent successfully to the server.")
else:
self.logger.error(f"Failed to send {model_type} model to the server. Status Code: {response.status_code}")
except Exception as e:
self.logger.error(f"Error while sending {model_type} model to server: {e}")
finally:
# Clean up: remove the temporary model file after sending
if os.path.exists(model_filename):
os.remove(model_filename)
def block_traffic(self, datapath, mac):
ofproto = datapath.ofproto
parser = datapath.ofproto_parser
match = parser.OFPMatch(eth_src=mac)
actions = []
self.add_flow(datapath, 1, match, actions)
self.logger.info('Blocking traffic from %s', mac)