페이지

2022년 8월 21일 일요일

18.2 Function 클래스 복습

 DeZero에서 미분을 하려면 순전파를 수행한 뒤 역전파 해주면 됩니다. 그리고 역전파 시에는 순전파의 계산 결과가 필요하기 때문에 순전파 때 결괐값을 기억해둡니다. 결괏값을 보관하는 로직은 바로  Function클래스의 다음 음영 부분입니다.

import weakref    
class Function(object):
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)  
    if not isinstance(ys, tuple):   
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    self.generation = max([x.generation for x in inputs]) 

    for output in outputs:
      output.set_creator(self)  
    self.inputs = inputs  #
    self.outputs = [weakref.ref(output) for output in outputs]  
   
    return outputs if len(outputs) > 1 else outputs[0]

이와 같이 함수는 입력을 inputs라는 '인스턴스 변수'로 참조합니다. 그 결과 inputs가 참조하는 변수의 참조 카운트가 1만큼 증가하고, __call__ 메서드에서 벗어난 뒤에도 메모리에 생존합니다. 만약 인스턴스 변수인 inputs로 참조하지 않았다면 참조 카운트가 0이 되어 메모리에서 삭제됐을 겁니다.


인스턴스 변수 inputs는 역전파 계산 시 사용됩니다.. 따라서 역전파하는 경우라면 참조할 변수들을 inputs에 미리 보관해둬야 합니다. 하지만 때로는 미분값이 필요 없는 경우도 있습니다. 이런 경우라면 중간 계산 결과를 저장할 필요가 없고, 계산의 '연결'또한 만들 이유가 없습니다.


신경망에는 학습(training)(혹은 훈련)과 추론(inference)이라는 두 가지단계가 있습니다. 학습 시에는 미분값을 구해야 하지만 추론 시에는 단순히 순전파만 하기 때문에 중간 계산 결과를 곧바로 버리면 메모리 사용량을 크게 줄일 수 있습니다.

댓글 없음: