지금까지 Variable과 Function클래스가 가변 길이 입출력을 지원하도록 개선했습니다. 그리고 구체적인 함수로서 Add 클래스를 구현했습니다. 마지막으로 Square클래스도 새로운 Variable과 Function클래스에 맞게 수정하겠습니다. 수정할 곳은 단 하나뿐입니다. 다음 코드에서 음영을 입힌 부분이죠.
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.inputs[0].data #수정전: x = self.input.data
gx = 2 * x * gy
return gx
Function 클래스의 인스턴스 변수 이름이 단수형인 input에서 복수형인 inputs로 변경되었으니 바뀐 변수에서 입력 변수 x 를 가져오도록 코드를 수정해주면 됩니다. 이것으로 새로운 Square 클래스도 완성입니다. 그럼 add 함수와 sequare 함수를 실제로 사용해봅시다. 다음은 z = x**2 + y**2를 계산하는 코드입니다.
x = Variable(np.array(2.0))
y = Variable(np.array(3.0))
z = add(square(x), square(y))
z.backward()
print(z.data)
print(x.grad)
print(y.grad)
13.0
4.0
6.0
보다시피 DeZero를 사용하여 z = x**2 + y**2이라는 계산을 z = add(square(x), square(y))라는 코드로 풀어냈습니다. 그런 다음 z.backward()를 호출하기만 하면 미분 계산이 자동으로 이루어집니다!
이상에서 복수의 입출력에 대응한 자동 미분 구조를 완성했습니다. 이제 다른 함수들도 적절히 구현해주면 더 복잡한 계산도 가능할 것입니다. 그러나 사실 지금의 DeZero에는 몇 가지 문제가 숨어 있습니다. 다음 단계에서는 이 문제들을 먼저 해결하겠습니다.
댓글 없음:
댓글 쓰기