페이지

2022년 8월 6일 토요일

2.3 Function 클래스 이용

 Function 클래스를 실제로 사용해보죠. Variable 인스턴스인 x를 Function인스턴스인 f 에 입력해보겠습니다.

x = Variable(np.array(10))
f = Function()
y = f(x)

print(type(y)) # type() 함수는 객체의 클래스를 알려준다.
print(y.data)

이와 같이 Variable과 Function을 연계할 수 있습니다. 실행 결과를 보면 y의 클래스는 Variable이며, 데이터는 y.data에 잘 저장되어 있음을 알 수 있습니다.

그런데 방금 구현한 Function 클래스는 용도가 '입력값의 제곱'으로 고정된 함수입니다. 따라서 Sequare라는 명확한 이름이 더 어울립니다. 앞으로 Sin, Exp 등 당야한 함수가 필요하다는 점을 고려하면 Function클래스는 기반 크래스로 두고 DeZero의 모든 함수가 공통적으로 게족하는 기능만 담아주는 것이 좋겠습니다. 그래서 앞으로 모든 DeZero함수는 다음의 두 사항을 만족하도록 구현하겠습니다.

1) Function 클래스는 기반 클래스로서, 모든 함수에 공통되는 기능을 구현합니다.

2) 구체적인 함수는 Function 클래스를 상속한 캘래스에서 구현합니다.

이를 위해 Function 클래스를 다음처럼 수정합니다.

class Function:
  def __call__(selfinput):
    x = input.data
    y = self.forward(x) #구체적인 계산은 forward 메서드에서 한다.
    output = Variable(y)
    return output
  
  def forward(selfx):
    raise NotImplementedError()


__call__살짝 수정하고 forward라는 메서드를 추가했습니다. __call__메서드는 'Variable에서 ㅁ데이터 찾기'와 '계산 결과를 Variable에 포장하기'라는 두 가지 일을 합니다. 그리고 그 사이의 구체적인 계산은 forward메서드를 호출하여 수행합니다. 마지막으로 forward 메서드의 구체적인 로직은 하위 클래스에서 구현합니다.


Function 클래스의  forward메서드는 예외를 발생시킵니다. 이렇게 해두면 Function클래스의 forward메서드를 직접 호출한 사람에게 '이 메서드는 상속하여 구현해야 한다'는 사실을 알려줄 수 있습니다.


이러서 Function 클래스를 상속하여 입력값을 제곱하는 클래스를 구현하겠습니다. 클래스이름은 Squarea 라고 짓고 다음과 같이 구현합니다.

class Square(Function):
  def forward(selfx):
      return x **2

Square클래스는 Function클래스를 상속하기 때문에 __call__메서드는 그대로 계승됩니다. 따라서 forward메서드에 구체적인 계산 로직을 작성해 넣는 것만으로 구현은 끝입니다. 실제로 잘 동작하는지 Square 클래스를 사용하여 Variable을 처리하는 모습을 보시죠.

x = Variable(np.array(10))
f = Square()
y = f(x)
print(type(y))
print(y.data)

<class '__main__.Variable'> 100

댓글 없음: