페이지

2022년 8월 13일 토요일

10.2 square 함수의 역전파 테스트

 이어서 square 함수의 역전파도 테스트해보겠습니다. 방금 구현한 SquareTest클래스에 다음코드를 추가합니다.

class SquareTest(unittest.TestCase):
  def test_forward(self):
    x = Variable(np.array(2.0))
    y = square(x)
    expected = np.array(4.0)
    self.assertEqual(y.data, expected)


  def test_backward(self):
    x = Variable(np.array(3.0))
    y = square(x)
    y.backward()
    expected = np.array(6.0)
    self.assertEqual(x.grad, expected)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

test_backward메서드를 추가했습니다. 메서드 안에서 y.backward()로 미분값을 구하고, 그 값이 기댓값과 일치하는지 확인합니다. 참고로 여기에서 설정한 기댓값 6.0은 손으로 계산해서 구한 값을 하드코딩한 것입니다.

그럼 다시 테스트를 돌려봄시다. 결과는 다음과 같습니다.


---------------------------------------------------------------------- Ran 2 tests in 0.006s



결과를 보면 2개의 테스트를 통과했음을 알 수 있습니다. 원한다면 지금까지와 같은 요령으로 다른 테스트 케이스(입력과 기댓값)도 추가해나갈 수 있습니다. 테스트 케이스가 많아질수록 SQUARE함수의 신뢰도도 높아질 겁니다. 그리고 코드를 수정할 때마다 즉시즉시 테스트를 실행해주면 square함수의 상태를 반복해서 확인할 수 있습니다.

댓글 없음: