페이지

2022년 8월 31일 수요일

22.4 거듭제곱

 거듭제곱은 y = x **  c 형태로 표현됩니다. 이때 x를 밑이라고 하고 c를 지수라 합니다. 거듭제곱의 미분은 미분 공식으로부터 dy/dx = cx ** ( c - 1)이 됩니다. dy/dc의 값도 구할 수는 있지만 실정에서는 거의 사용되지 않으니 이 책에서는 밑이 x인 경우만 미분해보겠습니다. 즉, 지수 c는 상수로 취급하여 따로 미분을 계산하지 않기로 합니다. 다음은 이를 구현한 코드입니다.

class Pow(Function):
  def __init__(selfc):
    self.c = c

  def forward(selfx):
    y = y ** self.c
    return y

  
  def backward(Selfgy):
    x = self.inputs[0].data
    c = self.c
    gx = c * x ** ( c - 1 )  * gy
    return gx


def pow(xc):
  return Pow(c)(x)

Variable.__pow__ = pow

코들르 보면 Pow클래스를 초기화할 때 지수 c를 제공할 수 있습니다. 그리고 순전파 메서드인 forward(x)는 밑에 해당하는 x만(즉, 하나의 항만)받게 합니다. 그런 다음 특수 메서드인 __pow__에 함수 pow를 할당합니다. 이제 ** 연산자를 사용하여 거듭제곱을 계산할 수 있습니다.

x = Variable(np.array(2.0))
y = x ** 3

print(y)

이상으로 목표한 연산자를 모두 추가했습니다. 이번 단곈느 다소 단조로운 작업의 연속이었지만 그 덕분에 DeZero의 유용성은 크게 향상됐습니다. 사칙연산 연산자들을 자유롭게 계산에 활용할 수 있게 된거죠. 거듭제ㅐ곱도 가능하기 때문에 제법 고급 계산까지 표현할 수 있답니다.

댓글 없음: