페이지

2022년 8월 12일 금요일

7.3 backward 메서드 추가

 방금 보여드린 역전파 코드에는 똑같은 처리 흐름이 반복해서 나타났습니다. 변수에서 하나 앞의 변수로 거슬러 올라가는 로직이 그러했습니다. 그러므로 이 반복 작업을 자동화할 수 있도록 Variable 클래스에 badckward라는 새로운 메서드를 추가하겠습니다.


class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func

  def backward(self):
    f = self.creator    # 1. 함수를 가져온다.
    if f is not None:
      x = f.input       # 2. 함수의 입력을 가져온다.
      x.grad = f.backward(self.grad)    # 3. 함수의 backward 메서드를 호출한다.
      x.backward()      # 하나 앞 변수의 backward 메서드를 호출한다(재귀)


backward메서드는 지금까지 반복한 처리 흐름과 거의 동일합니다. Variable의 creator에서 함수를 얻어오고, 그 함수의 입력 변수를 가져옵니다. 그런 다음 함수의 backward메서드를 호출합니다. 마지막으로 자신보다 하나 앞에 놓인 변수의 backward 메서드를 호출합니다. 이런 식으로 각 변수의 backward 메서드가 재귀적으로 불리게 됩니다.

Variable 인스턴스의 creator 가 None이면 역전파가 중단됩니다. 창조자가 없으므로 이 Variable 인스턴스는 함수 바깥에서 생성했음을 뜻합니다(높은 확률로 사용자가 만들어 건네 변수일 것입니다).

이제 새로워진 Variable을 이용하여 역전파가 자동으로 실행되는 모습을 보겠습니다.

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

# 역전파
y.grad = np.array(1.0)
y.backward()
print(x.grad)

3.297442541400256

이와 같이 변수 y의 backward 메서드를 호출하면 역전파가 자동으로 진행됩니다. 실행 결과도 지금까지와 동일합니다. 축하합니다! 여러분은 방금  DeZero에서 가장 중요한 개념인 자동 미분의 기초를 완성했습니다.

댓글 없음: