방금 보여드린 역전파 코드에는 똑같은 처리 흐름이 반복해서 나타났습니다. 변수에서 하나 앞의 변수로 거슬러 올라가는 로직이 그러했습니다. 그러므로 이 반복 작업을 자동화할 수 있도록 Variable 클래스에 badckward라는 새로운 메서드를 추가하겠습니다.
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
f = self.creator # 1. 함수를 가져온다.
if f is not None:
x = f.input # 2. 함수의 입력을 가져온다.
x.grad = f.backward(self.grad) # 3. 함수의 backward 메서드를 호출한다.
x.backward() # 하나 앞 변수의 backward 메서드를 호출한다(재귀)
backward메서드는 지금까지 반복한 처리 흐름과 거의 동일합니다. Variable의 creator에서 함수를 얻어오고, 그 함수의 입력 변수를 가져옵니다. 그런 다음 함수의 backward메서드를 호출합니다. 마지막으로 자신보다 하나 앞에 놓인 변수의 backward 메서드를 호출합니다. 이런 식으로 각 변수의 backward 메서드가 재귀적으로 불리게 됩니다.
Variable 인스턴스의 creator 가 None이면 역전파가 중단됩니다. 창조자가 없으므로 이 Variable 인스턴스는 함수 바깥에서 생성했음을 뜻합니다(높은 확률로 사용자가 만들어 건네 변수일 것입니다).
이제 새로워진 Variable을 이용하여 역전파가 자동으로 실행되는 모습을 보겠습니다.
A = Square()
B = Exp()
C = Square()
x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
# 역전파
y.grad = np.array(1.0)
y.backward()
print(x.grad)
3.297442541400256
이와 같이 변수 y의 backward 메서드를 호출하면 역전파가 자동으로 진행됩니다. 실행 결과도 지금까지와 동일합니다. 축하합니다! 여러분은 방금 DeZero에서 가장 중요한 개념인 자동 미분의 기초를 완성했습니다.
댓글 없음:
댓글 쓰기