In which we discover optimization and cost functions and how to use them.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
We start with a single slice from the EPI
subj_dir = 'sub009'
import nibabel as nib
fname = 'ds107_sub001_run01.nii.gz'
img = nib.load(fname)
data = img.get_data()
data.shape
(64, 64, 35, 164)
We need to make our data into floating point to make processing easier later
data = data.astype(np.float32)
Now we get our slices:
mid_vol0 = data[:, :, 17, 0]
plt.imshow(mid_vol0, cmap='gray')
<matplotlib.image.AxesImage at 0x36b8f90>
Let's pretend we are doing motion correction. We want to match this slice in volume 0 to the same slice in volume 1.
mid_vol1 = data[:, :, 17, 1]
plt.imshow(mid_vol1, cmap='gray')
<matplotlib.image.AxesImage at 0x39328d0>
To make it a bit harder, let's push mid_vol1
10 voxels down:
moved_mid_vol1 = np.zeros_like(mid_vol1)
moved_mid_vol1.shape
(64, 64)
moved_mid_vol1[10:, :] = mid_vol1[:-10, :]
plt.imshow(moved_mid_vol1, cmap='gray')
<matplotlib.image.AxesImage at 0x3b6ac90>
Let's say you didn't know how many voxels the image was translated. You might want a good way to find out. How are we going to do that?
We need to try moving one image (moved_mid_vol1
) to match our second image. Then we need some way of knowing if our moved image is close to our first image or not.
First let's do the moving.
def x_trans_slice(img_slice, x_vox_trans):
""" Make a new copy of `img_slice` translated by `x_vox_trans` voxels
`x_vox_trans` can be positive or negative
"""
trans_slice = np.zeros_like(img_slice)
if x_vox_trans < 0:
trans_slice[:x_vox_trans, :] = img_slice[-x_vox_trans:, :]
elif x_vox_trans == 0:
trans_slice[:, :] = img_slice
else:
trans_slice[x_vox_trans:, :] = img_slice[:-x_vox_trans, :]
return trans_slice
# Test positive translation does what we did before
assert np.allclose(x_trans_slice(mid_vol1, 10), moved_mid_vol1)
# Check negative translation reverses what we did before
back_again = x_trans_slice(moved_mid_vol1, -10)
assert np.allclose(back_again[:-10, :], mid_vol1[:-10, :])
# Check zero translation works
assert np.allclose(x_trans_slice(mid_vol1, 0), mid_vol1)
Now we need some measure of whether the images are the same or different. This is called a cost function. The measure is high when the images are different, and low when they are similar.
How about subtracting the images?
plt.imshow(mid_vol0 - moved_mid_vol1, cmap='gray')
<matplotlib.image.AxesImage at 0x3d98790>
Can we take the sum of this difference?
No, we need something better:
def my_cost_func(slice0, slice1):
""" Calculate a measure of similarity between two image slices
"""
return np.mean(np.abs(slice0 - slice1))
Now we can check different values of translation with our cost function:
costs = []
translations = range(-25, 15)
for t in translations:
unmoved = x_trans_slice(moved_mid_vol1, t)
cost = my_cost_func(unmoved, mid_vol0)
costs.append(cost)
plt.plot(translations, costs)
[<matplotlib.lines.Line2D at 0x3fe1fd0>]
Let's make it more general by passing in a cost function.
About passing functions as arguments:
def my_add(x, y):
return x + y
def my_mult(x, y):
return x * y
print(my_add(2, 3))
print(my_mult(2, 3))
5 6
def apply_func(x, y, func):
return func(x, y)
print(apply_func(2, 3, my_add))
print(apply_func(2, 3, my_mult))
5 6
Then, we can pass our cost function function, as an argument, so we can try different cost functions.
def cost_at_t(t, slice0, slice1, cost_func):
""" Give cost function at translation value `t` and function `cost_func`
"""
unmoved = x_trans_slice(slice0, t)
return cost_func(unmoved, slice1)
new_costs = []
for t in translations:
cost = cost_at_t(t, moved_mid_vol1, mid_vol0, my_cost_func)
new_costs.append(cost)
assert new_costs == costs
So we can try other cost functions. Any other good ones?
plt.plot(mid_vol1.ravel(), mid_vol0.ravel(), '.')
[<matplotlib.lines.Line2D at 0x4225050>]
def correl_cost(slice0, slice1):
""" Negative correlation between the two images, flattened to 1D """
correl = np.corrcoef(slice0.ravel(), slice1.ravel())[0, 1]
return -correl
corr_costs = [cost_at_t(t, moved_mid_vol1, mid_vol0, correl_cost)
for t in translations]
plt.plot(translations, corr_costs)
[<matplotlib.lines.Line2D at 0x4484a90>]
How about non-integer translations? Will these work?
We need a more general reampling algorithm. This is like slice timing, but in space.
import scipy.ndimage as snd
Investigate snd.affine_transform
. Then:
def fancy_x_trans_slice(img_slice, x_vox_trans):
""" Make a new copy of `img_slice` translated by `x_vox_trans` voxels
`x_vox_trans` can be positive or negative, and can be a float.
"""
trans_slice = snd.affine_transform(img_slice, [1, 1], [-x_vox_trans, 0])
return trans_slice
for t in translations:
assert np.allclose(fancy_x_trans_slice(moved_mid_vol1, t),
x_trans_slice(moved_mid_vol1, t))
def fancy_cost_at_t(t, slice0, slice1, cost_func):
""" Give cost function at translation value `t` and function `cost_func`
"""
unmoved = fancy_x_trans_slice(slice0, t)
return cost_func(unmoved, slice1)
fine_costs = []
fine_translations = np.linspace(-25, 15, 100)
for t in fine_translations:
cost = fancy_cost_at_t(t, moved_mid_vol1, mid_vol0, my_cost_func)
fine_costs.append(cost)
plt.plot(fine_translations, fine_costs)
[<matplotlib.lines.Line2D at 0x465cc50>]
We are looking for the best x translation. At the moment we have to sample lots of x translations and then choose the best. Is there a better way?
import scipy.optimize as sopt
sopt.fmin(fancy_cost_at_t, [0], args=(moved_mid_vol1, mid_vol0, my_cost_func))
Optimization terminated successfully. Current function value: 24.936150 Iterations: 38 Function evaluations: 76
array([-9.99996875])
What actually happened there? Let's track the progress using a callback function:
def my_callback(params):
print("Trying parameters " + str(params))
sopt.fmin(fancy_cost_at_t, [0], args=(moved_mid_vol1, mid_vol0, my_cost_func), callback=my_callback)
Trying parameters [-0.0005] Trying parameters [-0.0015] Trying parameters [-0.0035] Trying parameters [-0.0075] Trying parameters [-0.0155] Trying parameters [-0.0315] Trying parameters [-0.0635] Trying parameters [-0.1275] Trying parameters [-0.2555] Trying parameters [-0.5115] Trying parameters [-1.0235] Trying parameters [-2.0475] Trying parameters [-4.0955] Trying parameters [-8.1915] Trying parameters [-10.2395] Trying parameters [-10.2395] Trying parameters [-10.2395] Trying parameters [-9.9835] Trying parameters [-9.9835] Trying parameters [-9.9835] Trying parameters [-9.9835] Trying parameters [-9.9835] Trying parameters [-9.9995] Trying parameters [-9.9995] Trying parameters [-9.9995] Trying parameters [-9.9995] Trying parameters [-9.9995] Trying parameters [-9.9995] Trying parameters [-9.9995] Trying parameters [-9.99975] Trying parameters [-9.99975] Trying parameters [-9.999875] Trying parameters [-9.999875] Trying parameters [-9.9999375] Trying parameters [-9.9999375] Trying parameters [-9.99996875] Trying parameters [-9.99996875] Optimization terminated successfully. Current function value: 24.936150 Iterations: 38 Function evaluations: 76
array([-9.99996875])
How about adding y translation as well?
def fancy_xy_trans_slice(img_slice, x_y_trans):
""" Make a new copy of `img_slice` translated by `x_y_trans` voxels
x_y_trans is a sequence or array length 2, containing the (x, y) translations in voxels.
Values in `x_y_trans` can be positive or negative, and can be floats.
"""
x_y_trans = np.array(x_y_trans)
trans_slice = snd.affine_transform(img_slice, [1, 1], -x_y_trans)
return trans_slice
for t in translations:
assert np.allclose(fancy_xy_trans_slice(moved_mid_vol1, (t, 0)),
x_trans_slice(moved_mid_vol1, t))
assert np.allclose(fancy_xy_trans_slice(moved_mid_vol1, (0, t)),
x_trans_slice(moved_mid_vol1.T, t).T)
def fancy_cost_at_xy_t(xy_t, slice0, slice1, cost_func):
""" Give cost function at xy translation values `xy_t` and function `cost_func`
"""
unmoved = fancy_xy_trans_slice(slice0, xy_t)
return cost_func(unmoved, slice1)
sopt.fmin(fancy_cost_at_xy_t, [0, 0], args=(moved_mid_vol1, mid_vol0, correl_cost), callback=my_callback)
Trying parameters [-0.0005 0.000375] Trying parameters [-0.0005 0.000375] Trying parameters [-0.0015 0.001 ] Trying parameters [-0.002 0.0008125] Trying parameters [-0.00425 0.00196875] Trying parameters [-0.006375 0.00217188] Trying parameters [-0.0119375 0.00458594] Trying parameters [-0.01896875 0.00619922] Trying parameters [-0.03360938 0.01183398] Trying parameters [-0.05499219 0.01787793] Trying parameters [-0.09496484 0.03216943] Trying parameters [-0.1577168 0.05140308] Trying parameters [-0.26903809 0.08960291] Trying parameters [-0.45020264 0.1471701 ] Trying parameters [-0.76342749 0.25235336] Trying parameters [-1.28236902 0.42007939] Trying parameters [-2.16828949 0.71430892] Trying parameters [-3.64913278 1.19687574] Trying parameters [-6.16139537 2.02661822] Trying parameters [-10.37921325 3.4066231 ] Trying parameters [-10.37921325 3.4066231 ] Trying parameters [-8.32072352 2.73408816] Trying parameters [-8.32072352 2.73408816] Trying parameters [-9.29954918 3.05288812] Trying parameters [-9.29954918 3.05288812] Trying parameters [-9.29954918 3.05288812] Trying parameters [-9.29954918 3.05288812] Trying parameters [-9.29954918 3.05288812] Trying parameters [-9.29954918 3.05288812] Trying parameters [-9.37842421 3.07973801] Trying parameters [-9.3007237 3.0500423] Trying parameters [-9.3007237 3.0500423] Trying parameters [-9.26363524 3.03092572] Trying parameters [-9.26363524 3.03092572] Trying parameters [-9.29033912 3.03433613] Trying parameters [-9.16918333 2.98051658] Trying parameters [-9.16201318 2.96042761] Trying parameters [-9.16201318 2.96042761] Trying parameters [-8.9659392 2.83952035] Trying parameters [-9.11021379 2.88670583] Trying parameters [-8.79020312 2.66848405] Trying parameters [-8.91874696 2.65374412] Trying parameters [-8.59873629 2.43552234] Trying parameters [-8.69581863 2.29693159] Trying parameters [-8.69581863 2.29693159] Trying parameters [-8.40996732 1.69241742] Trying parameters [-8.90706299 1.82660389] Trying parameters [-8.5839082 0.68466879] Trying parameters [-9.41652215 0.38207417] Trying parameters [-9.41652215 0.38207417] Trying parameters [-9.41652215 0.38207417] Trying parameters [-10.2491361 0.07947956] Trying parameters [-10.2491361 0.07947956] Trying parameters [-9.7403964 0.10694714] Trying parameters [-9.93708573 -0.16131299] Trying parameters [-10.04393858 0.02614832] Trying parameters [-10.04393858 0.02614832] Trying parameters [-9.94589108 -0.06919882] Trying parameters [ -9.93018456e+00 -9.21422548e-04] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99098820e+00 -4.45590006e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99886824e+00 -4.58503387e-03] Trying parameters [ -9.99939090e+00 -2.75350566e-03] Trying parameters [ -9.99939090e+00 -2.75350566e-03] Trying parameters [ -9.99939090e+00 -2.75350566e-03] Trying parameters [ -9.99981858e+00 -4.12472145e-03] Trying parameters [ -9.99981858e+00 -4.12472145e-03] Trying parameters [ -9.99981858e+00 -4.12472145e-03] Trying parameters [ -9.99981858e+00 -4.12472145e-03] Trying parameters [ -9.99981858e+00 -4.12472145e-03] Trying parameters [ -9.99981858e+00 -4.12472145e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99996768e+00 -3.25098885e-03] Trying parameters [ -9.99999243e+00 -3.57752448e-03] Trying parameters [ -9.99998081e+00 -3.10855528e-03] Trying parameters [ -9.99998081e+00 -3.10855528e-03] Trying parameters [ -9.99999609e+00 -3.38906539e-03] Trying parameters [ -9.99999609e+00 -3.38906539e-03] Trying parameters [ -9.99999609e+00 -3.38906539e-03] Trying parameters [ -9.99999651e+00 -3.30086135e-03] Trying parameters [ -9.99999651e+00 -3.30086135e-03] Optimization terminated successfully. Current function value: -0.995571 Iterations: 102 Function evaluations: 188
array([ -9.99999651e+00, -3.30086135e-03])
You'll need the example structural from long ago:
anat_fname = 'ds107_sub001_highres.nii.gz'
anat_img = nib.load(anat_fname)
anat_data = anat_img.get_data()
anat_data.shape
(256, 256, 192)
Remember the EPI?
data.shape
(64, 64, 35, 164)
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(10.5,18.5)
axes[0].imshow(anat_data[:, :, 96], cmap="gray")
axes[1].imshow(data[:, :, 17, 0], cmap="gray")
<matplotlib.image.AxesImage at 0x4923ed0>
How do we make these match up?
Remember the affine?
anat_img.get_affine()
array([[ -1., 0., 0., 125.], [ 0., 1., 0., -144.], [ 0., 0., 1., -99.], [ 0., 0., 0., 1.]])
img.get_affine()
array([[ 3. , 0. , 0. , -93. ], [ 0. , 3. , 0. , -103.55609131], [ 0. , 0. , 3. , -51.73400116], [ 0. , 0. , 0. , 1. ]])
What is the transformation to go from voxels in the functional image img
to voxels in the anatomical image anat_img
?
funcvox2mm = img.get_affine()
anatvox2mm = anat_img.get_affine()
What is the transformation from mm to anatomical voxels?
mm2anatvox = np.linalg.inv(anatvox2mm)
funcvox2anatvox = np.dot(mm2anatvox, funcvox2mm)
funcvox2anatvox
array([[ -3. , 0. , 0. , 218. ], [ 0. , 3. , 0. , 40.44390869], [ 0. , 0. , 3. , 47.26599884], [ 0. , 0. , 0. , 1. ]])
How can we use this with scipy.ndimage
?
RZS = funcvox2anatvox[:3, :3]
translations = funcvox2anatvox[:3, 3]
resamp_anat = snd.affine_transform(anat_data, RZS, translations, output_shape=data.shape[:-1])
resamp_anat.shape
(64, 64, 35)
Let's take a mid-slice out of the resampled anatomy and compare to the functional:
resamp_mid = resamp_anat[:, :, 17]
fig, axes = plt.subplots(1, 2)
fig.set_size_inches(10.5,18.5)
axes[0].imshow(resamp_mid, cmap="gray")
axes[1].imshow(data[:, :, 17, 0], cmap="gray")
<matplotlib.image.AxesImage at 0x7f50f40f7a10>
How about the correlation now?
plt.plot(resamp_mid.ravel(), mid_vol0.ravel(), '.')
[<matplotlib.lines.Line2D at 0x4e06190>]