페이지

2022년 8월 15일 월요일

12.2 두 번째 개선: 함수를 구현하기 쉽도록

 두 번째는 Add 클래스를 '구현하는 사람'을 위한 개선입니다. 현재 Add 클래스를 구현하려면 [그림 12-2]의 왼쪽 처럼 작성해야 합니다.


왼쪽은 Add 클래스의 forward메서드의 코드입니다. 인수는 리스트로 전달되고 결과는 튜플을 반환하고 있습니다. 물론 오른쪽 코드가 더 바람직해 보입니다. 입력도 변수를 직접 받고 결과도 변수를 직접 돌려주는 것이죠. 이것이 두 번째 개선에서 할 일입니다.


두 번째 개선을 위해 Function클래스에서 다음 부분을 수정합니다.

class Function:
  def __call__(self, *inputs):
    xs = [x.data for x in inputs]
    ys = self.forward(*xs)  # 1 별표를 붙여 언팩
    if not isinstance(ys, tuple):   # 2 튜플이 아닌 경우 추가 지원
      ys = (ys,)
    outputs = [Variable(as_array(y)) for y in ys]

    for output in outputs:
      output.set_creator(self)
    self.inputs = inputs
    self.outputs = outputs

    return outputs if len(outputs) > 1 else outputs[0]

우선 1의 self.forward(*xs) 부분을 보죠. 함수를 '호출'할 때 별표를 붙였는데, 이렇게 하면 리스트 언팩(list unpack)이 이루어집니다. 언팩은 리스트의 원소를 낱개로 풀어서 전달하는 기법입니다. 예를 들어 xs = [x0, x1]일때 self.forward(*xs)를 하면 self.forward(x0, x1)로 호출하는 것과 동일하게 동작합니다.

이어서 2에서는 ys 가 튜플이 아닌 경우 튜플로 변경합니다. 이제 forward메서드는 반환원소가 하나뿐이라면 해당 원소를 직접 반환합니다. 이상의 수정으로 Add클래스를 다음처럼 구현할 수 있습니다.

class Add(Function):
  def forward(selfx0x1):
    y  = x0 + x1
    return y


이와 같이 순전파 메서드를 def forward(self, x0, x1): 이라고 정의할 수 있습니다. 결과는 return y처럼 하여 원소 하나만 반환하죠. 이제 Add클래스를 구현하는 사람에게DeZero는 더 쓰기 편한 프레임워크가 되었습니다. 이상으로 두 번째 개선을 마무리합니다.

댓글 없음: