페이지

2022년 8월 13일 토요일

9.2 backward 메서드 간소화

 두 번재 개선은 역전파 시 사용자의 번거로움을 줄이기 위한 것입니다. 구체적으로는 방금 작성한 코드에서 y.grad = np.array(1.0) 부분을 생략하려 합니다. 지금까지는 역전파할 때마다 y.grad = np.array(1.0)이라는 코드를 작성해야 했습니다. 이 코드를 생략할 수 있도록 Variable의 backward메서드에 다음 두 줄을 추가합니다.

class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func
  
  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)

    funcs = [self.creator]
    while funcs:
      f = funcs.pop()
      x, y = f.input, f.output
      x.grad = f.backward(y.grad)

      if x.creator is not None:
        funcs.append(x.creator)

이와 같이 만약 변수의 grad가  None이면 자동으로 미분값을 생성합니다. np.ones_like(self, data)코드는 self.data와 형상과 데이터 타입이 같은 ndarray 인스턴스를 생성하는데, 모든 요소를 1로 채워서 돌려줍니다. self.data가 스칼라이면 self.grad도 스칼라가 됩니다.


이전까지는 출력의 미분값을 nparray(1.0)으로 사용했지만, 방금 코드에서는 np.ones_like()를 썼습니다. 그 이유는 Variable의 data와 grad의 데이터 타입을 같게 만들기 위해서입니다. 예를 들어 data의 타입이 32비트 부동소스점 숫자이면 grad의 타입도 32비트 부동소수점 숫자가 됩니다. 참고로 nparray(1.0)은 64비트 부동소수점 숫자 타입으로 만들어 줍니다.


이제 어떤 계산을 하고 난 뒤의 최종 출력 변수에서 backward 메서드를 호출하는 것만으로 미분값이 구해집니다. 실제로 돌려보죠.


x = Variable(np.array(0.5))
y = square(exp(square(x)))
y.backward()
print(x.grad)

3.297442541400256

댓글 없음: