フィッティングする多項式の数order(係数 $c$ の数)と, xおよびyの点列を与えると最小二乗フィッティングする(やや)汎用的な関数fit_polyを作る.
def fit_poly(order=2, p_xs, p_ys):
pass # ここに実装
まずSymbolを生成するが, これは与えられた次数に合わせて可変的にする必要がある.
係数$ c_n, c_{n-1},\ ...,c_2, c_1\ $ として、文字列によって生成するので
cs = ["c_%d"%i for i in range(order, 0, -1)]
としよう. 添字を逆順に生成しているのは多項式の記法(次数の大きい順に添字の大きな係数で記述する)を意識してのことである.
続いてこの文字列からシンボルのリストを生成する.
syms = [Symbol(c) for c in cs]
さて, 最小二乗法 $$ J = \frac{1}{2} \sum_{\alpha=1}^{N}(y_{\alpha} - (c_n x_{\alpha}^n + c_{n-1} x_{\alpha}^{n-1} + ... + c_1))^2 \to min $$ は,
J = sum(term(x, y, syms) for x, y in zip(p_xs, p_ys))/2
で計算する. ただし内包表記中の関数termは,
def term(x, y, syms):
poly_x = sum(syms[i]*x**i for i in range(len(syms)-1, -1, -1))
return ((y-poly_x)**2)
である. $ c_k\ $ での偏微分
$$ \frac{\partial J}{\partial c_k} = \sum_{\alpha=1}^{N} (y_{\alpha} - c_n x_{\alpha}^n - c_{n-1} x_{\alpha}^{n-1} - ... - c_1)(-x_{\alpha}^{n-k}) $$はSymPyのdiffが面倒を見てくれる.
diffs = [diff(J, sym) for sym in syms]
正規方程式を解く.
sols = solve(diffs, syms)
最後に解のリストを返す.
return [sols[sym].evalf() for sym in reversed(syms)]
まとめると次のようになる.
def term(x, y, syms):
poly_x = sum(syms[i]*x**i for i in range(len(syms)-1, -1, -1))
return ((y-poly_x)**2)
def fit_poly(order, p_xs, p_ys):
cs = ["c_%d"%i for i in range(order, 0, -1)]
syms = [Symbol(c) for c in cs]
J = sum(term(x, y, syms) for x, y in zip(p_xs, p_ys))/2
diffs = [diff(J, sym) for sym in syms]
sols = solve(diffs, syms)
return [sols[sym].evalf() for sym in reversed(syms)]
fit_polyを用いて『これなら分かる 応用数学教室』の【例 1.3】を次数 1, 5, 9, 13で解いた様子を以下に示す.
import numpy as np
from matplotlib import pyplot as plt
from sympy import *
def term(x, y, syms):
poly_x = sum(syms[i]*x**i for i in range(len(syms)-1, -1, -1))
return ((y-poly_x)**2)
def fit_poly(order, p_xs, p_ys):
cs = ["c_%d"%i for i in range(order, 0, -1)]
syms = [Symbol(c) for c in cs]
J = sum(term(x, y, syms) for x, y in zip(p_xs, p_ys))/2
diffs = [diff(J, sym) for sym in syms]
sols = solve(diffs, syms)
return [sols[sym].evalf() for sym in reversed(syms)]
def main():
p_xs = [5.6, 5.8, 6.0, 6.2, 6.4, 6.4, 6.4, 6.6, 6.8]
p_ys = [ 30, 26, 33, 31, 33, 35, 37, 36, 33]
for idx in range(2, 15, 4):
sols = fit_poly(idx, p_xs, p_ys)
xs = np.linspace(5.5, 7.0, 1000)
ys = [sum(c*x**i for i, c in enumerate(reversed(sols))) for x in xs]
plt.plot(xs, ys, label="dim=%d" % (idx-1))
plt.plot(p_xs, p_ys, ".")
plt.xlim(5.5, 7.0)
plt.ylim(28, 40)
plt.legend(loc="upper left")
main()
あまり次数を増やしすぎると, 補間での値が大きく振動してしまうことが見て取れる(ルンゲ現象).