이제부터 순전파만 할 경우를 위한 개선을 DeZero에 추가하겠습니다. 우선 두 가지 모드, 즉 '역전파 활성 모드'와 '역전파 비활성 모드'를 전환하는 구조가 필요합니다. 간단히 다음 Config클래스를 이용할 것입니다.
보시다시피 아주 간단한 클래스입니다. 이 클래스의 속성은 (현재) 불러언 타입인 enable_backprop만 존재합니다. enable_backprop은 역전파가 가능한지 여부를 뜻하고, 이 값이 True면 '역전파 활성 모드'입니다.
설정 데이터는 단 한 군데에만 존재하는 게 좋습니다. 그래서 Config클래스는 인스턴스화하지 않고 '클래스' 상태로 이용합니다. 인스턴스는 여러 개 생성할 수 있지만 클래스는 항상 하나만 존재하기 때문이죠. 따라서 앞 코드에서 Config 클래스가 '클래스 속성'을 갖도록 설정했습니다.
Config클래스를 정의했으니 Function에서 참조하게 하여 모드를 전환할 수 있게 하겠습니다. 코드로는 다음과 같습니다.
import weakref
class Function(object):
def __call__(self, *inputs):
xs = [x.data for x in inputs]
ys = self.forward(*xs)
if not isinstance(ys, tuple):
ys = (ys,)
outputs = [Variable(as_array(y)) for y in ys]
if Config.enable_backprop:
self.generation = max([x.generation for x in inputs]) # 1 세대 설정
for output in outputs:
output.set_creator(self) # 연결 설정
self.inputs = inputs
self.outputs = [weakref.ref(output) for output in outputs]
return outputs if len(outputs) > 1 else outputs[0]
이와 같이 Config.enable_Backprop이 True일 때만 역전파 코드가 실행됩니다. 1 에서 정하는 '세대'는 역전파 시 노드를 따라가는 순서를 정하는 데 사용됩니다. 따라서 '역전파 비활성 모드'에서는 필요하지 않습니다. 또한 2의 output.set_creator(self)는 계산들의 '연결'을 만드는데, 마찬가지로 '역전파 비활성 모드'에서는 필요 없습니다.
댓글 없음:
댓글 쓰기