페이지

2022년 8월 26일 금요일

20.2 연산자 오버로드

먼저 곱셈 연산자 *를 오버로드하겠습니다. 곱셈의 특수 메서드는 __mul__(self, other) 입니다(인수 self와 other에 대해서는 조금 뒤에 설명합니다). __mul__메서드를 정의(구현)하면 * 연산자를 사용할 때 __mul__ 메서드가 호출됩니다. 시험 삼아 Variable 클래스의 __mul__ 메서드를 다음과 같이 구현해보겠습니다.


class Variable:
  def __init__(selfdataname=None):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{} is not supported'format(type(data)))
    
    self.data = data
    self.name = name
    self.grad = None
    self.creator = None
    self.generation = 0
    
  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     


  def backward(selfretain_grad=False):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   

      # 수정전 gys = [output.grad for output in f.outputs]  
      gys = [output().grad for output in f.outputs]  #
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)
      
      if not retain_grad:
        for y in f.outputs:
          y().grad = None   # y는 약한 참조(weakref)

  def cleargrad(self):
    self.grad = None

  
  @property
  def shape(self):
    return self.data.shape

  @property
  def ndim(self):
    return self.data.ndim
  
  @property
  def size(self):
    return self.data.size

  @property
  def dtype(self):
    return self.data.dtype


  def __len__(self):
    return len(self.data)

  def __repr__(self):
    if self.data is None:
      return 'variable(None)'
    
    p = str(self.data).replace('\n''\n' + ' ' * 9)
    return 'variable('+ p + ')'

  def __mul__(selfother):
    return mul(self, other)


지금까지 구현한 Variable 클래스에 이 __mul__ 메서드를 추가합니다. 이제부터 *를 사용하면 __ㅡmul__ 메서드 대신 불리고, 다시 그 안의 mul 함수가 불리게 됩니다. 시험해볼까요?


a = Variable(np.array(3.0))
b = Variable(np.array(2.0))
y = a * b
print(y)

variable(6.0)



보다시피 y = a * b라는 코드를 문제없이 실행할 수 있습니다. a * b가 실행될 때 인스턴스 a의 __mul__(self, other) 메서드가 호출됩니다. 이때 [그림 20-2]와 같이 연산자 * 왼쪽의 a가 인수 self에 전달되고, 오른쪽 b가 other에 전달됩니다.


앞의 예에서 a * b가 실행되면 먼저 인스턴스 a의 특수 메서드인 __mul__ 호출됩니다. 그런데 만약 a에 __mul__ 메서드가 구현되어 있지 않으면 인스턴스 b의 * 연산자 특수 메서드가 호출됩니다. 이 경우 b는 * 연산자의 오른쪽에 위치하기 때문에 __mul__이 아닌 __mul__이라는 특수 메서드가 호출됩니다(메서드 이름 앞에 오른쪽(right)을 뜻하는 'r'이 붙어 있습니다).


이상으로 * 연산자를 오버로드해봤습니다. 정확히는 Variable 클래스의 __mul__ 메서드를 구현했습니다. 그런데 이와 똑같은 작업을 다음 코드처럼 간단히 처리하는 방법도 있습니다.

class Variable:
  def __init__(selfdataname=None):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{} is not supported'format(type(data)))
    
    self.data = data
    self.name = name
    self.grad = None
    self.creator = None
    self.generation = 0
    
  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     


  def backward(selfretain_grad=False):
    if self.grad is None:
      self.grad = np.ones_like(self.data)


    funcs = []
    seen_set = set()

    def add_func(f):
      if f not in seen_set:
        funcs.append(f)
        seen_set.add(f)
        funcs.sort(key=lambda x: x.generation)
    
    add_func(self.creator)

    while funcs:
      f = funcs.pop()   

      # 수정전 gys = [output.grad for output in f.outputs]  
      gys = [output().grad for output in f.outputs]  #
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          add_func(x.creator)
      
      if not retain_grad:
        for y in f.outputs:
          y().grad = None   # y는 약한 참조(weakref)

  def cleargrad(self):
    self.grad = None

  
  @property
  def shape(self):
    return self.data.shape

  @property
  def ndim(self):
    return self.data.ndim
  
  @property
  def size(self):
    return self.data.size

  @property
  def dtype(self):
    return self.data.dtype


  def __len__(self):
    return len(self.data)

  def __repr__(self):
    if self.data is None:
      return 'variable(None)'
    
    p = str(self.data).replace('\n''\n' + ' ' * 9)
    return 'variable('+ p + ')'


  Variable.__mul__ = mul
  Variable.__add__ = add


Variable 클래스를 정의한 후 Variable.__mul__ = mul이라고 작성하면 끝! 파이썬에서는 함수도 객체이므로 이와 같이 함수 자체를 할당할 수 있습니다. 이렇게 하면 Variagble 인스턴스의 __mul__ 메서드를 호출할 때 mul 함수가 불립니다.

앞의 코드에서 + 연산자의 특수 메서드인 __add__도 설정했습니다. + 연샂자도 함께 오버로드한 것이죠. 그럼 + 와 * 를 모두 사용하여 계산을 해 보겠습니다.


a = Variable(np.array(3.0))
b = Variable(np.array(2.0))
c = Variable(np.array(1.0))

y = add(mul(a, b), c)
# y = a * b + c
y.backward()

print(y)
print(a.grad)
print(b.grad)

variable(7.0) 2.0 3.0


보다시피 y = a * b + c 형태로 코등하는게 가능해졌습니다. 계산 시 + 와  * 를 자유롭게 사용할 수 있게 된 것이죠, / 와 - 같은 다른 연산자도 같은 방식으로 구현할 수 있습니다. 그럼 다음 단계에서도 계속 연산자 오버로드를 살펴보겠습니다.

댓글 없음: