def my_em_segmentation(imagefilename, plot_dist=False):
# Read the brain slice image
max_int_val = 512;
image = sitk.ReadImage(imagefilename)
image = sitk.Cast(image, sitk.sitkFloat32 )
image = sitk.RescaleIntensity(image,0.0,max_int_val)
image_data = sitk.GetArrayFromImage(image);
# Compute the parameters for the Gaussian Mixture model
np.random.seed(1)
g = mixture.GMM(n_components=4, covariance_type='diag', thresh=0.01, n_iter=100,
n_init=1, params='wmc', init_params='wmc')
# Estimate model parameters with the expectation-maximization algorithm.
g.fit(image_data.flatten())
# pdf of each gaussian model
def plot_pdf_models(x, g):
we = g.weights_
mu = g.means_
si = np.sqrt(g.covars_)
for ind in range(0,we.shape[0]):
plt.plot(x,we[ind]*mlab.normpdf(x, mu[ind], si[ind]),linewidth=4)
if plot_dist:
# Class probability distribution function
x = np.linspace(0,max_int_val,500)
plt.figure(figsize=(16, 5), dpi=100)
plot_pdf_models(x,g)
plt.hist(image_data.flatten(), bins=max_int_val/6, range=(0, max_int_val), normed=True)
plt.title('Class specific probability distribution functions',fontsize=20)
plt.show()
# Plot the histogram and the approximated mixture model
plt.figure(figsize=(16, 5), dpi=100)
plt.hist( image_data.flatten(), bins=max_int_val/6, range=(0, max_int_val),
normed=True, color='m', label='Intensity histogram')
plt.plot(x,np.exp(g.score(x)), linewidth=4, color='k', label='Gaussian Mixture Model')
plt.title('Intensity histogram & Gaussian mixture model',fontsize=20)
plt.legend(loc=1, shadow=True, fontsize=20)
plt.show()
# Single components (class posterioe)
plt.figure(figsize=(16, 3), dpi=100)
plt.plot(x,g.predict_proba(x), linewidth=4)
plt.title('Class posterior probability under each Gaussian in the model',fontsize=20)
plt.show()
# Compute label image
label_data = g.predict(image_data.flatten())
label_data = label_data.reshape(image_data.shape)
label_image = sitk.GetImageFromArray(label_data)
sitk.Image.CopyInformation(label_image,image)
myshow(image[:,:, image.GetSize()[2]/2], title='2D Input image', dpi=42)
myshow(sitk.LabelToRGB(label_image[:,:, image.GetSize()[2]/2]), title='EM-GMM segmentation', dpi=42)