이어서 square 함수의 역전파도 테스트해보겠습니다. 방금 구현한 SquareTest클래스에 다음코드를 추가합니다.
class SquareTest(unittest.TestCase):
def test_forward(self):
x = Variable(np.array(2.0))
y = square(x)
expected = np.array(4.0)
self.assertEqual(y.data, expected)
def test_backward(self):
x = Variable(np.array(3.0))
y = square(x)
y.backward()
expected = np.array(6.0)
self.assertEqual(x.grad, expected)
if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)
test_backward메서드를 추가했습니다. 메서드 안에서 y.backward()로 미분값을 구하고, 그 값이 기댓값과 일치하는지 확인합니다. 참고로 여기에서 설정한 기댓값 6.0은 손으로 계산해서 구한 값을 하드코딩한 것입니다.
그럼 다시 테스트를 돌려봄시다. 결과는 다음과 같습니다.
---------------------------------------------------------------------- Ran 2 tests in 0.006s
결과를 보면 2개의 테스트를 통과했음을 알 수 있습니다. 원한다면 지금까지와 같은 요령으로 다른 테스트 케이스(입력과 기댓값)도 추가해나갈 수 있습니다. 테스트 케이스가 많아질수록 SQUARE함수의 신뢰도도 높아질 겁니다. 그리고 코드를 수정할 때마다 즉시즉시 테스트를 실행해주면 square함수의 상태를 반복해서 확인할 수 있습니다.
댓글 없음:
댓글 쓰기