페이지

2022년 8월 19일 금요일

13.2 Variable 클래스 수정

 그럼 Variable 클래스의 backward메서드를 살펴보겠습니다. 복습할 겸 Variable클래스의 현재 코드를 먼저 보여드리죠

class Variable:
  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func
    
  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)
    
    funcs = [self.creator]
    while funcs:
      f = funcs.pop()
      x, y = f.input, f.output  # 1 함수의 입출력을 얻는다.
      x.grad = f.backward(y.grad)   # 2 backward 메서드를 호출한다.

      if x.creator is not None:
        funcs.append(x.creator)

여기서 주목할 곳은 음영 부분입니다. 우선 while 블럭 안의 1)에서 함수의 입출력 변수를 꺼냅니다. 그리고 2)에서 함수의 backward메서드를 호출 합니다. 지금까지 우리는 1)에서 함수의 입출력이 하나씩이라고 한정했습니다. 이부분을 여러 개의 변수에 대응할 수 있도록 수정하겠습니다.

class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func

    
  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)

    funcs = [self.creator]
    while funcs:
      f = funcs.pop()
      gys = [output.grad for output in f.outputs]   # 1)
      gxs = f.backward(*gys)    # 2)
      if not isinstance(gxs, tuple):  # 3)
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  # 4)
        x.grad = gx

        if x.creator is not None:
          funcs.append(x.creator)


총 네 군데를 수정했습니다. 우선 1)에서 출력 변수인 outputs에 담계 잇는 미분값들을 리스트에 담습니다. 그리고 2)에서 함수 f의 역전파를 호출합니다. 이때 f.backward(*gys)처럼 인수에 별표를 붙여 호출하여 리스트를 풀어줍니다(리스트 언팩). 3) 에서 gxs가 튜플이 아니라면 튜플로 변환합니다.


2)와 3)은 이전 단계에서 순전파 개선 시 활용한 관례와 같습니다. 2)에서 Add클래스의 backward메서드를 호출할 때  인술르 플어서 전달합니다. 3)에서는 Add 클래스의 backward 메서드가 튜플이 아닌 해당 원소를 직접 반환할 수 있게 합니다.


4)에서는  역전파로 전파되는 미분값을 Variable의 인스턴스 변수 grad에 저장해둡니다. 여리게서 gxs와 f.inputs의 각 원소는 서로 대응 관계에 있습니다. 더 정확히 말하면 i번째 원소에 대해 f.inputs[i]의 미분값은 gxs[i]에 대응합니다. zip 함수와 for 문을 이용해서 모든 Variable  인스턴스 각각에 알맞은 미분값을 설정한 것입니다. 이상이 Variable 클래스의 새로운 backward 메서드입니다.

댓글 없음: