페이지

2022년 8월 19일 금요일

13.3 Square 클래스 구현

 지금까지 Variable과 Function클래스가 가변 길이 입출력을 지원하도록 개선했습니다. 그리고 구체적인 함수로서 Add 클래스를 구현했습니다. 마지막으로 Square클래스도 새로운 Variable과 Function클래스에 맞게 수정하겠습니다. 수정할 곳은 단 하나뿐입니다. 다음 코드에서 음영을 입힌 부분이죠.

class Square(Function):
  def forward(selfx):
    y = x ** 2
    return y

  def backward(selfgy):
    x = self.inputs[0].data #수정전: x = self.input.data
    gx = 2 * x * gy
    return gx

Function 클래스의 인스턴스 변수 이름이 단수형인 input에서 복수형인 inputs로 변경되었으니 바뀐 변수에서 입력 변수 x 를 가져오도록 코드를 수정해주면 됩니다. 이것으로 새로운 Square 클래스도 완성입니다. 그럼 add 함수와 sequare 함수를 실제로 사용해봅시다. 다음은 z = x**2 + y**2를 계산하는 코드입니다.

x = Variable(np.array(2.0))
y = Variable(np.array(3.0))

z = add(square(x), square(y))
z.backward()
print(z.data)
print(x.grad)
print(y.grad)

13.0 4.0 6.0


보다시피 DeZero를 사용하여 z = x**2  + y**2이라는 계산을 z = add(square(x), square(y))라는 코드로 풀어냈습니다. 그런 다음 z.backward()를 호출하기만 하면 미분 계산이 자동으로 이루어집니다!

이상에서 복수의 입출력에 대응한 자동 미분 구조를 완성했습니다. 이제 다른 함수들도 적절히 구현해주면 더 복잡한 계산도 가능할 것입니다. 그러나 사실 지금의 DeZero에는 몇 가지 문제가 숨어 있습니다. 다음 단계에서는 이 문제들을 먼저 해결하겠습니다.

댓글 없음: