거듭제곱은 y = x ** c 형태로 표현됩니다. 이때 x를 밑이라고 하고 c를 지수라 합니다. 거듭제곱의 미분은 미분 공식으로부터 dy/dx = cx ** ( c - 1)이 됩니다. dy/dc의 값도 구할 수는 있지만 실정에서는 거의 사용되지 않으니 이 책에서는 밑이 x인 경우만 미분해보겠습니다. 즉, 지수 c는 상수로 취급하여 따로 미분을 계산하지 않기로 합니다. 다음은 이를 구현한 코드입니다.
class Pow(Function):
def __init__(self, c):
self.c = c
def forward(self, x):
y = y ** self.c
return y
def backward(Self, gy):
x = self.inputs[0].data
c = self.c
gx = c * x ** ( c - 1 ) * gy
return gx
def pow(x, c):
return Pow(c)(x)
Variable.__pow__ = pow
코들르 보면 Pow클래스를 초기화할 때 지수 c를 제공할 수 있습니다. 그리고 순전파 메서드인 forward(x)는 밑에 해당하는 x만(즉, 하나의 항만)받게 합니다. 그런 다음 특수 메서드인 __pow__에 함수 pow를 할당합니다. 이제 ** 연산자를 사용하여 거듭제곱을 계산할 수 있습니다.
x = Variable(np.array(2.0))
y = x ** 3
print(y)
이상으로 목표한 연산자를 모두 추가했습니다. 이번 단곈느 다소 단조로운 작업의 연속이었지만 그 덕분에 DeZero의 유용성은 크게 향상됐습니다. 사칙연산 연산자들을 자유롭게 계산에 활용할 수 있게 된거죠. 거듭제ㅐ곱도 가능하기 때문에 제법 고급 계산까지 표현할 수 있답니다.
댓글 없음:
댓글 쓰기