이전 단계에서 역전파의 구동 원리를 설명했습니다. 이번 단계에서는 Variable과 Function클래스를 확장하여 역전파를 이용한 미분을 구현하겠습니다. Variable 클래스부터 살펴 보죠
2022년 8월 7일 일요일
5.3 계산 그래프로 살펴보기
다음과 같이 통상적인 계산인 순전파 계산 그래프(그림 5-1)와 미분을 계산하는 역전파 계산 그래프(그림 5-4)를 위아래로 나련히 놓고 살펴봅시다.
이렇게 비교하니 순전파와 역전파의 관계가 명확히 보입니다. 예를 들어 순전파 시의 변수 a는 역전파 시의 미분 dy/da에 대응합니다, 마찬가지로 b와 dy/db가 대응하고 x 와 dy/dx가 대응합니다. 또한 함수에도 대응 관계가 보입니다. 함수 B는 역전파의 B'(a)에 대응하고 A는 A'(x)에 대응하는 식 입니다. 이렇게 변수는 '통상값'과 '미분값'이 존재하고, 함수는 '통상 계산(순전파)'과 '미분값을 구하기 위한 계산(역전파)'이 존재하는 것으로 생각할 수 있습니다. 이를 통해 역전파를 어떻게 구현할지 짐작해볼 수 있을 것 입니다.
마지막으로 [그림 5-5]의 함수 노드 C'(b)를 계산하려면 b값이 필요하다는 사실입니다. 마찬가지로 B'(a)를 구하려면 입력 a의 값이 필요합니다. 무슨 말인고 하니, 역전파 시에는 순전파 시 이용한 데이터가 필요하다는 것입니다. 따라서 역전파를 구현하려면 먼저 순전파를 하고, 이때 각 함수가 입력 변수(앞의 예에서는 x, a, b)의 값을 기억해두지 않으면 안 됩니다. 그런 다음에야 각 함수의 역전파를 계산할 수 있습니다.
이상이 역전파의 이론 설명입니다. 다소 복잡한가요? 좋은 소식을 알려드리겠습니다. 다행히 역전파는 이 책에서 가장 어려운 내용에 속한답니다. 그리고 아직 잘 이해되지 않는 부분도 실제로 코드를 실행해보면 이해될 것입니다. 다음 단계에서는 역전파를 구현하고 실제로 돌려보며 검증하겠습니다.
2022년 8월 6일 토요일
5.2 역전파 원리 도출
이제 [식 5.2]를 차분히 살펴볼 시간입니다. [식 5.2]는 합성 함수의 미분은 구성 함수들의 미분의 곱으로 분해할 수 있음을 뜻합니다. '곱하는 순서'까지 말해주지는 않지만 사실 어떤 순서로 곱해도 상관없습니다. 그러니 [식 5.3]과 같이 출력에서 입력 방향으로(즉, 역방향으로)순서대로 계산해보겠습니다.
dy/dx = ((dy/dy*dy/db)*db/da)da/dx
[식 5.3]과 같이 출력에서 입력 방향으로, 즉 보통의 계산과는 반대 방향으로 미분을 계산합니다. 이때 [식 5.3]의 흐름은 [그림 5-2]와 같습니다.
[그림 5-2]처럼 y에서 입력 x 방향으로 곱하면서 순서대로 미분하면 최종적으로 dx/dy가 구해집니다. 계산 그래프로는 [그림 5-3]처럼 됩니다.
[그림 5-3]의 계산 그래프를 잘 관찰해봅시다. 우선 dx/dy(=1)에서 시작하여 dy/db와 곱합니다. 여기서 dy/db는 함수 y=C(b)의 미분입니다. 따라서 함수 C의 도함수를 C'의 도함수를 C''로 나타내면dy/db=C''(b)라고 쓸 수 있습니다. 마찬가지로 db/da=B'(a)이고 da/dx=A'(x)입니다. 이에 따라 [그림 5-3]의 계산 그래프는 다음과 같이 단순화할 수 있습니다.
[그림 5-4]와 같이 도함수의 곱을 함수 노드 하나로 그릴 수 있습니다. 이제 미분값이 전파되는 흐름이 평확해집니다. [그림 5-4]를 보면 'y의 각 변수에 대한 미분값'이, 즉 변수 y, b, a, x 에 대한 미분값이 오른쪽에서 왼쪽으로 전파되는 것을 알 수 있습니다. 이것이 역전파입니다. 여기서 중요한 점은 전파되는 데이터는 모두 'y의 미분값'이라는 것입니다. 구체적으로 dy/dy, dy/db, dy/da, dy/dx처럼 모두' y의 oo에 대한 미분값'이 전파되고 있습니다.
[식 5.3]과 같이 계산 순서를 출력에서 입력 방향으로 정한 이유는 y의 미분값을 전파하기 위해서입니다. 즉 y를 '중요 요소'로 대우하기 때문입니다. 만약 입력에서 출력방향으로 계산했다면 중요 요소는 입력인 x가 됩니다. 이 경우 전파되는 값은 dx/dx -> da/dx -> db/dx -> dy/dx가 되어 x 에 대한 미분을 전파하게 됩니다.
머신러닝은 주로 대량의 매개변수를 입력받아서 마지막에 손실 함수(loss function)를 거쳐 출력을 내는 형태로 진행됩니다. 손실 함수의 출력은(많은 경우) 단일한 스칼라값이며, 이 값이 '중요 요소'입니다. 즉, 손실 함수의 각 매개변수에 대한 미분을 계산해야 합니다. 이런 경우 미분값을 출려에서 입력 방향으로 전파하면 한 번의 전파만으로 모든 매개 변수에 대한 미분을 계산할 수 있습니다. 이처럼 계산이 효율적으로 이뤄지기 때문에 미분을 반대 방향으로 전파하는 방식(역전파)을 이용하는 것입니다.
5.1 연쇄 법칙
역전파를 이해하는 열쇠는 연쇄법칙(chain rule)입니다. chain은 '사슬'이라는 뜻으로, 여러 함수를 사슬처럼 연결하여 사용하는 모습을 빗댄것입니다. 연쇄 법칙에 따르면 합성 함수(여러함수가 연결된 함수)의 미분은 구성 함수 각각을 미분한 후 곱한 것과 같습니다.
구체적인 예를 하나 들어보죠. y = F(x)라는 함수가 있다고 합시다. 이 함수는 a = A(x), b= B(a), c = C(b)라는 세 함수로 구성되어 있습니다. 계산 그래프로 그리면 [그림 5-1]처럼 됩니다.
이때 x 에 대한 y의 미분은 [식 5.1]로 표현할 수 있습니다.
dy/dx = dy/db * db/da * da/dx
[식 5.1]에서 알수 있듯이 x에 대한 y의 미분은 구성 함수 각각의 미분값을 모두 곱한 값과 같습니다. 즉, 합성 함수의 미분은 각 함수의 국소적인 미분들로 분해 할 수 있ㅆ브니다. 이것이 연쇄 법칙입니다. 또한 [식 5.1]앞에 다음과 같이 dx/dy를 명시할 수도 있습니다.
dy/dx = dy/dy*dy/db*db/da*da/dx
dy/dy는 '자신'에 대한 미분이라 항상 1입니다. 따라서 생략하는 것이 보통이지만 이 책에서는 역전파를 구현할 때를 대비하여 포함하도록 하겠습니다.
dy/dy는 y의 ㅛ에 대한 미분입니다. 이때 y가 작은 값만큼 변하면 자기 자신도 y도 당연히 같은 크기만큼 변합니다. 따라서 변화율은 어떤 함수의 경우에도 항상 1입니다.
STEP 5 역전파 이론
우리는 수치 미분을 이용해 비분을 계산할 수 있게 되었지만 수치 미분은 계산 비용과 정확도면에서 문제가 있습니다. 지금이 바로 역전파(backpropatgation, 오차역전파법)가 구세주로 등장할 시점입니다! 역전파를 이용하면 미분을 효율적으로 계산할 수 있고 결괏값이 오차도 더 작습니다. 이 번 단계에서는 아직 역전파 구현까지는 들어가지 않고 이론 설명에 집중하겠습니다(구현은 다음 단계로 양보했습니다).
4.4 수치 미분의 문제점
수치 미분의 결과에는 오차가 포함되어 있습니다. 대부분의 경우 오차는 매우 작지만 어떤 계산이냐에 따라 커질 수도 있습니다.
수치미분의 결과에 오차가 포함되기 쉬운 이유는 주로 '자릿수 누락' 때문입니다. 중앙차분 등 '차이'를 구하는 계산은 주로 크기가 비슷한 값들을 다루므로 계산 결과에서 자릿수 누락이 생겨 유효 자릿수가 줄어들 수 있습니다. 예를 들어 유효 자릿수가 4일 때 1.234 - 1.2333 이라는 계산(비슷한 값끼리의 뺄셈)을 생각해보죠. 계산 결과는 0.001 되어 유효 자릿수가 1로 줄어듭니다. 원래는 1.234.... - 1.233... = 0.001434..같은 결과였을지도 모르는데, 자릿수 누락 때문에 0.001이 됐다고 볼 수 있습니다. 이와 같은 원리 때문에 수치 미분을 이용하면 자릿수 누락이 발생하여 오차가 포함되기 쉽습니다.
수치 미분의 더 심각한 문제는 계산량이 많다는 점입니다.변수가 여러 개인 계산을 미분할 경우 변수 각각을 미분해야 하기 때문입니다. 신경망에서는 매개변수를 수백만개 시상 사용하는 것일도 아니므로 이 모두를 수치 미분으로 구하는 것은 현실적이지 않습니다. 그래서 등장한 것이 바로 역전파입니다. 다음 단계에서 드디어 역전파를 소개합니다.
덧붙여서, 수치 미분은 구현하기 쉽고 거의 정확한 값을 얻을 수 있습니다. 이에 비해 역전파는 복잡한 알고리즘이라서 구현하면서 버그가 섞여 들어ㅓ가기 쉽습니다. 그래서 역전파를 정확하게 구현했는지 확인하기 휘애 수치 미분의 별과를 이용하곤합니다. 이를 기울기 확인(gradient checking)이라고 하는데, 단순히 수치 미분 결과와 역전파의 결과를 비교하는 것입니다. 기술기 확인은 10단계에서 구현합니다.
4.3 합성 함수의 미분
지금까지는 y = x ** 2 이라는 단순한 함수를 다뤘습니다. 이어서 합성 함수를 미분해 봅시다. y = (e ** x) ** 2 이라는 계산에 대한 미분 dy/dx를 계산할 것입니다. 코드는 다음과 같습니다.
이 코드는 일련의 계산을 f라는 함수로 정리했습니다. 파이썬에서는 함수도 객체이기 때문에 다른 함수에 인수로 전달할 수 있습니다. 실제로 앞의 코드에서는 numerical_diff 합수에 함수 f를 전달했습니다.
실행 결과를 보면 미분한 값이 3.297...입니다. x를 0.5에서 작은 값만큼 변화시키면 y는 작은 값의 3.297...배만큼 변한다는 의미죠.
이상에서 우리는 미분을 '자동으로'계산하는 데 성공했습니다. 원하는 계산을 파이썬 코드로 표현한 다음(앞의 예에서 함수 f로 정의) 미분해달라고 프로그램에 요구했습니다. 이 방식대로 하면 아무리 보=ㄱ잡하게 조립된 함수라도 미분을 자동으로 계산할 수 있습니다! 이제부터는 함수의 종류를 늘려가면서 어떠한 계산도 (미분 가능한 함수라면) 미분할 수 있습니다. 그러나 안타깝게도 수치 미분에는 문제가 있습니다.