Post

PRML [1.1] 예시:다항식 곡선 피팅

데이터셋을 나타내는 모델 찾기

$N$개의 관찰값 $x$로 이루어진 훈련 집합 $\textbf{x} \equiv (x_1, …, x_N)^T$와 그에 해당하는 표적값 $\textbf t \equiv (t_1, …, t_N)^T$가 다음과 같이 주어졌다고 하자. ($N = 10$)

noise data

$t$ 값들은 $\sin(2\pi x)$의 출력값에 노이즈를 더하여 생성한 것이고, 저 훈련 집합 값들만 이용해서 새로운 입력값 $\widehat x$ 가 들어왔을 때 타깃 변수 $\widehat t$ 를 예측하는 것이 목표이다. 그러기 위해서는 원래 그래프에 수렴하는 개형을 찾아야겠다.


다항식 $y(x, \textbf{w})$

$y(x, \textbf{w})$는 우리가 선택한 모델의 예측값이고, $\bf w$는 모델의 파라미터이다.

다음은 곡선을 피팅하는 데 사용하는 다항식.

\[y(x, \textbf{w}) = w_0 + w_1x + w_2x^2 + ... + w_Mx^M=\sum_{j=0}^{M} w_jx^j \tag{식 1.1}\]

$M$이 차수. 계수 $w_0, …, w_M$를 벡터 $w$로 표현. $x$에 대해서는 비선형이지만, 계수 $w$에 대해서는 선형. $\rightarrow$ 선형 모델


Error Function

error function(오차 함수)는 실제 훈련 데이터 셋의 표적값과 우리의 모델이 예측한 값 사이의 차이를 측정하는 함수이고, 위에서 봤듯이 $y(x_n, \textbf{w})$가 우리가 선택한 모델의 예측이다.

따라서, 오차함수를 정의하고, 그 값을 최소화함으로써 훈련 집합에 피팅할 수 있다. 다항식 회귀에서, 훈련 데이터의 표적값들을 가장 잘 설명하는 다항식의 계수를 찾는 것이 목표가 되겠다. 다음은 각각의 데이터 포인트 $x_n$에 대해서 예측치 $y(x_n, \textbf{w})$와 표적값 $t_n$ 사이의 오차를 제곱하여 합산한 오차 함수.

\[E(w)={1\over2} \sum_{n=1}^{N} \{y(x_n, w)-t_n\}^2 \tag{식 1.2}\]

이 함수를 계수에 대해서 미분하면 $\bf w$에 대하여 선형 식이 나온다. 따라서, 오차함수를 최소화할 수 있는 $\bf w^\star$를 찾을 수 있다.

error function


Exercise 1.1

1.122를 풀어서, 오차함수를 최소화하는 $\bf w$가 존재한다는 것을 증명해야 한다.

위와 같이 주어졌을 때, 오차함수 $E(w)$를 살펴보자. $y(x_n, \textbf{w})$를 대입하여 전개하면

\[E(w)={1\over2} \sum_{n=1}^{N} (\sum_{j=0}^{M} w_jx^j-t_n)^2\]

이다. $w_i$ 에 대해 미분하면

\[\frac{\partial E(w)}{\partial w_i}=\sum_{n=1}^{N} \{\sum_{j=0}^{M} w_j (x_n)^{j}- t_n\} x_n^i=0 \tag{㉠}\]

이제 주어진 식 1.122를 풀어보자. 식 1.123을 대입하면 식 1.122는 다음과 같이 표현이 가능하다.

\[\sum_{n=1}^{N} \sum_{j=0}^{M} w_j (x_n)^{i+j}=\sum_{n=1}^{N} (x_n)^i t_n\]

이항하고 $(x_n)^i$에 대해 묶으면

\[\sum_{n=1}^{N} \{\sum_{j=0}^{M} w_j (x_n)^{i+j}-(x_n)^i t_n\}\\=\sum_{n=1}^{N} (x_n)^i \{\sum_{j=0}^{M} w_j (x_n)^{j}- t_n\}=0 \tag{㉡}\]

과 같이 표현할 수 있다. ㉠과 ㉡이 같은 걸 확인할 수 있다. ㉠ 식을 적절히 변경하면 식 1.122가 나온다. (물론 나는 여기서 식 1.122를 변경해서 ㉠이 나왔다. 여러분은 증명을 이렇게 하시면 안 됩니다 ^^)


Model Selection (Model Comparison)

over-fitting

차수 $M$을 이리저리 바꿔가면서 데이터셋에 가장 적합한 그래프 개형을 찾아야 햔다. 다음은 $M$이 0, 1, 3, 9일 때의 그래프 개형이다.

graph according to M

$M=3$일 때는 비교적 잘 나타내는 것 같은데, $M=9$일 때는 살펴보자. 이때는 계수가 10개이다 ($w_0, …, w_9$). 10차의 자유도를 가지고, 데이터 포인트가 10개이므로 완벽한 피팅이 가능하다. 다항식이 모든 데이터 포인트를 지나며, $E(w^\star)=0$이다. 하지만, 초록색 그래프 $\sin(2\pi x)$를 나타내는 데 실패했고, 진동이 너무 심하다.

