In this notebook we will use Spark's machine learning library MLlib to build a Decision Tree classifier for network attack detection. We will use the complete KDD Cup 1999 datasets in order to test Spark capabilities with large datasets.
Decision trees are a popular machine learning tool in part because they are easy to interpret, handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions. In this notebook, we will first train a classification tree including every single predictor. Then we will use our results to perform model selection. Once we find out the most important ones (the main splits in the tree) we will build a minimal tree using just three of them (the first two levels of the tree in order to compare performance and accuracy.
At the time of processing this notebook, our Spark cluster contains:
As we said, this time we will use the complete dataset provided for the KDD Cup 1999, containing nearly half million network interactions. The file is provided as a Gzip file that we will download locally.
import urllib
f = urllib.urlretrieve ("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data.gz", "kddcup.data.gz")
data_file = "./kddcup.data.gz"
raw_data = sc.textFile(data_file)
print "Train data size is {}".format(raw_data.count())
Train data size is 4898431
The KDD Cup 1999 also provide test data that we will load in a separate RDD.
ft = urllib.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/corrected.gz", "corrected.gz")
test_data_file = "./corrected.gz"
test_raw_data = sc.textFile(test_data_file)
print "Test data size is {}".format(test_raw_data.count())
Test data size is 311029
In this section we will train a classification tree that, as we did with logistic regression, will predict if a network interaction is either normal
or attack
.
Training a classification tree using MLlib requires some parameters:
In the next section we will see how to obtain all the labels within a dataset and convert them to numerical factors.
As we said, in order to benefits from trees ability to seamlessly with categorical variables, we need to convert them to numerical factors. But first we need to obtain all the possible levels. We will use set transformations on a csv parsed RDD.
from pyspark.mllib.regression import LabeledPoint
from numpy import array
csv_data = raw_data.map(lambda x: x.split(","))
test_csv_data = test_raw_data.map(lambda x: x.split(","))
protocols = csv_data.map(lambda x: x[1]).distinct().collect()
services = csv_data.map(lambda x: x[2]).distinct().collect()
flags = csv_data.map(lambda x: x[3]).distinct().collect()
And now we can use this Python lists in our create_labeled_point
function. If a factor level is not in the training data, we assign an especial level. Remember that we cannot use testing data for training our model, not even the factor levels. The testing data represents the unknown to us in a real case.
def create_labeled_point(line_split):
# leave_out = [41]
clean_line_split = line_split[0:41]
# convert protocol to numeric categorical variable
try:
clean_line_split[1] = protocols.index(clean_line_split[1])
except:
clean_line_split[1] = len(protocols)
# convert service to numeric categorical variable
try:
clean_line_split[2] = services.index(clean_line_split[2])
except:
clean_line_split[2] = len(services)
# convert flag to numeric categorical variable
try:
clean_line_split[3] = flags.index(clean_line_split[3])
except:
clean_line_split[3] = len(flags)
# convert label to binary label
attack = 1.0
if line_split[41]=='normal.':
attack = 0.0
return LabeledPoint(attack, array([float(x) for x in clean_line_split]))
training_data = csv_data.map(create_labeled_point)
test_data = test_csv_data.map(create_labeled_point)
We are now ready to train our classification tree. We will keep the maxDepth
value small. This will lead to smaller accuracy, but we will obtain less splits so later on we can better interpret the tree. In a production system we will try to increase this value in order to find a better accuracy.
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from time import time
# Build the model
t0 = time()
tree_model = DecisionTree.trainClassifier(training_data, numClasses=2,
categoricalFeaturesInfo={1: len(protocols), 2: len(services), 3: len(flags)},
impurity='gini', maxDepth=4, maxBins=100)
tt = time() - t0
print "Classifier trained in {} seconds".format(round(tt,3))
Classifier trained in 439.971 seconds
In order to measure the classification error on our test data, we use map
on the test_data
RDD and the model to predict each test point class.
predictions = tree_model.predict(test_data.map(lambda p: p.features))
labels_and_preds = test_data.map(lambda p: p.label).zip(predictions)
Classification results are returned in pars, with the actual test label and the predicted one. This is used to calculate the classification error by using filter
and count
as follows.
t0 = time()
test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(test_data.count())
tt = time() - t0
print "Prediction made in {} seconds. Test accuracy is {}".format(round(tt,3), round(test_accuracy,4))
Prediction made in 38.651 seconds. Test accuracy is 0.9155
NOTE: the zip transformation doesn't work properly with pySpark 1.2.1. It does in 1.3
Understanding our tree splits is a great exercise in order to explain our classification labels in terms of predictors and the values they take. Using the toDebugString
method in our three model we can obtain a lot of information regarding splits, nodes, etc.
print "Learned classification tree model:"
print tree_model.toDebugString()
Learned classification tree model: DecisionTreeModel classifier of depth 4 with 27 nodes If (feature 22 <= 89.0) If (feature 3 in {2.0,3.0,4.0,7.0,9.0,10.0}) If (feature 36 <= 0.43) If (feature 28 <= 0.19) Predict: 1.0 Else (feature 28 > 0.19) Predict: 0.0 Else (feature 36 > 0.43) If (feature 2 in {0.0,3.0,15.0,26.0,27.0,36.0,42.0,58.0,67.0}) Predict: 0.0 Else (feature 2 not in {0.0,3.0,15.0,26.0,27.0,36.0,42.0,58.0,67.0}) Predict: 1.0 Else (feature 3 not in {2.0,3.0,4.0,7.0,9.0,10.0}) If (feature 2 in {50.0,51.0}) Predict: 0.0 Else (feature 2 not in {50.0,51.0}) If (feature 32 <= 168.0) Predict: 1.0 Else (feature 32 > 168.0) Predict: 0.0 Else (feature 22 > 89.0) If (feature 5 <= 0.0) If (feature 11 <= 0.0) If (feature 31 <= 253.0) Predict: 1.0 Else (feature 31 > 253.0) Predict: 1.0 Else (feature 11 > 0.0) If (feature 2 in {12.0}) Predict: 0.0 Else (feature 2 not in {12.0}) Predict: 1.0 Else (feature 5 > 0.0) If (feature 29 <= 0.08) If (feature 2 in {3.0,4.0,26.0,36.0,42.0,58.0,68.0}) Predict: 0.0 Else (feature 2 not in {3.0,4.0,26.0,36.0,42.0,58.0,68.0}) Predict: 1.0 Else (feature 29 > 0.08) Predict: 1.0
For example, a network interaction with the following features (see description here) will be classified as an attack by our model:
count
, the number of connections to the same host as the current connection in the past two seconds, being greater than 32.dst_bytes
, the number of data bytes from destination to source, is 0.service
is neither level 0 nor 52.logged_in
is false.From our services list we know that:
print "Service 0 is {}".format(services[0])
print "Service 52 is {}".format(services[52])
Service 0 is urp_i Service 52 is tftp_u
So we can characterise network interactions with more than 32 connections to the same server in the last 2 seconds, transferring zero bytes from destination to source, where service is neither urp_i nor tftp_u, and not logged in, as network attacks. A similar approach can be used for each tree terminal node.
We can see that count
is the first node split in the tree. Remember that each partition is chosen greedily by selecting the best split from a set of possible splits, in order to maximize the information gain at a tree node (see more here). At a second level we find variables flag
(normal or error status of the connection) and dst_bytes
(the number of data bytes from destination to source) and so on.
This explaining capability of a classification (or regression) tree is one of its main benefits. Understaining data is a key factor to build better models.
So now that we know the main features predicting a network attack, thanks to our classification tree splits, let's use them to build a minimal classification tree with just the main three variables: count
, dst_bytes
, and flag
.
We need to define the appropriate function to create labeled points.
def create_labeled_point_minimal(line_split):
# leave_out = [41]
clean_line_split = line_split[3:4] + line_split[5:6] + line_split[22:23]
# convert flag to numeric categorical variable
try:
clean_line_split[0] = flags.index(clean_line_split[0])
except:
clean_line_split[0] = len(flags)
# convert label to binary label
attack = 1.0
if line_split[41]=='normal.':
attack = 0.0
return LabeledPoint(attack, array([float(x) for x in clean_line_split]))
training_data_minimal = csv_data.map(create_labeled_point_minimal)
test_data_minimal = test_csv_data.map(create_labeled_point_minimal)
That we use to train the model.
# Build the model
t0 = time()
tree_model_minimal = DecisionTree.trainClassifier(training_data_minimal, numClasses=2,
categoricalFeaturesInfo={0: len(flags)},
impurity='gini', maxDepth=3, maxBins=32)
tt = time() - t0
print "Classifier trained in {} seconds".format(round(tt,3))
Classifier trained in 226.519 seconds
Now we can predict on the testing data and calculate accuracy.
predictions_minimal = tree_model_minimal.predict(test_data_minimal.map(lambda p: p.features))
labels_and_preds_minimal = test_data_minimal.map(lambda p: p.label).zip(predictions_minimal)
t0 = time()
test_accuracy = labels_and_preds_minimal.filter(lambda (v, p): v == p).count() / float(test_data_minimal.count())
tt = time() - t0
print "Prediction made in {} seconds. Test accuracy is {}".format(round(tt,3), round(test_accuracy,4))
Prediction made in 23.202 seconds. Test accuracy is 0.909
So we have trained a classification tree with just the three most important predictors, in half of the time, and with a not so bad accuracy. In fact, a classification tree is a very good model selection tool!