페이지

2022년 8월 19일 금요일

16.1 세대 추가

 먼저 Variable 클래스와 Function 클래스에 인스턴스 변수 generation을 추가하겠습니다. 몇 번째 '세대'의 함수(혹은 변수)인지 나타내는 변수죠. Variable 클래스부터 시작하겠습니다.


class Variable:

  def __init__(selfdata):
    if data is not None:
      if not isinstance(data, np.ndarray):
        raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
    self.data = data
    self.grad = None
    self.creator = None
    self.generation = 0     # 세대 수를 기록하는 변수

  def set_creator(selffunc):
    self.creator = func
    self.generation = func.generation + 1     #  세대를 기록한다(부모 세대 + 1).


  def backward(self):
    if self.grad is None:
      self.grad = np.ones_like(self.data)

    funcs = [self.creator]
    while funcs:
      f = funcs.pop()   #
      gys = [output.grad for output in f.outputs]  
      gxs = f.backward(*gys)   
      if not isinstance(gxs, tuple):  
        gxs = (gxs,)
      
      for x, gx in zip(f.inputs, gxs):  
        if x.grad is None:
          x.grad = gx
        else:
          x.grad = x.grad + gx

        if x.creator is not None:
          funcs.append(x.creator)   #

  def cleargrad(self):
    self.grad = None


Variable 클래스는 generation을 0으로 초기화합니다. 그리고 set_creator 메서드가 호출될 때 부모 함수의 세대보다 1만큼 큰 값을 설정합니다. 예를 들어 [그림 16-1]처럼 f.generation이 2인 함수에서 만들어진 변수인 y의 generation 은 3이 됩니다. 이상이 Variable클래스에 추가되는 구현입니다.


다음 차례는 Function 클래스입니다.  Function클래스의 generation은 입력 변수와 같은 값으로 설정합니다. 예를 들어 [그림 16-2]의 왼쪽처럼 입력 변수의 generation 이 4라면 함수의 generation도 4가 됩니다.


입력 변수가 둘 이사이라면 가장 큰 generation의 수를 선택합니다. 예를 들어 [그림 16-2]의 오른쪽처럼 입력 변수가 2개고 각각의 generation이 3과 4라면 함수 D의 generation은 4로 설정합니다. 다음은 이상의 설계를 반영한 Function 클래스의 코드입니다.

class Function(object):
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)  
    if not isinstance(ys, tuple):   
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    self.generation = max([x.generation for x in inputs])  # generation 설정

    for output in outputs:
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = outputs

    return outputs if len(outputs) > 1 else outputs[0]

이 코드의 음영 부분에서 Function의 generation을 설정했습니다.

댓글 없음: