페이지

2022년 8월 21일 일요일

18.2 Function 클래스 복습

 DeZero에서 미분을 하려면 순전파를 수행한 뒤 역전파 해주면 됩니다. 그리고 역전파 시에는 순전파의 계산 결과가 필요하기 때문에 순전파 때 결괐값을 기억해둡니다. 결괏값을 보관하는 로직은 바로  Function클래스의 다음 음영 부분입니다.

import weakref    
class Function(object):
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)  
    if not isinstance(ys, tuple):   
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    self.generation = max([x.generation for x in inputs]) 

    for output in outputs:
      output.set_creator(self)  
    self.inputs = inputs  #
    self.outputs = [weakref.ref(output) for output in outputs]  
   
    return outputs if len(outputs) > 1 else outputs[0]

이와 같이 함수는 입력을 inputs라는 '인스턴스 변수'로 참조합니다. 그 결과 inputs가 참조하는 변수의 참조 카운트가 1만큼 증가하고, __call__ 메서드에서 벗어난 뒤에도 메모리에 생존합니다. 만약 인스턴스 변수인 inputs로 참조하지 않았다면 참조 카운트가 0이 되어 메모리에서 삭제됐을 겁니다.


인스턴스 변수 inputs는 역전파 계산 시 사용됩니다.. 따라서 역전파하는 경우라면 참조할 변수들을 inputs에 미리 보관해둬야 합니다. 하지만 때로는 미분값이 필요 없는 경우도 있습니다. 이런 경우라면 중간 계산 결과를 저장할 필요가 없고, 계산의 '연결'또한 만들 이유가 없습니다.


신경망에는 학습(training)(혹은 훈련)과 추론(inference)이라는 두 가지단계가 있습니다. 학습 시에는 미분값을 구해야 하지만 추론 시에는 단순히 순전파만 하기 때문에 중간 계산 결과를 곧바로 버리면 메모리 사용량을 크게 줄일 수 있습니다.

2022년 8월 20일 토요일

18.1 필요 없는 미분값 삭제

 첫 번째로 DeZero의 역전파를 개선하겠습니다. 현재의 DeZero에서는 모든 변수가 미분값을 변수에 저장해두고 있습니다. 다음 예를 보시죠.

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)
print(x0.grad, x1.grad)

1.0 1.0 2.0 1.0

여기에서 사용자가 제공한 변수는 x0와 x1이며, 다른 변수 t와 y는 계산 결과로 만들어집니다. 그리고 y.backward()를 실행하여 미분하면 모든 변수가 미분 결과를 메모리에 유지합니다. 그러나 많은 경우, 특히 머신러닝에서는 역전파로 구하고 싶은 미분값은 말단 변수(x0, x1)뿐 일때가 대부분입니다. 앞의 예에서는 y와 t같은 중간 변수의 미분값은 필요하지 않습니다. 그래서 중간 변수에 대해서는 미분값을 제거하는 모드를 추가하겠습니다. 현재의 Variabel 클래스의 backward메서드에 다음 음영 부분의 코드를 추가하면 됩니다.

class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0     

  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     


  def backward(selfretain_grad=False):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   

      # 수정전 gys = [output.grad for output in f.outputs]  
      gys = [output().grad for output in f.outputs]  #
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)
      
      if not retain_grad:
        for y in f.outputs:
          y().grad = None   # y는 약한 참조(weakref)

  def cleargrad(self):
    self.grad = None


우선 메서드의 인수에 retain_grad 를 추가합니다. 이 retain_grad가 True면 지금까지처럼 모든 변수가 미분 결과(기울기)를 유지합니다. 반면 retain_grad가 False면(기본값) 중간 변수의 미분값을 모두 None으로 재설정합니다. 그 원리는 앞의 코드에서 보듯 backward 메서드의 마지막 for문으로, 각 함수의 출력 변수의 미분값을 유지하지 않도록 y().grad = None으로 설정하는 것입니다. 이렇게 하면 말단 변수 외에는 미분값을 유지하지 않습니다.