over-fitting(과적합) 된 케이스이다.


RMS error

우리는 저런 과적합한 개형이 아닌, 일반화를 하고 싶기 때문에 $M$ 값에 따라 변하는 $E(w^{\star})$의 잔차를 관찰하고 싶다. RMS error (root mean square error, 평균 제곱근 오차)를 사용한다.

\[E_{RMS}=\sqrt{2E(w^\star)/N} \tag{식 1.3}\]

E RMS

$M$ 값에 따른 RMS error를 확인할 수 있다. 10개의 데이터 포인트에 대해 10개의 계수로 과적합 피팅을 하면 훈련 집합에서는 오차가 0이다. 대신, 일반화가 되지 않아 테스트 셋에서는 오차가 굉장히 커진다.

w according to M

또한, $M$ 값이 커질 수록 계수의 단위가 커지고 있다. 그래프의 진동폭이 점점 커진다는 뜻이겠지.


데이터셋 크기에 따른 결과 변화

graph according to data size

둘 다 $M=9$이지만, $N=100$일 때는 진동이 줄어든 것을 확인할 수 있다. 모델의 복잡도를 일정하게 유지시킬 때, 데이터셋의 크기를 증가시키면 오버피팅 현상이 완화된다.

오버피팅 현상은 Maximum likelihood 방법이나 베이지안 관점에서 해결할 수 있다. → 1장 뒷부분에서 볼 수 있다.


regularization (정규화)

데이터셋의 크기가 제한되어 있을 때, 비교적 복잡한(차수가 높은) 모델을 사용하고 싶을 때 정규화를 사용할 수 있다. 식 1.2에 페널티항을 추가함으로써, 오차함수의 계수의 크기가 커지는 것을 방지한다.

\[Ẽ (\textbf{w})={1\over2}\sum_{n=1}^{N}\{y(x_n,w)−t_n\}^2+{\lambda\over2}‖\textbf{w}‖^2 \tag{식 1.4}\]

$‖\textbf{w}‖^2 ≡ \textbf{w}^T\textbf{w}≡w_0^2+w_1^2+…+w_M^2$ 이다. coefficient $\lambda$는 정규화항의 가중치(제곱합 오류항에 대한 상대적인 중요도)를 결정한다. $w_0$는 종종 정규화항에서 제외한다. 식 1.4도 미분을 통해 유일해를 구할 수 있다.

다음은 계수 $\lambda$에 따른 결과의 차이이다.

lambda

왼쪽 그림에서는 적절히 정규화된 오차 함수를 사용한 것이고, 오른쪽은 $\lambda$값을 너무 큰 걸 사용한 경우이다. $\ln \lambda=-\infty$가 정규화를 하지 않은 것. 너무 큰 $\lambda$값을 사용하면 언터피팅을 야기한다. 아니 전혀 가깝지 않잖아!

E RMS

$M=9$일 때 $E_{RMS}$는 위와 같이 변한다.

  • 훈련 데이터에 대한 파란색 곡선부터 보자면,

    $\lambda$가 너무 작을 때 (로그 스케일에서 매우 음수일 때) $E_{RMS}$가 매우 낮다. 과적합(모델이 훈련 데이터의 노이즈까지 학습함)을 의미한다.

    또, $\lambda$가 특정 값 이상으로 커지면 다시 점점 오차가 증가하고 있다. 언더피팅이 되고 있는 것을 의미한다.

  • 테스트 데이터에 대한 빨간 곡선을 보면,

    처음에 과적합 상태에서는 $E_{RMS}$가 높다. $\lambda$가 증가하면서 $E_{RMS}$가 감소한다. 적절한 $\lambda$ 값에서 잘 일반화 되었음을 의미한다.

    $\lambda$가 너무 커지면 언더피팅이 일어나서 성능이 감소한다.


Exercise 1.2

E RMS

\[\sum_{j=0}^{M} A_{ij}w_j=T_i \tag{식 1.122}\] \[A_{ij}=\sum_{n=1}^{N} (x_n)^{i+j} \\ T_i=\sum_{n=1}^{N} (x_n)^i t_n \tag{식 1.123}\]

식 1.122를 정규화 식 1.4에 대하여 다시 쓰라는 거다. 정규화된 제곱합 오류 함수를 최소화하는 계수가 $w_i$란다.

이번에는 증명 흐름을 정상적으로 끌고 가보자.

\[Ẽ (w)={1\over2}\sum_{n=1}^{N}\{y(x_n,w)−t_n\}^2+{λ\over2}‖w‖^2\] \[\frac{\partial Ẽ(w)}{\partial w_i}=\sum_{n=1}^{N} \{\sum_{j=0}^{M} w_j (x_n)^{j}- t_n\} x_n^i + λw_i=0\] \[\sum_{n=1}^{N} \{\sum_{j=0}^{M} w_j (x_n)^{i+j}\} + λw_i=\sum_{n=1}^{N} t_n x_n^i\] \[\sum_{j=0}^{M} A_{ij}w_j+ λw_i=T_i\]

정규화된 제곱합 오류 함수를 최소화하는 계수 $w_i$에 대하여 선형 방정식을 나타낸 것이다.

This post is licensed under CC BY 4.0 by the author.