본론으로 돌아와서 Variable 클래스의 backward 메서드를 구현하겠습니다. 이전과 달라진 부분(음영)에 주목해서 살펴보죠.
class Variable:
def __init__(self, data):
if data is not None:
if not isinstance(data, np.ndarray):
raise TypeError('{}은(는) 지원하지 않습니다.' .format(type(data)))
self.data = data
self.grad = None
self.creator = None
self.generation = 0 # 세대 수를 기록하는 변수
def set_creator(self, func):
self.creator = func
self.generation = func.generation + 1 # 세대를 기록한다(부모 세대 + 1).
def backward(self):
if self.grad is None:
self.grad = np.ones_like(self.data)
funcs = []
seen_set = set()
def add_func(f):
if f not in seen_set:
funcs.append(f)
seen_set.add(f)
funcs.sort(key=lambda x: x.generation)
add_func(self.creator)
while funcs:
f = funcs.pop() #
gys = [output.grad for output in f.outputs]
gxs = f.backward(*gys)
if not isinstance(gxs, tuple):
gxs = (gxs,)
for x, gx in zip(f.inputs, gxs):
if x.grad is None:
x.grad = gx
else:
x.grad = x.grad + gx
if x.creator is not None:
add_func(x.creator) #수정 전 funcs.append(x.creator)
def cleargrad(self):
self.grad = None
가장 큰 변화는 새로 추가된 add_func 함수입니다. 그 동안 'DeZero 함수'를 리스트에 추가할 때 funcs.append(f)를 호출했는데, 대신 add_func 함수를 호출하도록 변경했습니다. 이 add_func함수가 DeZero함수 리스트를 세대 순으로 정렬하는 역할을 합니다. 그 결과 funcs.pop()은 자동으로 세대가 가장 큰 DeZero함수를 꺼내게 됩니다.
참고로 add_func 함수를 backward메서드 안에 중첩 함수로 정의했습니다. 중첩 함수는 주로 다음 두 조건을 충족할 때 적합합니다.
1) 감싸는 메서드(backward 메서드)안에서만 이용한다.
2) 감싸는 메서드(backward 메서드)에 정의된 변수(funcs과 seen_set)를 사용해야 한다.
add_func 함수는 이 조건들을 모두 충족하기 때문에 메서드 안에 정의했습니다.
앞의 구현에서는 seen_set이라는 '집합(set)'을 이용하고 있습니다. funcs리스트에 같은 함수를 중복 추가하는 일을 막기위해서 입니다. 덕분에 함수의 backward메서드가 잘못되어 여러번 불리는 일은 발생하지 않습니다.
댓글 없음:
댓글 쓰기