앞 코드의 마지막 y().grad = None에서 y에 접근할 때 y()라고 한 이유는 y가 약한 참조이기 때문입니다(약한 참조 구조는 이전 단계에서 도입했습니다). y().grad = None코드가 실행되면 참조 카운트가 0이  되어 미분값 데이터가 메모리에서 삭제됩니다.


이제 앞에서 실행했던 코드를 다시 실행해보죠.

x0 = Variable(np.array(1.0))
x1 = Variable(np.array(1.0))
t = add(x0, x1)
y = add(x0, t)
y.backward()

print(y.grad, t.grad)
print(x0.grad, x1.grad)

None None 2.0 1.0

이와 같이 중간 변수인 y와 t의 미분값이 삭제되어 그만큼의 메모리를 다른 용도로 사용할 수 있게 됩니다. 이렇게 DeZero의 메모리 사용에 관한 첫 번째 개선이 완성되었습니다. 다음은 두 번째 개선 차례지만, 그에 앞서 잠시 현재의 Function클래스를 복습해보겠습니다.

STEP 18 메모리 절약 모드

 이전 단계에서는 파이썬의 메모리 관리 방식에 대해 알아봤습니다. 이번 단계에서는 DeZero의 메모리 사용을 개선할 수 있는 구조 두 가지를 도입합니다. 첫 번째는 역전파 시 사용하는 메모리양을 줄이는 방법으로, 불필요한 미분 결과를 보관하지 않고 즉시 삭제합니다. 두 번째는 '역전파가 필요 없는 경우용 모드'를 제공하는 것입니다. 이 모드에서는 불필요한 계산을 생략합니다.

17.5 동작 확인

 순환 참조가 없어진 새로운 DeZero에서 다음 코드를 실행해보죠.

for i in range(10)
  x = Variable(np.random.randn(10000))    # 거대한 데이터
  y = square(square(square(x)))           # 복잡한 계싼을 수행한다

for 문을 사용하여 계산을 반복해 수행했습니다. 이 반복문은 [그림 17-4]와 같이 복잡한 참조 구조를 만들어냅니다.



그리고 for 문이 두 번째 반복될 때 x와 y가 덮어 써집니다. 그러면 사용자는 이전의 계산 그래프를 더 이상 참조하지 않게 되죠. 참조 카운트가 0이 되므로 이 시점에 계산 그래프에 사용된 메모리가 바로 삭제됩니다. 이것으로 DeZero 순환 참조 문제가 해소되었습니다.


파이썬으로 메모리 사용량을 측정하려면 외부 라이브러리인 memory porfiler등을 사용하면 편리합니다. 방금 전의 코드를 실제로 측정해 보면 메모리 사용량이 전혀 증가하지 않았음을 확인할 수 있을 겁니다.

17.4 weakref 모듈

 파이썬에서는 weakref.ref 함수를 사용하여 약한 참조(weak reference)를 만들 수 있습니다. 약한 참조란 다른 객체를 참조하되 참조 카운트는 증가시키지 않는 기능입니다. 다음은 weakref.ref함수를 사용하는 예입니다.

import weakref
import numpy as np

a = np.array([123])
b = weakref.ref(a)

b

<weakref at 0x7f640483f410; to 'numpy.ndarray' at 0x7f640483f210>

b()
array([1, 2, 3])


a = None
b
<weakref at 0x7f640483f410; to 'numpy.ndarray' at 0x7f640483f210>

nbarray 인스턴스를 대상으로 실험을 해봤습니다. 먼저 a는 일반적인 방식으로 참조하고, 다음으로 b는 약한 참졸르 갖게 했습니다. 이 상태로 b를 출력해보면 ndarray를 가리키는 약함참조(weakref)임을 확인할 수 있습니다. 참고로, 참조된 데이터에 접근하려면 b()라고 쓰면 됩니다.

그럼 앞의 코드에 바로 이어서 a = None을 실행하면 어떻게 될까요? 결과는 다음과 같습니다.


이와 같이 ndarray 인스턴스는 참조 카운트 방식에 따라 메모리에서 삭제됩니다. b도 참조를 가지고 있지만 약한 참조이기 때문에 참조 카운트에 영향을 주지 못하는 것이죠. 그래서 b를 출력하면 dead라는 문자가 나오고, 이것은 ndarray 인스턴스가 삭제됐음을 알 수 있습니다.


지금까지의 약한 참조 실험 코드는 파이썬 인터프리터에서 실행한다고 가정했습니다. IPython과 주피터 노트북(Jupyter Notebook)등의 인터프리터는 인터프리터 자체가 사용자가 모르는 참조를 추가로 유지하기 때문에 앞의 코드에서 b가 여전히 유효한 참조를 유지할 것입니다(dead가 되지 않습니다)


이 weakref 구조를 DeZero에서 도입하려 합니다. 먼저 Function에 다음 음영부분을 추가합니다.


import weakref    #
class Function(object):
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)  
    if not isinstance(ys, tuple):   
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    self.generation = max([x.generation for x in inputs]) 

    for output in outputs:
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = [weakref.ref(output) for output in outputs]  #
   
    return outputs if len(outputs) > 1 else outputs[0]

이와 같이 인스턴스 변수 self.outputs가 대상을 약한 참조로 가리키게 변경합니다. 그 결과 함수는 출력 변수를 약하게 참조합니다. 또한 이 변경의 여파로 다른 클래스에서 Function클래스의 outputs를 참조하는 코드로 수정해야 합니다. DeZero에서는 Variable클래스의 backward메서드를 다음처럼 수정하면 됩니다.


class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0     

  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     


  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   

      # 수정전 gys = [output.grad for output in f.outputs]  
      gys = [output().grad for output in f.outputs]  #
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)

  def cleargrad(self):
    self.grad = None


이와 같이 [output.grad for ,...] 부분을 [output().grad for ...]로 수정합니다. 이상으로 DeZero의 순환 참조 문제가 해결되었습니다.

17.3 순환 참조

 다음은 순환 참조(circular reference)를 설명하기 위해 준비한 코드입니다.


앞서 보여드린 코드와 거의 같지만, 이번에는 c에서 a로의 참조가 추가됐습니다. 그래서 세 개의 객체가 원 모양을 이루며 서로가 서로를 참조하게 되는데, 이 상태가 바로 순환 참조입니다. 현재의 a, b, c관계는 [그림 17-2]와 같습니다.


[그림 17-2]의 오른쪽에서 a, b, c의 참조 카운트는 모두 1입니다. 하지만 사용자는 이들 세 객체중 어느 것에도 접근할 수 없습니다(즉, 모두 불필요한 객체입니다). 그러나 a = b = c = None을 실행하는 것으로 순환 참조의 참조 카운트가 0이 되지않고, 결과적으로 메모리에서 삭제되지 않습니다. 그래서 또 다른 메모리 관리 방식이 등장합니다. 이 주인공이 GC입니다(정확하게 '세대별 가비지 컬렉션(generational garbage collection).

GC는 참조 카운트보다 영리한 방법으로 불필요한 객체를 찾아냅니다(GC의 구조는 복잡하기 때문에 이 책에서는 설명을 생략합니다). GC는 참조 카운트와 달리 메모리가 부족해지는 시점에 파이썬 인터프리터에 의해 자동으로 호출됩니다. 물론 명시적으로 호출할 수도 있습니다(gc 모듈을 임포트해서 gc.collect()를 실행).

GC는 순환 참조를 올바르게 처리합니다. 따라서 일반적인 파이썬 프로그래밍에서는 순환 참조를 의식할 필요가 특별히 없습니다. 하지만 메모리 해제를 GC에 미루다 보면 프로그램의 전체 메모리의 사용량이(순환 참조가 없을때와 비교해) 커지는 원인이 됩니다(자세한 내용은 문헌[10]참고). 그런데 마침 머신러닝, 특히 신경망에서 메모리는 중요한 자원입니다. 따라서 DeZero를 개발할 때는 순환 참조를 만들지 않는 것이 좋겠지요.

이 정도면 파이썬 메모리 관리에 관한 지식은 충분한 것 같습니다. 그럼 DeZero로 눈을 돌려볼까요? 사실 현재의 DeZero에는 순환 참조가 존재합니다. 바로 [그림 17-3]과 같이 '변수'와 '함수'를 연결하는 방식에 순환 참조가 숨어 있습니다.


[그림 17-3]에서 보듯Function인스턴스는 두 개의 Variable 인스턴스(입력과 출력)를 참조합니다. 그리고 출력 Variable인스턴스는 창조자인 Function인스턴스를 참조합니다. 이때 Function인스턴스와 Variable 인스턴스가 순환 참조 관계를 만듭니다. 다행이 이 순환 참조는 표준 파이썬 모듈인 weakref로 해결할 수 있습니다.

17.2 참조 카운트 방식의 메모리 관리

 파이썬 메모리 관리의 기본은 참조 참조카운트입니다. 참조 카운트는 구조가 간단하고 속도도 빠릅니다. 모든 객체는 참조 카운트가 0인 상태로 생성되고, 다른 객체가 참조할 때마다 1씩 증가합니다. 반대로 객체에 대한 참조가 끊길 때마다 1만금 감소하다가 0이 되면 파이썬 인터프리터가 회수해갑니다. 이런 방식으로 객체가 더 이상 필요 없어지면 즉시 메모리에서 삭제됩니다. 이상이 참조 카운트 방식의 매모리 관리입니다.

참고로 가령 다음과 같은 경우에 참조 카운트가 증가합니다.

1) 대입 연산자를 사용할 때

2) 함수에 인수로 전달할때

3) 컨테이너 타입 객체(리스트, 튜플, 클래스 등)에 추가할때

코드로도 예를 준비했습니다(개념을 설명하기 위한 의사코드라서 동작하지 않습니다).

class obj:
  pass

  def f(x):
    print(x)

a = obj()   # 변수에 대입: 참조 카운트 1
f(a)        # 함수에 전달: 함수 안에서 참조 카운트 2
            # 함수 완료: 빠져나오면 참조 카운트 1
a = None    # 대입 해제: 참조 카운트 0

먼저 obj()에 의해 생성된 객체를 a에 대입했습니다. 그러면 이 객체의 참조 카운트는 1입니다. 다음 줄에서 함수 f(a)를 호출하는데, 이때 a가 인수로 전달되기 때문에 함수 f의 범위 안에서는 참조 카운트가 1 증가합니다(총 2). 그리고 함수의 범위를 벗어나면 참조 카운트가 다시 1 감소합니다. 마지막으로 a = None에서 참조를 끊으면 결국 0이 됩니다(아무도 참조하지 않은 상태). 이렇게 0이 되는 즉시 해당 객체는 메모리에서 삭제됩니다.

보다시피 참조 카운트 방식은 간단합니다. 그리고 이 간단한 방식을 상용하여 수많은 메모리 문제를 해결할 수 있습니다.  다음 코드를 보시죠

a = obj()
b = obj()
c = obj()

a.b = b
b.c = c

a = b = c = None

a, b, c라는 세 개의 객체를 생성했습니다. 그리고 a가 b를 참조하고, b가 c를 참조합니다. 자. 이제 객채의 관곈느 [그림 17-1]의 왼쪽처럼 되었습니다.


그런 다음 a = b = c = None줄을 실행하면 객체의 관계는 [그림 17-1]의 오른쪽처럼 변함니다. 이때 a의 참조 카운트는 0이 됩니다(b와 c의 참조 카운트는 1입니다). 따라서 a는 즉시 삭제됩니다. 그 여파로 b의 참조 카운트가 1에서 0으로 감소하여 b역시 삭제됩니다. 똑같은 원리로 c의 참조 카운트로 0이 되어 삭제됩니다. 이렇게 사용자로부터 참조되지 않는 책체들이 마치 도미노처럼 한꺼번에 삭제되는 것입니다.

이상이 파이썬의 참조 카운트 방식 메모리 관리입니다. 이 기능이 수많은 메모리 관리 문제를 해결해 줍니다. 하지만 참조 카운트로는 해결할 수 없는 문제가 있으니, 바로 순환 참조입니다.