#!/usr/bin/env python # coding: utf-8 # # Unsupervised classification of imagery using `scikit-learn` # # # This example shows how to classify imagery (for example from LANDSAT) using [`scikit-learn`](http://scikit-learn.org/stable/). There are many classification methods available, but for this example we will use [K-Means](http://scikit-learn.org/stable/modules/clustering.html#k-means) as it's simple and fast. For imagery I grabbed the [North Carolina dataset](http://grassbook.org/datasets/datasets-3rd-edition/) raster sample and I'm using the red, green and blue bands of the landsat 7 imagery within this pacakge. I'm using [`rasterio`](https://github.com/mapbox/rasterio) to read in the data. # In[4]: import numpy as np import matplotlib.pyplot as plt import rasterio import sklearn.cluster # In[5]: red_path = r"C:\projects\quick_scripts\gis_se\landsat\ncrast\lsat7_2002_30.tif" green_path = r"C:\projects\quick_scripts\gis_se\landsat\ncrast\lsat7_2002_20.tif" blue_path = r"C:\projects\quick_scripts\gis_se\landsat\ncrast\lsat7_2002_10.tif" # In[6]: with rasterio.open(red_path) as red, rasterio.open(green_path) as green, rasterio.open(blue_path) as blue: data = np.array([red.read(1), green.read(1), blue.read(1)]) data.shape # Note that this is a three band image giving us a 3 dimensional array. # In[7]: plt.figure(figsize=(12, 12)) plt.imshow(np.dstack(data)) # To work with data in `scikit-learn` it assumes that your data is in two dimensions, with each row being a sample and each column being a variable. As such you have to reshape your data accordingly. # In[8]: samples = data.reshape((3, -1)).T samples.shape # Next build the classifier with the number of groups we want to have, e.g. `4`. # In[9]: clf = sklearn.cluster.KMeans(n_clusters=4) # Finally we can classify the data. The results will be returned as a label array with the same length as the number of samples. This can be reshaped back to the original dimensions of the data. # In[10]: labels = clf.fit_predict(samples) labels.shape # In[11]: plt.figure(figsize=(12, 12)) plt.imshow(labels.reshape((475, 527)), cmap="Set3") # For a quick comparison of results, we can use a different clustering method, such as [Agglomorative clustering](http://scikit-learn.org/stable/modules/clustering.html#hierarchical-clustering), potentially with connectivity constraints. # In[12]: import sklearn.feature_extraction.image # In[13]: connectivity = sklearn.feature_extraction.image.grid_to_graph(475, 527) # In[14]: clf = sklearn.cluster.AgglomerativeClustering(n_clusters=4, connectivity=connectivity) # In[15]: labels = clf.fit_predict(samples) # In[16]: plt.figure(figsize=(12, 12)) plt.imshow(labels.reshape((475, 527)), cmap="Set3")