뺄셈의 미분은 y = x0 - x1일 때 dy/dx0 = 1, dy/dx1 = -1 입니다. 따라서 역전파는 상류에서 전해지는 미분값에 1을 곱한 값이 x0의 미분 결과가 되며, -1을 곱한 값이 x1의 미분 결과가 됩ㄴㅣ다. 코드로는 다음처럼 구현할 수 있습니다.
class Sub(Function):
def forward(self, x0, x1):
y = x0 - x1
return y
def backward(self, gy):
return gy, -gy
def sub(x0, x1):
x1 = as_array(x1)
return Sub()(x0, x1)
Variable.__sub__ = sub
이제 x0와 x1이 Variable 인스턴스라면 y0 = x0 - x1 계산을 수행할 수 있습니다. 그러나 x0가 Variable 인스턴스가 아닌 경우, 예컨데 y = 2.0 - x 같은 코드는 제대로 처리할 수 없습니다. x의 __rsub__ 메서드가 호출되어 인수가 [그림 22-1]형태로 전달되기 때문이죠.
[그림 22-1]과 같이 __rsub__(self, other)가 호출될 땐느 우항인 x가 인수 self에 전달됩니다. 따라서 다음처럼 구현해야 합니다.
def rsub(x0, x1):
x1 = as_array(x1)
return Sub()(x1, x0) # x0 와 x1의 순서를 바꾼다.
Variable.__rsub__ = rsub
보다시피 함수 rsub(x0, x1)을 정의하고 인수의 순서를 바꿔서 Sub()(x1, x0)를 호출하게 합니다. 그런 다음 특수 메서드인 __rsub__에 함수 rsub를 할당합니다.
덧셈과 곰셈은 좌항과 우항의 순석를 바꿔도 결과가 같기 때문에 둘을 구별할 필요가 없었ㅅ브니다. 하지만 뺄셈에서는 좌우를 구별해야 합니다. (x0 - x1과 x1 - x0의 값은 다릅니다). 따라서 우항을 대상으로 했을 때 적용할 함수인 rsub(x0, x1)을 별도로 준비해야 합니다.
이상으로 뺄셈도 할 수 있게 되었습니다. 이제 다음 코드가 잘 작동합니다.
x = Variable(np.array(2.0))
y1 = 2.0 - x
y2 = x - 1.0
print(y1)
print(y2)
댓글 없음:
댓글 쓰기