페이지

2022년 8월 31일 수요일

22.2 뺄셈

 뺄셈의 미분은 y = x0 - x1일 때 dy/dx0 = 1, dy/dx1 = -1 입니다. 따라서 역전파는 상류에서 전해지는 미분값에 1을 곱한 값이 x0의 미분 결과가 되며, -1을 곱한 값이 x1의 미분 결과가  됩ㄴㅣ다. 코드로는 다음처럼 구현할 수 있습니다.

class Sub(Function):
  def forward(selfx0x1):
    y = x0 - x1
    return y

  def backward(selfgy):
    return gy, -gy

def sub(x0x1):
  x1 = as_array(x1)
  return Sub()(x0, x1)

Variable.__sub__ = sub

이제 x0와 x1이 Variable 인스턴스라면 y0 = x0 - x1 계산을 수행할 수 있습니다. 그러나 x0가 Variable 인스턴스가 아닌 경우, 예컨데 y = 2.0 - x 같은 코드는 제대로 처리할 수 없습니다. x의 __rsub__ 메서드가 호출되어 인수가 [그림 22-1]형태로 전달되기 때문이죠.


[그림 22-1]과 같이 __rsub__(self, other)가 호출될 땐느 우항인 x가 인수 self에 전달됩니다. 따라서 다음처럼 구현해야 합니다.

def rsub(x0x1):
  x1 = as_array(x1)
  return Sub()(x1, x0)    # x0 와 x1의 순서를 바꾼다.

Variable.__rsub__ = rsub

보다시피 함수 rsub(x0, x1)을 정의하고 인수의 순서를 바꿔서 Sub()(x1, x0)를 호출하게 합니다. 그런 다음 특수 메서드인 __rsub__에 함수 rsub를 할당합니다.


덧셈과 곰셈은 좌항과 우항의 순석를 바꿔도 결과가 같기 때문에 둘을 구별할 필요가 없었ㅅ브니다. 하지만 뺄셈에서는 좌우를 구별해야 합니다. (x0 - x1과 x1 - x0의 값은 다릅니다). 따라서 우항을 대상으로 했을 때 적용할 함수인 rsub(x0, x1)을 별도로 준비해야 합니다.


이상으로 뺄셈도 할 수 있게 되었습니다. 이제 다음 코드가 잘 작동합니다.

x = Variable(np.array(2.0))
y1 = 2.0 - x
y2 = x - 1.0
print(y1)
print(y2)

댓글 없음: