페이지

2022년 8월 20일 토요일

17.4 weakref 모듈

 파이썬에서는 weakref.ref 함수를 사용하여 약한 참조(weak reference)를 만들 수 있습니다. 약한 참조란 다른 객체를 참조하되 참조 카운트는 증가시키지 않는 기능입니다. 다음은 weakref.ref함수를 사용하는 예입니다.

import weakref
import numpy as np

a = np.array([123])
b = weakref.ref(a)

b

<weakref at 0x7f640483f410; to 'numpy.ndarray' at 0x7f640483f210>

b()
array([1, 2, 3])


a = None
b
<weakref at 0x7f640483f410; to 'numpy.ndarray' at 0x7f640483f210>

nbarray 인스턴스를 대상으로 실험을 해봤습니다. 먼저 a는 일반적인 방식으로 참조하고, 다음으로 b는 약한 참졸르 갖게 했습니다. 이 상태로 b를 출력해보면 ndarray를 가리키는 약함참조(weakref)임을 확인할 수 있습니다. 참고로, 참조된 데이터에 접근하려면 b()라고 쓰면 됩니다.

그럼 앞의 코드에 바로 이어서 a = None을 실행하면 어떻게 될까요? 결과는 다음과 같습니다.


이와 같이 ndarray 인스턴스는 참조 카운트 방식에 따라 메모리에서 삭제됩니다. b도 참조를 가지고 있지만 약한 참조이기 때문에 참조 카운트에 영향을 주지 못하는 것이죠. 그래서 b를 출력하면 dead라는 문자가 나오고, 이것은 ndarray 인스턴스가 삭제됐음을 알 수 있습니다.


지금까지의 약한 참조 실험 코드는 파이썬 인터프리터에서 실행한다고 가정했습니다. IPython과 주피터 노트북(Jupyter Notebook)등의 인터프리터는 인터프리터 자체가 사용자가 모르는 참조를 추가로 유지하기 때문에 앞의 코드에서 b가 여전히 유효한 참조를 유지할 것입니다(dead가 되지 않습니다)


이 weakref 구조를 DeZero에서 도입하려 합니다. 먼저 Function에 다음 음영부분을 추가합니다.


import weakref    #
class Function(object):
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)  
    if not isinstance(ys, tuple):   
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    self.generation = max([x.generation for x in inputs]) 

    for output in outputs:
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = [weakref.ref(output) for output in outputs]  #
   
    return outputs if len(outputs) > 1 else outputs[0]

이와 같이 인스턴스 변수 self.outputs가 대상을 약한 참조로 가리키게 변경합니다. 그 결과 함수는 출력 변수를 약하게 참조합니다. 또한 이 변경의 여파로 다른 클래스에서 Function클래스의 outputs를 참조하는 코드로 수정해야 합니다. DeZero에서는 Variable클래스의 backward메서드를 다음처럼 수정하면 됩니다.


class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0     

  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     


  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   

      # 수정전 gys = [output.grad for output in f.outputs]  
      gys = [output().grad for output in f.outputs]  #
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)

  def cleargrad(self):
    self.grad = None


이와 같이 [output.grad for ,...] 부분을 [output().grad for ...]로 수정합니다. 이상으로 DeZero의 순환 참조 문제가 해결되었습니다.

댓글 없음: