페이지

2022년 8월 19일 금요일

16.3 Variable 클래스의 backward

 본론으로 돌아와서 Variable 클래스의 backward 메서드를 구현하겠습니다. 이전과 달라진 부분(음영)에 주목해서 살펴보죠.

class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0     # 세대 수를 기록하는 변수

  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     #  세대를 기록한다(부모 세대 + 1).


  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   #
      gys = [output.grad for output in f.outputs]  
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator) #수정 전 funcs.append(x.creator)   

  def cleargrad(self):
    self.grad = None

가장 큰 변화는 새로 추가된 add_func 함수입니다. 그 동안 'DeZero 함수'를 리스트에 추가할 때 funcs.append(f)를 호출했는데, 대신 add_func 함수를 호출하도록 변경했습니다. 이 add_func함수가 DeZero함수 리스트를 세대 순으로 정렬하는 역할을 합니다. 그 결과 funcs.pop()은 자동으로 세대가 가장 큰 DeZero함수를 꺼내게 됩니다.

참고로 add_func 함수를 backward메서드 안에 중첩 함수로 정의했습니다. 중첩 함수는 주로 다음 두 조건을 충족할 때 적합합니다.

1) 감싸는 메서드(backward 메서드)안에서만 이용한다.

2) 감싸는 메서드(backward 메서드)에 정의된 변수(funcs과 seen_set)를 사용해야 한다.


add_func 함수는 이 조건들을 모두 충족하기 때문에 메서드 안에 정의했습니다.


앞의 구현에서는 seen_set이라는 '집합(set)'을 이용하고 있습니다. funcs리스트에 같은 함수를 중복 추가하는 일을 막기위해서 입니다. 덕분에 함수의 backward메서드가 잘못되어 여러번 불리는 일은 발생하지 않습니다.

댓글 없음: