【超初心者向け】フィッシャーの線形判別分析法をpythonで実装してみた。
【超初心者向け】フィッシャーの線形判別分析法をpythonで実装してみた。

【超初心者向け】フィッシャーの線形判別分析法をpythonで実装してみた。

[chat face="kubi_kashigeru_josei-01-1.png" name="" align="le

\begin &\boldsymbol& \propto \boldsymbol_w^(\boldsymbol_2 – \boldsymbol_1) \\ &\boldsymbol_w^& = \sum_^T + \sum_^T \end

def cal_sw(x1, m1, x2, m2): sw = ((x1 - m1).T @ (x1 - m1)) + ((x2 - m2).T @ (x2 - m2)) return sw

実際に計算

w = LA.inv(cal_sw(x1, m1, x2, m2)) @ (m2 - m1)

グラフ出力の準備

xlist = np.arange(-5,10,0.1) ylist = m[1] + (w[1]/w[0]) * (xlist - m[0]) ydisc = m[1] + (-w[0]/w[1]) * (xlist - m[0])

グラフ出力

plt.plot(xlist, ylist, color=cm(4)) plt.plot(xlist, ydisc, linestyle='dashed', color='black') plt.plot(x1[:,0], x1[:,1], 'o', color=cm(0)) plt.plot(x2[:,0], x2[:,1], 'o', color=cm(1)) plt.axis('equal') plt.ylim(-6,12) plt.show()

考察

フィッシャーの線形判別分析は,名前の通りクラス判別のために使用される手法です。実際に,出力されたグラフの点線で2つのクラスがきれいに判別できることが分かります。フィッシャーの判別基準では, クラス内の分散が小さく ,かつ クラス間の分散が大きく なるように線形な識別面を設定しました。

これは,単純に「射影先でクラスの平均値がよく離れている」という基準で設定した識別面よりも良い結果を与えます。なぜなら,平均だけで分離の良さを比較しようとすると,データの広がり(分散)を無視することになるからです。 データの広がり(分散)を使って分離の良さを記述したもの が,フィッシャーの判別基準になります。

もし理論にモヤモヤがあれば

¥3,080 (2026/03/29 02:30:23時点 楽天市場調べ- 詳細)

全コード

import numpy as np import matplotlib.pyplot as plt cm = plt.get_cmap("tab10") m1 = np.array([3, 1]) s1 = np.array([[1, 2], [2, 5]]) m2 = np.array([1, 3]) s2 = np.array([[1, 2], [2, 5]]) N = 100 x1 = np.random.multivariate_normal(m1, s1, N) x2 = np.random.multivariate_normal(m2, s2, N) plt.plot(x1[:,0], x1[:,1], 'o', color=cm(0)) plt.plot(x2[:,0], x2[:,1], 'o', color=cm(1)) plt.axis('equal') plt.show() ​ m1 = np.array([3, 1]) s1 = np.array([[1, 2], [2, 5]]) m2 = np.array([1, 3]) s2 = np.array([[1, 2], [2, 5]]) ​ N = 100 ​ x1 = np.random.multivariate_normal(m1, s1, N) x2 = np.random.multivariate_normal(m2, s2, N) plt.plot(x1[:,0], x1[:,1], 'o', color=cm(0)) plt.plot(x2[:,0], x2[:,1], 'o', color=cm(1)) plt.axis('equal') plt.show() def cal_sw(x1, m1, x2, m2): sw = ((x1 - m1).T @ (x1 - m1)) + ((x2 - m2).T @ (x2 - m2)) return sw w = LA.inv(cal_sw(x1, m1, x2, m2)) @ (m2 - m1) xlist = np.arange(-5,10,0.1) ylist = m[1] + (w[1]/w[0]) * (xlist - m[0]) ydisc = m[1] + (-w[0]/w[1]) * (xlist - m[0]) plt.plot(xlist, ylist, color=cm(4)) plt.plot(xlist, ydisc, linestyle='dashed', color='black') plt.plot(x1[:,0], x1[:,1], 'o', color=cm(0)) plt.plot(x2[:,0], x2[:,1], 'o', color=cm(1)) plt.axis('equal') plt.ylim(-6,12) plt.show() 京都大学で機械学習を学んでいます。

【第6章カーネル法】PRML演習問題解答を全力で分かりやすく解説<6.1>

2019年6月3日 zuka

【競プロ精進日記】ABC128-C

2020年6月18日 zuka

【初学者向け】情報セキュリティ<著作権編>

2019年8月1日 zuka

【第8章】PRML演習問題解答を全力で分かりやすく解説<8.19>

2019年6月19日 zuka

【超初心者向け】TensorFlowのチュートリアルを読み解く。

2019年6月20日 zuka

【第3章線形回帰モデル】PRML演習問題解答を全力で分かりやすく解説<3.19>

2019年5月6日 zuka POSTED COMMENT リトルトゥース より:

zukaさん 判別分析を、既存の関数に頼らずPythonで書こうとしておりました。 なので、大変助かりました! zukaさんのコードを実行するとエラーになりした。 ylist = m[1] + (w[1]/w[0]) * (xlist – m[0]) ydisc = m[1] + (-w[0]/w[1]) * (xlist – m[0]) のところで、mが定義されていないためかと思われます。 今後とも、よろしくお願い致します!