페이지

2022년 8월 14일 일요일

11.2 Add 클래스 구현

 이번 절에서는 Add클래스의 forward메서드를 구현합니다. 주의할 점은 인수와 반환값이 리스트(또는 튜플)여야 한다는 것입니다. 이 조건을 반영하여 다음 처럼 구현할 수 있습니다.

class Add(Function):
  def forward(selfxs):
    x0, x1 = xs
    y = x0 + x1
    return (y,)

Add 클래스의 인수는 변수가 두개 담긴 리스트입니다. 따라서 x0, x1 = xs형태로 리스트 xs에서 원소 두 개를 꺼냈습니다. 그런 다음 꺼낸 원소들을 사용하여 계산합니다. 결과를 반환할 때는 return(y,)형태로 튜플을 반환합니다(return y,처럼 괄호는 생략해도 됩니다). 수정한 Add클래스는 다음과 같이 사용할 수 있습니다.

xs = [Variable(np.array(2)), Variable(np.array(3))]   # 리스트로 준비
f = Add()
ys = f(xs)      # ys 튜플
y = ys[0]
print(y.data)

5

보시는 것처럼 2 + 3 = 5 계산을 DeZero로 재대로 처리할 수 있게 되었습니다. 입력을 리스트로 바꿔서 여러 개의 변수를 다룰 수 있게 하였고, 출력은 튜플로 바꿔서 역시 여러 개의 변수에 대응할 수 있게 했습니다. 이제 순전파에 한해서는 가변 길이 인수와 반환값에 대응할 수 있을 것입니다. 그런데 앞의 코드를 보면 다소 귀찮은 느낌이 듭니다. 왜냐하면 Add클래스를 사용하는 사람에게 입력 변수를 리스트에 담아 건네주라고 요구하거나 반환값으로 튜플을 받게 하는 것은 자연스럽지 않기 때문입니다. 그래서 다음 단계에서는 더 자연스러운 코드로 쓸 수 있도록 지금의 구현을 개선하겠습니다.

댓글 없음: