페이지

2022년 8월 20일 토요일

18.1 필요 없는 미분값 삭제

 첫 번째로 DeZero의 역전파를 개선하겠습니다. 현재의 DeZero에서는 모든 변수가 미분값을 변수에 저장해두고 있습니다. 다음 예를 보시죠.

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)
print(x0.grad, x1.grad)

1.0 1.0 2.0 1.0

여기에서 사용자가 제공한 변수는 x0와 x1이며, 다른 변수 t와 y는 계산 결과로 만들어집니다. 그리고 y.backward()를 실행하여 미분하면 모든 변수가 미분 결과를 메모리에 유지합니다. 그러나 많은 경우, 특히 머신러닝에서는 역전파로 구하고 싶은 미분값은 말단 변수(x0, x1)뿐 일때가 대부분입니다. 앞의 예에서는 y와 t같은 중간 변수의 미분값은 필요하지 않습니다. 그래서 중간 변수에 대해서는 미분값을 제거하는 모드를 추가하겠습니다. 현재의 Variabel 클래스의 backward메서드에 다음 음영 부분의 코드를 추가하면 됩니다.

class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0     

  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     


  def backward(selfretain_grad=False):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   

      # 수정전 gys = [output.grad for output in f.outputs]  
      gys = [output().grad for output in f.outputs]  #
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)
      
      if not retain_grad:
        for y in f.outputs:
          y().grad = None   # y는 약한 참조(weakref)

  def cleargrad(self):
    self.grad = None


우선 메서드의 인수에 retain_grad 를 추가합니다. 이 retain_grad가 True면 지금까지처럼 모든 변수가 미분 결과(기울기)를 유지합니다. 반면 retain_grad가 False면(기본값) 중간 변수의 미분값을 모두 None으로 재설정합니다. 그 원리는 앞의 코드에서 보듯 backward 메서드의 마지막 for문으로, 각 함수의 출력 변수의 미분값을 유지하지 않도록 y().grad = None으로 설정하는 것입니다. 이렇게 하면 말단 변수 외에는 미분값을 유지하지 않습니다.


앞 코드의 마지막 y().grad = None에서 y에 접근할 때 y()라고 한 이유는 y가 약한 참조이기 때문입니다(약한 참조 구조는 이전 단계에서 도입했습니다). y().grad = None코드가 실행되면 참조 카운트가 0이  되어 미분값 데이터가 메모리에서 삭제됩니다.


이제 앞에서 실행했던 코드를 다시 실행해보죠.

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)
print(x0.grad, x1.grad)

None None 2.0 1.0

이와 같이 중간 변수인 y와 t의 미분값이 삭제되어 그만큼의 메모리를 다른 용도로 사용할 수 있게 됩니다. 이렇게 DeZero의 메모리 사용에 관한 첫 번째 개선이 완성되었습니다. 다음은 두 번째 개선 차례지만, 그에 앞서 잠시 현재의 Function클래스를 복습해보겠습니다.

댓글 없음: