페이지

2022년 8월 11일 목요일

STEP7 역전파 자동화

 이전 단계에서 역전파를 동작신키는 데 성공했습니다. 그러나 역전파 계산 코드를 수동으로 조합해야 했습니다. 새로운 계산을 할 때마다 역전파 코드를 직접 작성해야 한다는 뜻이죠. [그림 7-1]처럼 계산 그래프가 여러 개라면 각각의 계산에 맞게 역전파 코드를 수동으로 따로따로 작성해야 합니다. 그러다 보면  실수가 생길 수도 있고, 무엇보다도 지루할 것입니다. 지루한 일은 파이썬에게 시키자고요!


그래서 이제부터 역전파를 자동화하려 합니다. 더 정확히 말하면, 일반적인 계산(순전파)을 한 번만 해주면 어떤 계산이라도 상관없이 역전파가 자동으로 이루어지는 구조를 만들 것입니다. 두둥! 지금부터가 바로 Define-by-Run의 핵심을 건드리는 내용입니다!

Define-by-Run이란 딥러닝에서 수행하는 계산들을 계산 시점에 '연결'하는 방식으로, '동적 계산 그래프'라고도 합니다. Define-by-Run의 개념과 장점은 제2고지 마지막의 '칼럼:Define-by-Run'에서 자세히 설명합니다.

그런데 [그림7-1]의 계산 그래프들은 모두 일직선으로 늘어선 계산입니다. 따라서 함수의 순서를 리스트 형태로 저장해두면 자중에 거꾸로 추적하는 식으로 역전파를 자동화할 수 있습니다. 그러나 분기가 있는 계산 그래프나 같은 변수가 여러 번 사용되는 복잡한 계산 그래프는 단순히 리스트로 저장하는 식으로 풀 수 없습니다. 우리 목표는 아무리 복잡한 계산 그래프라 하더라도 역전파를 자동으로 할 수 있는 구조를 만련하는 것입니다.


사실 리스트 데이터 구조를 응용하면 수행한 계산을 리스트에 추가해 나가는 것만으로 어떠한 계산 그래프의 역전파도 제대로 해낼 수 있습니다. 이 데이터 구조를 웬거트 리스트(Wergert List)(혹은 테이프(tape))라고 합니다. 이 책에서는 웬거트 리스트에 대한 설명은 하지 않으니 관심 있는 분은 참고 문헌 [2]와 [3]을 참고하시고, 웬거튼 리스트를 활용하는 Define-by-Run의 장점은 참고문헌[4]를 참고해 주세요


6.4 역전파 구현

 이상으로 준비 작업이 끝났습니다. 이번 절에서는 [그림 6-1]에 해당하는 계산의 미분을 역전파로 계산해보겠습니다.


먼저[그림 6-1]을 순전파하는 코드부터 보겠습니다.


A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)


이어서 역전파로 y를 미분해보죠. 순전파 때와는 반대 순서로 각 함수의 backward메서드를 호출하면 됩니다. [그림 6-2]는 이때 이루어지는 역전파를 계산 그래프로 그린 모습입니다.


[그림 6-2]를 보면 어떤 순서로 어느 함수의 backward메서드를 호출하면 되는지 알 수 있습니다. 또한 backward메서드의 결과를 어느 변수의 grad로 설정하면 되는지도 알 수 있습니다. 다음은 [그림 6-2]의 계산 그래프를 코드로 옮긴 모습입니다.

y.grad = np.array(1.0)
b.grad = C.backward(y.grad)
a.grad = B.backward(b.grad)
x.grad = A.backward(a.grad)
print(x.grad)

3.297442541400256

역전파는 dy/dy = 1에서 시작합니다. 따라서 출력 y의 미분값을 np.array(1.0)로 설정합니다. 그런 다음 C->B->A순으로 backward 메서드를 호출하기만 하면 됩니다. 이것으로 각 변수의 미분값이 구해집니다.

앞의 코드를 실행하면 x.grad의 값이 3.297442541400256 이라고 나옵니다. 이 값이 y의 x에 대한 미분 결과입니다. 4단계에서 수치 미분으로 구한 값이 3.2974426293330694 였으니 두 결과가 거의 같음을 알 수 있습니다. 역전파를 제대로 구현한 것입니다(더 정확하게는, 올바르게 구현했을 가능성이 큽니다).

이상이 역전파 구현입니다. 제대로 동작하지만 역전파 순서(C->B->A)에 맞춰 호출하는 코드를 우리가 일일이 작성해 넣은 건 영 불편할 것 같습니다. 그래서 다음 단계에서는 이 작업을 자동화하겠습니다.



2022년 8월 7일 일요일

6.3 Square 와 Exp클래스 추가 구현

 이어서 Function을 상속한 구체적인 함수에서 역전파(backward)를 구현해보겠습니다. 첫번째 대상은 제곱을 계산하는 Square 클래스입니다. y = x **2 의 미분은 dy/dx = 2x가 되기 때문에 다음처럼 구현할 수 있습니다.

class Square(Function):
  def forward(selfx):
    y = x ** 2
    return y

  def backward(selfgy):
    x = self.input.data
    gx = 2 * x * gy
    return gx

이와 같이 역전파를 담당하는 backward 메서드를 추가했습니다. 이 메서드의 인수 gy는 ndarray 인스턴스이며, 출력 쪽에서 전해지는 미분값을 전달하는 역할을 합니다. 그리고 인수로 전달된 미분에 'y= x ** 2의 미분'을 곱한 값이 backward의 결과가 됩니다. 역전파에서는 이 결괏값을 입력 쪽에 더 가까운 다음 함수로 전파해나갈 것입니다.

이어서 y =  e ** x계산을 할 Exp 클래스입니다. 이 계산의 미분은 dy/dx = e ** 2이기 때문에 다음과 같이 구현할 수 있습니다.

class Exp(Function):
  def forward(selfx):
    y = np.exp(x)
    return y
  
  def backward(selfgy):
    x = self.input.data
    gx = np.exp(x) * gy
    return gx


6.2 Function클래스 추가 구현

 Function 클래스를 알아볼 차례입니다. 이전 단계까지 Function  클래스는 일반적인 계산을 하는 순전파(forward 메서드) 기능만 지원하는 상태입니다. 이외에 다음 두 기능을 추가하겠습니다.

- 미분을 계산하는 역전파(backward메서드)

- forward메서드 호출 시 건너받은 Variable 인스턴스 유지

class Function:
  def __call__(selfinput):
    x = input.data
    y = self.forward(x)
    output = Variable(y)
    self.input = input # 입력 변수를 기억(보관)한다.
    return output
  
  def forward(selfx):
    raise NotImplementedError()
  
  def backward(self, gy):
    raise NotImplementedError()


코드에서 보듯 __call__메서드에서 입력된 input을 인스턴스 변수인 self.input에 저장합니다. 이렇게 해서 나중에  backward 메서드에서 함수(Function)에 입력한 변수(Variable 인스턴스)가 필요할때 self.input에서 가져와 사용할 수 있습니다.

6.1 Variable 클래스 추가 구현

 역전파에 대응하는  Variable 클래스를 구현하겠습니다. 그러기 위해 통상값(data)과 더붙어 그에 대응하는 미분값(grad)도 저장하도록 확장합니다. 새로 추가된 코드에 음영을 덧쒸웠습니다.

class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None

이와 같이 새로 grad라는 인스턴스 변수를 추가했습니다. 인스턴수 변수인 data와 grad는 모두 넘파이의 다차원 배열(ndarray)이라고 가정합니다. 또한 grad는 None으로 초기화해둔 다음, 나중에 실제로 역전파를 하면 미분값을 계산하여 대입합니다.

벡터나 행렬등 다변수에 대한 미분은 기우기(gradent)라고 합니다. Variable 클래스에 새로 추가한 grad변수의 이름은 gradient를 줄인 것입니다.

STEP 6 수동 역전파

 이전 단계에서 역전파의 구동 원리를 설명했습니다. 이번 단계에서는 Variable과 Function클래스를 확장하여 역전파를 이용한 미분을 구현하겠습니다. Variable 클래스부터 살펴 보죠

5.3 계산 그래프로 살펴보기

 다음과 같이 통상적인 계산인 순전파 계산 그래프(그림 5-1)와 미분을 계산하는 역전파 계산 그래프(그림 5-4)를 위아래로 나련히 놓고 살펴봅시다.




이렇게 비교하니 순전파와 역전파의 관계가 명확히 보입니다. 예를 들어 순전파 시의 변수 a는 역전파 시의 미분  dy/da에 대응합니다, 마찬가지로 b와 dy/db가 대응하고 x 와 dy/dx가 대응합니다. 또한 함수에도 대응 관계가 보입니다. 함수 B는 역전파의 B'(a)에 대응하고 A는  A'(x)에 대응하는 식 입니다. 이렇게 변수는 '통상값'과 '미분값'이 존재하고, 함수는 '통상 계산(순전파)'과 '미분값을 구하기 위한 계산(역전파)'이 존재하는 것으로 생각할 수 있습니다. 이를 통해 역전파를 어떻게 구현할지 짐작해볼 수 있을 것 입니다.

마지막으로  [그림 5-5]의 함수 노드 C'(b)를 계산하려면 b값이 필요하다는 사실입니다. 마찬가지로 B'(a)를 구하려면 입력 a의 값이 필요합니다. 무슨 말인고 하니, 역전파 시에는 순전파 시 이용한 데이터가 필요하다는 것입니다. 따라서 역전파를 구현하려면 먼저 순전파를 하고, 이때 각 함수가 입력 변수(앞의 예에서는 x, a, b)의 값을 기억해두지 않으면 안 됩니다. 그런 다음에야 각 함수의 역전파를 계산할 수 있습니다.

이상이 역전파의 이론 설명입니다. 다소 복잡한가요? 좋은 소식을 알려드리겠습니다. 다행히 역전파는 이 책에서 가장 어려운 내용에 속한답니다. 그리고 아직 잘 이해되지 않는 부분도 실제로 코드를 실행해보면 이해될 것입니다. 다음 단계에서는 역전파를 구현하고 실제로 돌려보며 검증하겠습니다.