%matplotlib inline
import sys
sys.path.insert(0,'../src')
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.mlab as mlab
from sklearn import mixture
import SimpleITK as sitk
from imagedisplay import myshow
For more information please refer to:
Section II-D (Initialization with a Digital Brain Atlas) in http://ieeexplore.ieee.org/xpl/articleDetails.jsp?arnumber=811268
Section Probabilistic Atlas in http://radiology.ucsf.edu/sites/all/files/filemanager/research/Baby_Brain_Web_Site/fulltext-1.pdf
max_int_val = 512
# Read the mask image (where the labels are defined)
maskfilename = './data/atlas/mask.nii.gz'
mask_data = sitk.GetArrayFromImage(sitk.ReadImage(maskfilename)).flatten()
# Read the image labels (posterior probabilities for given pixel location)
label1filename = "./data/atlas/wm_m.nii.gz"
label1_d = sitk.GetArrayFromImage(sitk.ReadImage(label1filename)).flatten()
label2filename = "./data/atlas/csf_m.nii.gz"
label2_d = sitk.GetArrayFromImage(sitk.ReadImage(label2filename)).flatten()
label3filename = "./data/atlas/gm_m.nii.gz"
label3_d = sitk.GetArrayFromImage(sitk.ReadImage(label3filename)).flatten()
label1_d = label1_d[mask_data>0]
label2_d = label2_d[mask_data>0]
label3_d = label3_d[mask_data>0]
labels_d = np.column_stack ([label1_d,label2_d,label3_d])
# Read the affine aligned and transformed test image ( spatially normalized to atlas space )
studyfilename = "./data/test.nii.gz"
study_img = sitk.ReadImage(studyfilename)
study_img = sitk.Cast(study_img, sitk.sitkFloat32)
study_img = sitk.RescaleIntensity(study_img,0.0,max_int_val)
image_data = sitk.GetArrayFromImage(study_img)
nparray_size = image_data.shape
image_data = image_data.flatten()
image_data = image_data[mask_data>0]
# Compute the initial values for EM Segmentation
# The class-specific distribution parameters are computed from the study image,
# using the registered and reformatted a (priori pixel class probability) maps provided by the atlas.
N_ = np.array([1.0,1.0,1.0])
mu = np.array([0.0,0.0,0.0])
sig = np.array([1.0,1.0,1.0])
pi_ = np.array([1.0,1.0,1.0])
for k in range(0,3):
N_[k] = 1.*labels_d[:,k].sum()
mu[k] = sum(labels_d[:,k]*image_data)/N_[k]
sig[k] = np.sqrt( sum(labels_d[:,k]*(image_data-mu[k])**2)/N_[k] )
#pi_[k] = N_[k]/image_data.size (in standard EM algorithm)
# ==== EM With Spatial Priors ====
# Define the mixture model
def pdf_model(image_data, mu, sig, labels_d):
output = np.zeros((image_data.size))
for l in range(0,3):
output += labels_d[:,l]*mlab.normpdf(image_data, mu[l], sig[l])
return output + 1e-6 # to avoid division by zero
# Perform standard em algorithm and initializations
max_iter = 50
gamma = np.zeros((3, image_data.size))
N_ = np.zeros(3)
N = image_data.size
mu_new = mu
sig_new = sig
# During subsequent iterations, the atlas is further used to spatially constrain the classification
# by setting the a priori class probability p(k_i=j) in (2) equal to p(k_i=j|x_i). EM loop
counter = 0
converged = False
while not converged:
for k in range(0,3):
gamma[k,:] = labels_d[:,k]*mlab.normpdf(image_data, mu[k], sig[k])/pdf_model(image_data,
mu_new, sig_new, labels_d)
N_[k] = 1.*gamma[k,:].sum()
mu[k] = sum(gamma[k,:]*image_data)/N_[k]
sig[k] = np.sqrt( sum(gamma[k,:]*(image_data-mu[k])**2)/N_[k] )
pi_[k] = N_[k]/image_data.size
mu_new = mu
sig_new = sig
assert abs(N_.sum() - N)/float(N) < 1e-2
# Convergence check
counter += 1
converged = counter >= max_iter
# Prob. distribution function of each gaussian component
def plot_pdf_models_2(x,mu,si,we):
for ind in range(0,we.shape[0]):
plt.plot(x,we[ind]*mlab.normpdf(x, mu[ind], si[ind]),linewidth=4)
# Plot the computed Gaussian mixture model together with intensity histogram
x = np.linspace(0,max_int_val,500)
plt.figure(figsize=(16, 5), dpi=100)
plot_pdf_models_2(x,mu,sig,pi_)
plt.hist(image_data.flatten(), bins=max_int_val/4, range=(0, max_int_val), normed=True, color='m')
plt.title('Class specific probability distribution functions (After Convergence)',fontsize=20)
plt.show()
# === Prediction ===
segm_data = np.zeros(nparray_size).flatten()
segm_data[mask_data==0] = 5.0 # background pixels
segm_data[mask_data>0] = np.argmax(gamma,axis=0)+1
segm_data = segm_data.reshape(nparray_size)
segm_img = sitk.GetImageFromArray(segm_data)
sitk.Image.CopyInformation(segm_img,study_img)
# display the labels and the original image
study_img = sitk.Cast(study_img, sitk.sitkUInt16)
segm_img = sitk.Cast(segm_img, sitk.sitkUInt16)
imageSize = study_img.GetSize()
slices =[study_img[imageSize[0]/2,:,::-1], study_img[:,imageSize[1]/2,::-1], study_img[:,:,90]]
myshow(sitk.Tile(slices, [3,1]), dpi=20)
slices =[segm_img[imageSize[0]/2,:,::-1], segm_img[:,imageSize[1]/2,::-1], segm_img[:,:,90]]
myshow(sitk.LabelToRGB(sitk.Tile(slices, [3,1])), dpi=20)
# Read the mask image
maskfilename = './data/atlas/mask.nii.gz'
mask_data = sitk.GetArrayFromImage(sitk.ReadImage(maskfilename)).flatten()
# Read the brain image
max_int_val = 512;
imagefilename = "./data/test.nii.gz"
image = sitk.ReadImage(imagefilename)
image = sitk.Cast(image, sitk.sitkFloat32 )
image = sitk.RescaleIntensity(image,0.0,max_int_val)
image_data = sitk.GetArrayFromImage(image)
input_shape = image_data.shape
image_data = image_data.flatten()
# Compute the parameters for the Gaussian Mixture model
np.random.seed(1)
g = mixture.GMM(n_components=3, covariance_type='diag', thresh=0.01, n_iter=50,
n_init=1, params='wmc', init_params='wmc')
# Estimate model parameters with the expectation-maximization algorithm.
image_data = image_data[mask_data>0]
g.fit(image_data)
# Predict the labels
segm_data = np.zeros(input_shape).flatten()
segm_data[mask_data==0] = 5.0; # background pixels
segm_data[mask_data>0] = g.predict(image_data)+1
segm_data = segm_data.reshape(input_shape)
segm_img = sitk.GetImageFromArray(segm_data)
sitk.Image.CopyInformation(segm_img,image)
# display the labels and the original image
imageSize = study_img.GetSize()
segm_img = sitk.Cast(segm_img, sitk.sitkUInt16)
slices =[image[imageSize[0]/2,:,::-1], image[:,imageSize[1]/2,::-1], image[:,:,90]]
myshow(sitk.Tile(slices, [3,1]), dpi=20)
slices =[segm_img[imageSize[0]/2,:,::-1], segm_img[:,imageSize[1]/2,::-1], segm_img[:,:,90]]
myshow(sitk.LabelToRGB(sitk.Tile(slices, [3,1])), dpi=20)