페이지

2018년 8월 7일 화요일

3. 분류 3.1 MNIST

이 장에서는 고등학생과 미국 인구조사국 직원들이 손으로 쓴 70,000개의 작은 숫자 이미지를 모은 MNIST 데이터셋을 상요하겠습니다. 각 이미지에는 어떤 숫자를 나타내는지 레이블되어 있습니다. 이 데이터셋은 학습용으로 아주 많이 사용되기 때문에 머신러닝 분야의 'Hello World'라고 불립니다. 새로운 분류 알고리즘이 나올 때마다 MNIST 데이터셋에서 얼마나 잘 작동하는지 봅시다. 머신러닝을 배우는 사람이라면 머지않아 MNIST 데이터셋을 맞닥뜨리게 될 것입니다.

사이킷런에서 제공하는 여러 헬퍼 함수를 사용해 잘 알려진 데이터셋을 내려받을 수 있습니다. MNIST도 그 중 하나입니다. 다음은 MNIST 데이터셋을 내려받는 코드입니다.

사이킷런에서 읽어 들인 데이터셋들은 일반적으로 비슷한 딕셔너리 구조를 가지고 있습니다. 

3.3.3 정밀도와 재현율
사이킷런은 정밀도와 재현율을 포함하여 분류기의 지표를 계산하는 여러 함수를 제공합니다.

이제 '5-감지기'가 정확도에서 봤을 때만큼 멋져 보이지는 않네요. 5로 판별된 이미지 중 77%만 정확합니다. 더군다나 전체 숫자 5에서 80%만 감지했습니다.
정밀도와 재현율을 F1점수라고 하는 하나의 숫자로 만들면 편리할 때가 많습니다. 특히 두 분류기를 비교할 때 그렇습니다. F1점수는 정밀도와 재현율의 조화평균(harmonic mean)입니다.

F1 = 2/(1/정밀도 + 1/재현율) = 2 * (정밀도 * 재현율)/(정밀도 + 재현율) = TP/ (TP + (FN+FP)/2)

F1 점수를 계싼하려면 f1_score()함수를 호출하면 됩니다.

정밀도와 재현율이 비슷한 분류기에스는 F1점수가 높습니다. 하지만 이게 항상 바람직한 것은 아닙니다. 상황에 따라 정밀도가 중요할 수도 있고 재현율이 중요할수도 있습니다. 예를 들어 어린아이에게 안전한 동영상을 걸러내는 분류기를 훈련시킨다고 가정해보겠습지다. 재현율은 높으나 정말 나쁜 동양상이 몇 개 노출되는 것보다 좋은 동영상이 많이 제외되더라도(낮은 재현율) 안전한 것들만 노출시키는 (높은 정밀도)분류기를 선호할 것입니다(이런 경우에는 분류기의 동양상 선택 결과를 확인하기 위해 사람이 참여하는 분석 파이프라인을 추가할지도 모릅니다). 다른 예로, 감시 카메라를 통해 좀도둑을 잡아내는 분류기를 훈련시킨다고 가정해보겠습니다. 분류기의 재현율이 99%라면 정확도가 30%만 되더라도 괜찮을지 모릅니다(아마도 경비원이 잘못된 호출을 종종 받게 되겠지만, 거의 모든 좀도둑을 잡을 것입니다).

안됐지만 이 둘을 모두 얻을 수는 없습니다. 정밀도를 올리면 재현율이 줄고 그 반대도 마찬가지입니다. 이를 정밀도/재현율 트레이드오프 라고 합니다.

3.3.4 정밀도/재현율 트레이드오프
SGDClassifier가 분류를 어떻게 결정하는지 살펴보며 이 트레이드오프를 이해해보겠습니다.
이 분류기는 결정함수(decision function)를 사용하여 각 샘플의 점수를 계산합니다. 이 점수가 임곗값보다 크면 샘플을 양성 클래스에 할당하고 그렇지 않으면 음성 클래스를 할당합니다.[그림 3-3]에 가장 낮은 점수부터 가장 높은 점수까지 몇 개의 숫자를 나열했습니다. 결정 임곗값이 가운데(두 개의 숫자 5사이)화살표라고 가정해보겠습니다. 임곗값 오른쪽에 4개의 진짜 양성(실제 숫자5)과 하나의 거짓 양성(실제 숫자6)이 있습니다. 그렇기 때문에 이 임곗값에서 정밀도는 80%(5개 중 4개)입니다. 하지만 실제 숫자 5는 6개고 분류기는 4개만 감지했으므로 재션율은 67%(6개중 4개)입니다. 이번 입곗값을 높이면(임곗값을 오른쪽 화살표로 옮기면) 거짓 양성(숫자 6)이 진짜 음성이 되어 정밀도가 높아집니다(이 경우에 100%가 됩니다). 하지만 진짜 양성 하나가 거짓 음성이 되었으므로 재현율이 50%로 줄어듭니다. 반대로 임곗값을 내리면 재현율이 높아지고정밀도가 줄어듭니다.

사이킷런에서 임곗값을 직접 지정할 수는 없지만 예측에 사용한 점수는 확인할 수 있습니다. 분류기의 predict()메서드 대신 decision_function()메서드를 호출하면 각 샘플의 점수를 얻을 수 있습니다. 이 점수를 기반으로 원하는 임곗값을 정해 예측을 만들수 있습니다.

SDGClassifier의 임곗값이 0이므로 위 코드는 predict()메서드와 같은 결과(즉, True)를 반환합니다.

이 결과는 임곗값을 높이면 재현율이 줄어든다는 것을 보여줍니다. 이미지가 실제로 숫자 5이고 임곗값이 0일때는 분류기가 이를 감지했지만, 임곗값을 200,000으로 높이면 이를 놓치게 됩니다.

그렇다면 적절한 임곗값을 어떻게 정할 수 있을까요? 이를 위해서는 먼저 cross_val_predict()함수를 사용해 훈련 세트에 있는 모든 샘플의 점수를 구해야 합니다. 하지만 이번에는 예측 결과가 아니라 결정 점수를 반환받도록 지정해야 합니다.

이 점수로 precision_recall_curve()함수를 사용하여 가능한 모든 임곗값에 대해 정밀도와 재현율을 계산할 수 있습니다.



댓글 없음: