페이지

2022년 8월 12일 금요일

STEP9 함수를 더 편리하게

 DeZero가 역전파를 해낼 수 있게 되었습니다. 또한 Define-by-Run이라고 하는 전체 계산의 각 조각들을 런터임에 '연결'해내는 능력도 같췄습니다. 하지만 사용하기에 조금 불편한 부분이 있어서 이번 단계에서는 DeZero의 함수에 세 가지 개선을 추가하엤습니다.

8.3 동작 확인

 개선된 Variable 클래스를 사용하여 실제로 미분을 해봅시다. 7단계에서와 똑같은 코드를 실행해보겠습니다.

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

# 역전파
y.grad = np.array(1.0)
y.backward()
print(x.grad)

3.297442541400256

결과도 이전과 똑같습니다. 이상으로 '재귀'에서 '반복문'으로 구현 방식을 전환했습니다. 반복문 방식의 이점은 15단계에서 알 수 있습니다. 15단계에서는 복잡한 계산 그래프를 다루는데, 받금 전환한 구현 덕분에 부르럽게 확장할 수 있습니다. 처리효율도 반복문 방식이 뛰어납니다.

재귀는 함수를 재귀적으로 호출할 때마다 중간 결과를 메모리에 유지하면서(스택에 쌓으면서)처리를 이어갑니다. 일반적으로 반복문 방식이 효율이 더 좋은 이유입니다. 그러나 요즘 컴퓨터는 메모리가 넉넉한 편이라서 조금 더 사용한느 건 그리 문제가 되지 않습니다 또한 '꼬리 재귀(tail recursion)'기법을 이용하여 재귀를 반복문처럼 실행할 수 있는 경우도 있습니다.


이상으로 역전파 구현의 기반은 완성했습니다. 앞으로는 더욱 복잡한 계산도 가능하도록 현재의  DeZero를 확장해나갈 것입니다. 하지만 그전에 다음 단계에서 DeZero의 '사용자 편의성'부터 개선하겠습니다.

8.2 반복문을 이용한 구현

 이번 절에서는 지금까지의 '재귀를 사용한 구현'을 '반복문을 이용한 구현'으로 고쳐보겠습니다. 코드는 다음과 같습니다.

class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func
  
  def backward(self):
    funcs = [self.creator]
    while funcs:
      f = funcs.pop()   # 함수를 가져온다.
      x, y = f.input, f.output  # 함수의 입력과 출력을 가져온다.
      x.grad = f.backward(y.grad)   # backward 메서드를 호출한다

      if x.creator is not None:
        funcs.append(x.creator)   # 하나 앞의 함수를 리스트에 추가한다.

이것이 반복문을 이용한 구현입니다. 주묵할 점은 처리해야 할 함수들을 funcs 라는 리스트에 차례로 집어넣는다는 것입니다.  while블럭 안에서 funcs.pop()을 호출하여 처리할 함수 f를 꺼내고, f의 backward메서드를 호출합니다. 이때 f, input과 f, output에서 함수 f의 입력과 출력 변수를 얻음으로써 f.backward()의 인수와 반환값을 올바르게 설정할 수 있습니다.


리스트의 pop 메서드는 리스트에서 마지막 원소를 꺼내줍니다(반환된 요소는 리스트에서 제거됩니다). 예컨대 funcs = [1,2,3] 일때 x = funcs.pop()을 실행하면 3이 반환되고 funcs는 [1,2]가 됩니다.

8.1 현재의 Variable 클래스

 이전 장에서 우리는 Variable 클래스의 backward 메서드를 다음처럼 구현했습니다.


class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func

  def backward(self):
    f = self.creator
    if f is not None:
      x = f.input
      x.grad = f.backward(self.grad)
      x.backward()

이 backward메서드에서 눈에 밟히는 부분은 (입력 방향으로) 하나 앞 변수의 backward메서드를 호출하는 코드입니다. "backward 메서드에서 backward메서드를 호출하고, 호출된 backward 메서드에서 또 다른  backward 메서드를 호출하고,...' 과정이 계속됩니다(창조자 함수가 없는 변수, 즉 self.creator 가 None인 변수를 찾을 때까지 계속됩니다) 이를 재귀 구조라고 합니다.



STEP8 재귀에서 반복문으로

 앞 단계에서는 Variable 클래스에 backward메서드를 추가했습니다. 이번에는 처리 효율을 개선하고 앞으로의 확장을 대비해 backward메서드의 구현 방식을 바꿔보겠습니다.

7.3 backward 메서드 추가

 방금 보여드린 역전파 코드에는 똑같은 처리 흐름이 반복해서 나타났습니다. 변수에서 하나 앞의 변수로 거슬러 올라가는 로직이 그러했습니다. 그러므로 이 반복 작업을 자동화할 수 있도록 Variable 클래스에 badckward라는 새로운 메서드를 추가하겠습니다.


class Variable:
  def __init__(selfdata):
    self.data = data
    self.grad = None
    self.creator = None

  def set_creator(selffunc):
    self.creator = func

  def backward(self):
    f = self.creator    # 1. 함수를 가져온다.
    if f is not None:
      x = f.input       # 2. 함수의 입력을 가져온다.
      x.grad = f.backward(self.grad)    # 3. 함수의 backward 메서드를 호출한다.
      x.backward()      # 하나 앞 변수의 backward 메서드를 호출한다(재귀)


backward메서드는 지금까지 반복한 처리 흐름과 거의 동일합니다. Variable의 creator에서 함수를 얻어오고, 그 함수의 입력 변수를 가져옵니다. 그런 다음 함수의 backward메서드를 호출합니다. 마지막으로 자신보다 하나 앞에 놓인 변수의 backward 메서드를 호출합니다. 이런 식으로 각 변수의 backward 메서드가 재귀적으로 불리게 됩니다.

Variable 인스턴스의 creator 가 None이면 역전파가 중단됩니다. 창조자가 없으므로 이 Variable 인스턴스는 함수 바깥에서 생성했음을 뜻합니다(높은 확률로 사용자가 만들어 건네 변수일 것입니다).

이제 새로워진 Variable을 이용하여 역전파가 자동으로 실행되는 모습을 보겠습니다.

A = Square()
B = Exp()
C = Square()

x = Variable(np.array(0.5))
a = A(x)
b = B(a)
y = C(b)

# 역전파
y.grad = np.array(1.0)
y.backward()
print(x.grad)

3.297442541400256

이와 같이 변수 y의 backward 메서드를 호출하면 역전파가 자동으로 진행됩니다. 실행 결과도 지금까지와 동일합니다. 축하합니다! 여러분은 방금  DeZero에서 가장 중요한 개념인 자동 미분의 기초를 완성했습니다.

7.2 역전파 도전!

 변수와 함수의 관곌르 이용하여 역전파를 시도해보겠습니다. 우선 y에서 b까지의 역전파를 시도해보죠.


이 부분은 다음과 같이 구현할 수 있습니다.

y.grad = np.array(1.0)

C = y.creator # 1. 함수를 가져온다.
b = C.input # 2. 함수의 입력을 가져온다.
b.grad = C.backward(y.grad) # 3. 함수의 backward 메서드를 호출한다.

y의 인스턴스 변수 creator 에서 함수를 얻어오고, 그 함수의 input에서 입력 변수를 가져왔습니다. 그런 다음 함수의 backward메서드를 호출합니다. 이어서 변수 b에서 a로의 역전파를 보겠습니다.

B = b.creator # 1. 함수를 가져온다.
a = B.input   # 2. 함수의 입력을 가져온다.
a.grad = B.backward(b.grad) # 3. 함수의 backward메서드를 호출한다.

똑같은 흐름입니다. 구체적으로 다음과 같은 순서로 진행됩니다.

1. 함수를 가져온다.

2. 함수의 입력을 가져온다.

3. 함수의 backward메서드를 호출한다


마지막으로 변수 a에서 x로의 역전파까지 진행합니다.\

 A = a.creator  #1. 함수를 가져온다.
 x = A.input    # 함수의 입력을 가져온다.
 x.grad = A.backward(a.grad)    # 3. 함수의 backward메서드를 호출한다.
 print(x.grad)

3.297442541400256

이상으로 모든 역전파가 끝났습니다.