%matplotlib inline import sys sys.path.insert(0,'../src') import numpy as np from matplotlib import pyplot as plt import matplotlib.mlab as mlab from sklearn import mixture import SimpleITK as sitk from imagedisplay import myshow def my_em_segmentation(imagefilename, plot_dist=False): # Read the brain slice image max_int_val = 512; image = sitk.ReadImage(imagefilename) image = sitk.Cast(image, sitk.sitkFloat32 ) image = sitk.RescaleIntensity(image,0.0,max_int_val) image_data = sitk.GetArrayFromImage(image); # Compute the parameters for the Gaussian Mixture model np.random.seed(1) g = mixture.GMM(n_components=4, covariance_type='diag', thresh=0.01, n_iter=100, n_init=1, params='wmc', init_params='wmc') # Estimate model parameters with the expectation-maximization algorithm. g.fit(image_data.flatten()) # pdf of each gaussian model def plot_pdf_models(x, g): we = g.weights_ mu = g.means_ si = np.sqrt(g.covars_) for ind in range(0,we.shape[0]): plt.plot(x,we[ind]*mlab.normpdf(x, mu[ind], si[ind]),linewidth=4) if plot_dist: # Class probability distribution function x = np.linspace(0,max_int_val,500) plt.figure(figsize=(16, 5), dpi=100) plot_pdf_models(x,g) plt.hist(image_data.flatten(), bins=max_int_val/6, range=(0, max_int_val), normed=True) plt.title('Class specific probability distribution functions',fontsize=20) plt.show() # Plot the histogram and the approximated mixture model plt.figure(figsize=(16, 5), dpi=100) plt.hist( image_data.flatten(), bins=max_int_val/6, range=(0, max_int_val), normed=True, color='m', label='Intensity histogram') plt.plot(x,np.exp(g.score(x)), linewidth=4, color='k', label='Gaussian Mixture Model') plt.title('Intensity histogram & Gaussian mixture model',fontsize=20) plt.legend(loc=1, shadow=True, fontsize=20) plt.show() # Single components (class posterioe) plt.figure(figsize=(16, 3), dpi=100) plt.plot(x,g.predict_proba(x), linewidth=4) plt.title('Class posterior probability under each Gaussian in the model',fontsize=20) plt.show() # Compute label image label_data = g.predict(image_data.flatten()) label_data = label_data.reshape(image_data.shape) label_image = sitk.GetImageFromArray(label_data) sitk.Image.CopyInformation(label_image,image) myshow(image[:,:, image.GetSize()[2]/2], title='2D Input image', dpi=42) myshow(sitk.LabelToRGB(label_image[:,:, image.GetSize()[2]/2]), title='EM-GMM segmentation', dpi=42) imagefilename1 = "./data/test_slice.nii.gz" my_em_segmentation(imagefilename1) # Atlas image (high-quality image - obtained by registering and averaging inter-patient T1 brain images) imagefilename2 = "./data/atlas_slicez90.nii.gz" my_em_segmentation(imagefilename2, plot_dist=True)