Let's find the dominant colors in images

In [1]:
# we will use the python imaging library to read color info from the image
from PIL import Image
In [2]:
# let's start by pulling down the image data
import urllib2
fh = urllib2.urlopen('http://media.charlesleifer.com/blog/photos/thumbnails/akira_650x650.jpg')
img_data = fh.read()
fh.close()
In [3]:
# what does img_data look like?  here are the first 10 bytes -- looks like a header and some null bytes
print img_data[:10]
����JFIF

In [4]:
# let's load up this image data
from StringIO import StringIO
img_buf = StringIO(img_data)
img = Image.open(img_buf)
In [5]:
# what is img? it should be a jpeg image file
print img
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=650x334 at 0x9B9CFCC>

What happened

We just pulled down some image data over the wire and created an Image object. This Image object gives us a nice, fast way to extract things like raw pixel data, which we will use to determine dominant colors.

Now that we have the image data, we will resize the image down to 200px on a side -- this makes calculations faster since we have less pixels to count.

In [6]:
# let's resize the image in C, this will make calculations faster and we won't lose much accuracy
img.thumbnail((200, 200))
In [7]:
# let's load up some modules which will be useful while we're extracting color info and clustering
from collections import namedtuple
import random

We want a nice way to represent the various color points in the image. I chose to use namedtuple, which has lower memory overhead than a python class. The point class will store coordinate data, the number of dimensions (always 3 in this case), and a count associated with the point.

The clusters will be a collection of points, and have the additional property of a "center".

In [8]:
# these classes will represent the data we extract -- I use namedtuples as they have lower memory overhead than full classes
Point = namedtuple('Point', ('coords', 'n', 'ct'))
Cluster = namedtuple('Cluster', ('points', 'center', 'n'))
In [9]:
# let's extract all the color points from the image -- the red/green/blue channels will be treated as points in a 3-dimensional space
def get_points(img):
    points = []
    w, h = img.size
    for count, color in img.getcolors(w * h):
        points.append(Point(color, 3, count))
    return points
In [10]:
img_points = get_points(img)
In [11]:
# when we're clustering we will need a way to find the distance between two points
def point_distance(p1, p2):
    return sum([
        (p1.coords[i] - p2.coords[i]) ** 2 for i in range(p1.n)
    ])
In [12]:
# we also need a way to calculate the center when given a cluster of points -- this is done
# by taking the average of the points across all dimensions
def calculate_center(points, n):
    vals = [0.0 for i in range(n)]
    plen = 0
    for p in points:
        plen += p.ct
        for i in range(n):
            vals[i] += (p.coords[i] * p.ct)
    return Point([(v / plen) for v in vals], n, 1)

What have we done so far

So far we have extracted the points from the image and created a few helper functions for things like calculating the center of a cluster of points and calculating the distance between two points.

The next step -- running the algorithm

The code can be found in the next cell, the algorithm is k-means

Our goal is to find where the points tend to form “clumps”. Since we want to group the numbers into k clusters, we’ll pick k points randomly from the data to use as the initial “clusters”.

We’ll iterate over every point in the data and calculate its distance to each of the k clusters. Find the nearest cluster and associate that point with the cluster. When you’ve iterated over all the points they should all be assigned to one of the clusters. Now, for each cluster recalculate its center by averaging the distances of all the associated points and start over.

When the centers stop moving very much we can stop looping. To find the dominant colors, simply take the centers of the clusters!

In [13]:
# finally, here is our algorithm -- 'kmeans'
def kmeans(points, k, min_diff):
    clusters = [Cluster([p], p, p.n) for p in random.sample(points, k)]
 
    while 1:
        plists = [[] for i in range(k)]
 
        for p in points:
            smallest_distance = float('Inf')
            for i in range(k):
                distance = point_distance(p, clusters[i].center)
                if distance < smallest_distance:
                    smallest_distance = distance
                    idx = i
            plists[idx].append(p)
 
        diff = 0
        for i in range(k):
            old = clusters[i]
            center = calculate_center(plists[i], old.n)
            new = Cluster(plists[i], center, old.n)
            clusters[i] = new
            diff = max(diff, point_distance(old.center, new.center))
 
        if diff < min_diff:
            break
 
    return clusters
In [14]:
print 'Calculating clusters -- this may take a few seconds'
clusters = kmeans(img_points, 3, 1) # run k-means on the color points, calculating 3 clusters (3 dominant colors), and stopping when our clusters move < 1 unit
rgbs = [map(int, c.center.coords) for c in clusters]
print 'Done'
Calculating clusters -- this may take a few seconds
Done

In [15]:
print rgbs
[[36, 24, 33], [86, 100, 112], [192, 87, 42]]

In [16]:
# let's create a function to convert RGBs into hex color code
rtoh = lambda rgb: '#%s' % ''.join(('%02x' % p for p in rgb))
In [17]:
color_codes = map(rtoh, rgbs)
print color_codes
['#241821', '#566470', '#c0572a']

In [18]:
# now, let's display those colors using HTML
from IPython.core.display import HTML
In [19]:
HTML('<div style="width: 40px; height: 40px; background-color: %s">&nbsp;</div>' % color_codes[0])
Out[19]:
 
In [20]:
HTML('<div style="width: 40px; height: 40px; background-color: %s">&nbsp;</div>' % color_codes[1])
Out[20]:
 
In [21]:
HTML('<div style="width: 40px; height: 40px; background-color: %s">&nbsp;</div>' % color_codes[2])
Out[21]:
 
In []:
 
Back to top