이어서 Function을 상속한 구체적인 함수에서 역전파(backward)를 구현해보겠습니다. 첫번째 대상은 제곱을 계산하는 Square 클래스입니다. y = x **2 의 미분은 dy/dx = 2x가 되기 때문에 다음처럼 구현할 수 있습니다.
class Square(Function):
def forward(self, x):
y = x ** 2
return y
def backward(self, gy):
x = self.input.data
gx = 2 * x * gy
return gx
이와 같이 역전파를 담당하는 backward 메서드를 추가했습니다. 이 메서드의 인수 gy는 ndarray 인스턴스이며, 출력 쪽에서 전해지는 미분값을 전달하는 역할을 합니다. 그리고 인수로 전달된 미분에 'y= x ** 2의 미분'을 곱한 값이 backward의 결과가 됩니다. 역전파에서는 이 결괏값을 입력 쪽에 더 가까운 다음 함수로 전파해나갈 것입니다.
이어서 y = e ** x계산을 할 Exp 클래스입니다. 이 계산의 미분은 dy/dx = e ** 2이기 때문에 다음과 같이 구현할 수 있습니다.
class Exp(Function):
def forward(self, x):
y = np.exp(x)
return y
def backward(self, gy):
x = self.input.data
gx = np.exp(x) * gy
return gx
댓글 없음:
댓글 쓰기