scikit-learn
¶This example shows how to classify imagery (for example from LANDSAT) using scikit-learn
. There are many classification methods available, but for this example we will use K-Means as it's simple and fast. For imagery I grabbed the North Carolina dataset raster sample and I'm using the red, green and blue bands of the landsat 7 imagery within this pacakge. I'm using rasterio
to read in the data.
import numpy as np
import matplotlib.pyplot as plt
import rasterio
import sklearn.cluster
red_path = r"C:\projects\quick_scripts\gis_se\landsat\ncrast\lsat7_2002_30.tif"
green_path = r"C:\projects\quick_scripts\gis_se\landsat\ncrast\lsat7_2002_20.tif"
blue_path = r"C:\projects\quick_scripts\gis_se\landsat\ncrast\lsat7_2002_10.tif"
with rasterio.open(red_path) as red, rasterio.open(green_path) as green, rasterio.open(blue_path) as blue:
data = np.array([red.read(1), green.read(1), blue.read(1)])
data.shape # Note that this is a three band image giving us a 3 dimensional array.
(3L, 475L, 527L)
plt.figure(figsize=(12, 12))
plt.imshow(np.dstack(data))
<matplotlib.image.AxesImage at 0x155ef898>
To work with data in scikit-learn
it assumes that your data is in two dimensions, with each row being a sample and each column being a variable. As such you have to reshape your data accordingly.
samples = data.reshape((3, -1)).T
samples.shape
(250325L, 3L)
Next build the classifier with the number of groups we want to have, e.g. 4
.
clf = sklearn.cluster.KMeans(n_clusters=4)
Finally we can classify the data. The results will be returned as a label array with the same length as the number of samples. This can be reshaped back to the original dimensions of the data.
labels = clf.fit_predict(samples)
labels.shape
(250325L,)
plt.figure(figsize=(12, 12))
plt.imshow(labels.reshape((475, 527)), cmap="Set3")
<matplotlib.image.AxesImage at 0x15d98470>
For a quick comparison of results, we can use a different clustering method, such as Agglomorative clustering, potentially with connectivity constraints.
import sklearn.feature_extraction.image
connectivity = sklearn.feature_extraction.image.grid_to_graph(475, 527)
clf = sklearn.cluster.AgglomerativeClustering(n_clusters=4, connectivity=connectivity)
labels = clf.fit_predict(samples)
plt.figure(figsize=(12, 12))
plt.imshow(labels.reshape((475, 527)), cmap="Set3")
<matplotlib.image.AxesImage at 0x15dc4080>