페이지

2022년 8월 12일 금요일

8.2 반복문을 이용한 구현

 이번 절에서는 지금까지의 '재귀를 사용한 구현'을 '반복문을 이용한 구현'으로 고쳐보겠습니다. 코드는 다음과 같습니다.

class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func
  
  def backward(self):
    funcs = [self.creator]
    while funcs:
      f = funcs.pop()   # 함수를 가져온다.
      x, y = f.input, f.output  # 함수의 입력과 출력을 가져온다.
      x.grad = f.backward(y.grad)   # backward 메서드를 호출한다

      if x.creator is not None:
        funcs.append(x.creator)   # 하나 앞의 함수를 리스트에 추가한다.

이것이 반복문을 이용한 구현입니다. 주묵할 점은 처리해야 할 함수들을 funcs 라는 리스트에 차례로 집어넣는다는 것입니다.  while블럭 안에서 funcs.pop()을 호출하여 처리할 함수 f를 꺼내고, f의 backward메서드를 호출합니다. 이때 f, input과 f, output에서 함수 f의 입력과 출력 변수를 얻음으로써 f.backward()의 인수와 반환값을 올바르게 설정할 수 있습니다.


리스트의 pop 메서드는 리스트에서 마지막 원소를 꺼내줍니다(반환된 요소는 리스트에서 제거됩니다). 예컨대 funcs = [1,2,3] 일때 x = funcs.pop()을 실행하면 3이 반환되고 funcs는 [1,2]가 됩니다.

8.1 현재의 Variable 클래스

 이전 장에서 우리는 Variable 클래스의 backward 메서드를 다음처럼 구현했습니다.


class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func

  def backward(self):
    f = self.creator
    if f is not None:
      x = f.input
      x.grad = f.backward(self.grad)
      x.backward()

이 backward메서드에서 눈에 밟히는 부분은 (입력 방향으로) 하나 앞 변수의 backward메서드를 호출하는 코드입니다. "backward 메서드에서 backward메서드를 호출하고, 호출된 backward 메서드에서 또 다른  backward 메서드를 호출하고,...' 과정이 계속됩니다(창조자 함수가 없는 변수, 즉 self.creator 가 None인 변수를 찾을 때까지 계속됩니다) 이를 재귀 구조라고 합니다.



STEP8 재귀에서 반복문으로

 앞 단계에서는 Variable 클래스에 backward메서드를 추가했습니다. 이번에는 처리 효율을 개선하고 앞으로의 확장을 대비해 backward메서드의 구현 방식을 바꿔보겠습니다.

7.3 backward 메서드 추가

 방금 보여드린 역전파 코드에는 똑같은 처리 흐름이 반복해서 나타났습니다. 변수에서 하나 앞의 변수로 거슬러 올라가는 로직이 그러했습니다. 그러므로 이 반복 작업을 자동화할 수 있도록 Variable 클래스에 badckward라는 새로운 메서드를 추가하겠습니다.


class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func

  def backward(self):
    f = self.creator    # 1. 함수를 가져온다.
    if f is not None:
      x = f.input       # 2. 함수의 입력을 가져온다.
      x.grad = f.backward(self.grad)    # 3. 함수의 backward 메서드를 호출한다.
      x.backward()      # 하나 앞 변수의 backward 메서드를 호출한다(재귀)


backward메서드는 지금까지 반복한 처리 흐름과 거의 동일합니다. Variable의 creator에서 함수를 얻어오고, 그 함수의 입력 변수를 가져옵니다. 그런 다음 함수의 backward메서드를 호출합니다. 마지막으로 자신보다 하나 앞에 놓인 변수의 backward 메서드를 호출합니다. 이런 식으로 각 변수의 backward 메서드가 재귀적으로 불리게 됩니다.

Variable 인스턴스의 creator 가 None이면 역전파가 중단됩니다. 창조자가 없으므로 이 Variable 인스턴스는 함수 바깥에서 생성했음을 뜻합니다(높은 확률로 사용자가 만들어 건네 변수일 것입니다).

이제 새로워진 Variable을 이용하여 역전파가 자동으로 실행되는 모습을 보겠습니다.

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

# 역전파
y.grad = np.array(1.0)
y.backward()
print(x.grad)

3.297442541400256

이와 같이 변수 y의 backward 메서드를 호출하면 역전파가 자동으로 진행됩니다. 실행 결과도 지금까지와 동일합니다. 축하합니다! 여러분은 방금  DeZero에서 가장 중요한 개념인 자동 미분의 기초를 완성했습니다.

7.2 역전파 도전!

 변수와 함수의 관곌르 이용하여 역전파를 시도해보겠습니다. 우선 y에서 b까지의 역전파를 시도해보죠.


이 부분은 다음과 같이 구현할 수 있습니다.

y.grad = np.array(1.0)

C = y.creator # 1. 함수를 가져온다.
b = C.input # 2. 함수의 입력을 가져온다.
b.grad = C.backward(y.grad) # 3. 함수의 backward 메서드를 호출한다.

y의 인스턴스 변수 creator 에서 함수를 얻어오고, 그 함수의 input에서 입력 변수를 가져왔습니다. 그런 다음 함수의 backward메서드를 호출합니다. 이어서 변수 b에서 a로의 역전파를 보겠습니다.

B = b.creator # 1. 함수를 가져온다.
a = B.input   # 2. 함수의 입력을 가져온다.
a.grad = B.backward(b.grad) # 3. 함수의 backward메서드를 호출한다.

똑같은 흐름입니다. 구체적으로 다음과 같은 순서로 진행됩니다.

1. 함수를 가져온다.

2. 함수의 입력을 가져온다.

3. 함수의 backward메서드를 호출한다


마지막으로 변수 a에서 x로의 역전파까지 진행합니다.\

 A = a.creator  #1. 함수를 가져온다.
 x = A.input    # 함수의 입력을 가져온다.
 x.grad = A.backward(a.grad)    # 3. 함수의 backward메서드를 호출한다.
 print(x.grad)

3.297442541400256

이상으로 모든 역전파가 끝났습니다.


7.1 역전파 자동화의 시작

 역전파 자동화로 가는 길은 변수와 함수의 '관계'를 이해하는 데서 출발합니다. 우선 함수 관점에서 '함수는 변수를 어떻게 바라볼까'를 생각해봅시다. 함수 입장에서 변수는 '입력'과 '출력'에 쓰입니다. 즉, [그림 7-2]의 왼쪽과 같이 함수에서 변수는 '입력 변수 (input)'와 '출력 변수(output)'로서 존재합니다(그림의 점선은 참조(reference)를 뜻합니다).


변수 관점에서 함수는 어떤 존재일까요? 여기서 눈여겨볼 점은 변수는 함수에 의해 '만들어진다'라는 것입니다. 즉, 변수에게 있어 함수는 '창조자(creator)'혹은 '부모'입니다. 창조자인 함수가 존재하지 않는 변수는 함수 이외의 존재, 예컨대 사용자에 의해 만들어진 변수로 간주됩니다.

일단 [그림 7-2]와 같은 함수와 변수의 관계를 DeZero코드에 녹여볼까요? 여기에서는 일반적인 계산(순전파)이 이루어지는 시점에 '관계'를 맺어주도록(즉, 함수와 변수를 연결 짓도록) 만들겠습니다. 이를 위해 우선 Variable 클래스에 다음 코드를 추가합니다.

class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None
  
  def set_creator(selffunc):
    self.creator = func
    

creator 라는 인스턴스 변수를 추가했습니다. 그리고 creator 를 설정할 수 있도록 set_creator 메서드로 추가합니다. 이어서 Function클래스에 다음 코드를 추가합니다.

class Function:
  def __call__(selfinput):
    x = input.data
    y = self.forward(x)
    output = Variable(y)
    output.set_creator(self)  # 출력 변수에 창조자를 설정한다.
    self.input = input
    self.output = output # 출력도 저장한다.
    return output
  

순전파를 계산하면 그 결과로 output이라는 Variable 인스턴스가 생성됩니다. 이때 생성된 output에 '내가 너의 창조자임'을 기억시킵니다. 이 부분이 '연결'을 동적으로 만드는 기법의 핵심입니다. 그런 다음 앞으로를 위해 output을 인스턴스 변수에 저장했습니다.


DeZero의 동적 계산 그래프(Dynamic Computational Graph)는 실제 계산이 이루어질 때 변수(상자)에 관련 '연결'을 기록하는 방식으로 만들어집니다. 체이너와 파이토치의 방식도 이와 비슷합니다.


이와 같이'연결' 된 Variable과 Function이 있다면 계산 그래프를 거꾸로 거슬러 올라갈 수 있습니다. 구체적인 코드로 나타내면 다음과 같습니다.

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)
#계산 그래프의 노드들을 거꾸로 거슬러 올라간다.

assert y.creator == C
assert y.creator.input == b
assert y.creator.input.creator == B
assert y.creator.input.creator.input == a
assert y.creator.input.creator.input.creator == A
assert y.creator.input.creator.input.creator.input == x

우선 assert문이 무엇이지 설명해야겠네요. 먼저 'assert'는 우리말로 '단호하게 주장하다', '단언하다'라는 뜻입니다. assert 문은 assert ... 형태로 사용합니다. 여기서 ... 부분이 '주장'에 해당하는 내용으로, 그 평가 결과가 True가 아니면 예외가 발생합니다. 따라서 assert문은 조건을 충족하는지 여부를 확인하는데 사용할 수 있습니다. 참고로 앞의 코드는 문제없이(예외가 발생하지 않고) 실행되므로 assert문의 조건을 모두 충족함을 알 수 있습니다.

앞의 코드가 보여주듯 Variable의 인스턴스 변수 creator 에서 바로 앞의 Function으로 건너갑니다. 그리고 그 Function의 인스턴스 변수 input에서 다시 하나 더 앞의 Variable로 건너가죠. [그림 7-3]은 이 관계를 잘 보여줍니다.


[그림 7-3]과 같이 우리 계산 그래프는 함수와 변수 사이의 연결로 구성됩니다. 그리고 중요한 점은 이 '연결'이 실제로 계산을 수행하는 싲넘에(순전파로 데이터를 흘려보낸 때) 만들어진다는 것입니다. 이러한 특성에 이름을 붙인 것이 Define-by-Run입니다. 데이터를 흘려보냄으로써(Run함으로써)연결이 규정된다는 (Define된다는) 뜻입니다.

또한 [그림 7-3]과 같이 노드들의 연결로 이루어진 데이터 구조를 '링크드 리스트(linked list)라고 합니다. 노드는 그래프를 구성하는 요소이며, 링크(link)는 다른 노드를 가리키는 참조를 뜻합니다. 결국 우리는 '링크드 리스트'라는 데이터 구조를 이용해 계산 그래프를 표현하고 있는 것입니다.

2022년 8월 11일 목요일