페이지

2022년 9월 6일 화요일

26.2 계산 그래프에서 DOT 언어로 변환하기

 이상의 내용을 코드로 옮겨봅시다. 몸체인 get_dot_graph 함수는 잠시 뒤로 미루고, _dot_var라는 보조 함수부터 구현하겠습니다. 이름 앞에 밑줄(_)이 붙은 이유는 이 함수를 로컬에서만, 즉 get_dot_graph 함수 전용으로 사용할 것이기 때문입니다. 다음은 _dot_var 함수의 코드와 사용 예입니다.

def _dot_var(v, verbose=False):

    dot_var = '{} [label="{}", color=orange, style=filled]\n'

    

    name = '' if v.name is None else v.name

    if verbose and v.data is not None:

            if v.name is not None:

                    name += ': '

            name += str(v.shape) + ' ' + str(v.dtype)

    return dot_var.format(id(v), name)


# 사용 예

x = Variable(np.random.randn(2, 2))

x.name = 'x'

print(_dot_var(x))

print(_dot_var(x, verbose=True))


2622676553440 [label="x", color=orange, style=filled]

2622676553440 [label="x: (2, 2) float64", color=orange, style=filled]


이와 같이 _dot_var 함수에 Variable인스턴스를 건내면 인스턴스의 내용을 DOT 언어로 작성된 문자열로 바꿔서 반환합니다. 한편 변수 노두에 고유한 ID를 부여하기 위해 파이썬 내장함수인 id를 사용했습니다. id 함수는 주어진 객체의 ID를 반환하는데, 객체 ID는 다른 객체의 중복되지 않기 때문에 노드의 ID로 상용하기에 적합합니다.

또한 마지막 반환 직적에 format메서드를 이용했습니다. format 메서드는 문자열에 등장하는 "{}" 부분을 메서드 인수로 건넨 객체(문자열이나 정수 등)로 차례로 바꿔줍니다. 가령 앞의 코드에서는 dot_var 문자열의 첫 번째{} 자리에는 id(v)의 값이 , 두 번째 {} 자리에는 name의 값이 채워집니다.


_dot_var 함수는 verbose 인수도 받습니다. 이 값을 True로 설정하면 ndarray인스턴스의 '형상'과 '타입'도 함께 레이블로 출력합니다.


이어서 'DeZero 함수'를 DOT 언어로 변환하는 편의 함수를 구현하겠습니다. 이름은 _dot_ func이고 코드는 다음과 같습니다.

def _dot_func(f):

    dot_func = '{} [label="{}", color=lightblue, style=filled, shape=box]\n'

    txt = dot_func.format(id(f), f.__class__.__name__)

    

    dot_edge = '{} -> {}\n'

    for x in f.inputs:

        txt += dot_edge.format(id(x), id(f))

    for y in f.outputs:

        txt += dot_edge.format(id(f), id(y()))   # y는 약한 참조 (weakref, 17.4절 참고)        

    return txt

# 사용 예

x0 = Variable(np.array(1.0))

x1 = Variable(np.array(1.0))

y = x0 +  x1

txt = _dot_func(y.creator)

print(txt)


1991655372544 [label="Add", color=lightblue, style=filled, shape=box]
1991655371872 -> 1991655372544
1991629516560 -> 1991655372544
1991655372544 -> 1991619916176

_dot_func 함수는 'DeZero 함수'를 DOT 언어로 기술합니다. 또한 '함수와 입력 변수의 관계' 그리고 '함수와 출력 변수의 관계'도 DOT 언어로 기술합니다. 복습해보자면, DeZero 함수는 Function 클래스를 상속하고 [그림 26-2] 처럼 inputs와 outputs라는 인스턴스 변수를 가지고 있습니다.


준비가 끝났습니다. 이제 본격적으로 get_dot_graph 함수를 구현할 차례입니다. Variable 클래스의 backward 메서드를 참고하여 다음과 같이 구현할 수 있습니다.


def get_dot_graph(output, verbose=True):

    txt = ''

    funcs = [] 

    seen_set = set()

    

    def add_func(f):

        if f not in seen_set:

            funcs.append(f)

            # funcs.sort(key=lamhbda x: x.generation)

            seen_set.add(f)

            

    add_func(output.creator)

    txt += _dot_var(output, verbose)

    while funcs:

        func = funcs.pop()

        txt += _dot_func(func)

        for x in fuinc.inputs:

            txt += _dot_var(x, verboxe)

            

            if x.creator is not None:

                add_func(x.creator)

                

    return 'digraph g {\n' +  txt + '}'

이 코드의 로직은 Variable 클래스의 backward메서드와 거의 같습니다(backward 메서드 구현에서 달라진 부분은 음영으로 표시했습니다). backward 메서드는 미분값을 전파했지만 여기에서는 미분 대신 DOT언어로 기술한 문자열을 txt에 추가합니다.

또한 실제 역전파에선느 노드를 따라가는 순서가 중요했습니다. 그래서 함수에 generation(세대)이라는 정숫값을 부여하고 그 값이 큰 순서대로 꺼냈죠(자세한 내용은 15-16단계 참고). 하지만 get_dot_graph 함수에는 노드를 추적하는 순서는 문제가 되지 않으므로 generation 값으로 정렬하는 코드를 주석으로 처리했습니다.


계산 그래프를 DOT언어로 변환할 때는 '어떤 노드가 존재하는가' 와 '어떤 노드끼리 연결되는가'가 문제입니다. 즉, 노드의 추적 '순서'는 문제가 되지 않기 때문에 generation을 사용하여 순서대로 꺼내는 구조는 사용하지 않아도 됩니다.


이것으로 계산 그래프 시각화 코드가 완성되었습니다. 이어서 계산 그래프를 더 손쉽게 시각화하는 함수를 추가하겠습니다.


댓글 없음: