페이지

2022년 8월 11일 목요일

6.4 역전파 구현

 이상으로 준비 작업이 끝났습니다. 이번 절에서는 [그림 6-1]에 해당하는 계산의 미분을 역전파로 계산해보겠습니다.


먼저[그림 6-1]을 순전파하는 코드부터 보겠습니다.


A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)


이어서 역전파로 y를 미분해보죠. 순전파 때와는 반대 순서로 각 함수의 backward메서드를 호출하면 됩니다. [그림 6-2]는 이때 이루어지는 역전파를 계산 그래프로 그린 모습입니다.


[그림 6-2]를 보면 어떤 순서로 어느 함수의 backward메서드를 호출하면 되는지 알 수 있습니다. 또한 backward메서드의 결과를 어느 변수의 grad로 설정하면 되는지도 알 수 있습니다. 다음은 [그림 6-2]의 계산 그래프를 코드로 옮긴 모습입니다.

y.grad = np.array(1.0)
b.grad = C.backward(y.grad)
a.grad = B.backward(b.grad)
x.grad = A.backward(a.grad)
print(x.grad)

3.297442541400256

역전파는 dy/dy = 1에서 시작합니다. 따라서 출력 y의 미분값을 np.array(1.0)로 설정합니다. 그런 다음 C->B->A순으로 backward 메서드를 호출하기만 하면 됩니다. 이것으로 각 변수의 미분값이 구해집니다.

앞의 코드를 실행하면 x.grad의 값이 3.297442541400256 이라고 나옵니다. 이 값이 y의 x에 대한 미분 결과입니다. 4단계에서 수치 미분으로 구한 값이 3.2974426293330694 였으니 두 결과가 거의 같음을 알 수 있습니다. 역전파를 제대로 구현한 것입니다(더 정확하게는, 올바르게 구현했을 가능성이 큽니다).

이상이 역전파 구현입니다. 제대로 동작하지만 역전파 순서(C->B->A)에 맞춰 호출하는 코드를 우리가 일일이 작성해 넣은 건 영 불편할 것 같습니다. 그래서 다음 단계에서는 이 작업을 자동화하겠습니다.



댓글 없음: