인공지능 및 데이터과학/딥러닝 프레임워크

케라스(Keras) #2 - MNIST 분류 모델 만들기

Steve Jang 2020. 4. 27. 14:52

MNIST(Modified National Institute of Standards and Technology database)는 0~9까지의 숫자를 사람들이 손으로 직접 쓴 손글씨를 이미지화한 데이터셋이다. 


딥러닝(Deep Learning)을 제외한 머신러닝(Machine Learning)에 IRIS(붓꽃 데이터)셋이 있다면 딥러닝에는 MNIST가 있으며 그만큼 성능을 자랑하기에 매우 딥러닝에 효율적이며 기본으로 내장이 되어 있어서 언제든지 불러서 실습을 해볼 수가 있는 데이터셋이다.


MNIST 데이터셋



최근에는 이에 MNIST보다 조금 더 분류가 힘든 fashion mnist라는 데이터셋이 신규로 추가되었다. 둘의 포맷은 완전히 동일하기에 성능을 측정하기에 MNIST보다 더 수월해보인다.


Fashion MNIST 데이터셋


해당 소스는 기본적으로 텐서플로우(Tensorflow) 사이트에서 제공하는 example 소스를 기반으로 불필요한 내용들은 제외하고 오로지 테스트 결과가 어떤지만을 보여주는 소스로 간략화 시켰다.



from tensorflow import keras
import numpy as np


케라스(Keras)와 numpy를 사용하기에 두개의 라이브러리를 임포트한다


# fashion mnist 데이터를 로드
# 0~9 사이의 숫자 손글씨였던 mnist보다 업그레이드된 버전으로 패션 이미지를 맞추는 문제로 mnist와 포맷은 동일
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()


keras.datasets에 fashion_mnist라는 메소드를 호출하면 데이터를 읽어들이며, 해당 데이터를 train_set과 test_set에 나누어서 저장한다.



처음 실행하면, 데이터를 읽어들이는 과정이 진행되며 다음 실행에는 위와 같은 위치에 있는 데이터를 읽어서 빠르게 수행된다.


class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


숫자로 만들어진 레이블값에 대응하는 실제 레이블명을 배열로 생성한다.


# 값의 범위를 0~1 사이로 조정하기 위해, 255로 나눈다
train_images = train_images / 255.0
test_images = test_images / 255.0


데이터가 0부터 255까지의 값으로 이루어져 있기에 255로 나누게 되면, 0~1사이의 실수값으로 변환된다.


# 모델 구성
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)), # flatten을 사용하여 2차원의 배열을 1차원의 배열 값으로 변경
keras.layers.Dense(128, activation='relu'), # relu 활성화 함수를 사용하고, 128개의 노드를 가진 신경망 층
keras.layers.Dense(10, activation='softmax') # 10개의 레이블 값을 가진 output layer, softmax로 확률로 변형
])


28 by 28로 이루어진 배열 데이터를 flatten을 사용하여 1차원의 배열값(784개)으로 변형하고, relu 활성화 함수(activation function)로 이루어진, 128개의 노드의 신경망을 만들어서 계산하며, 최종적으로는 10개의 레이블값을 가진 output 레이어를 생성한다.



output 레이어는 0~9값의 확률을 구해야 하기 때문에 softmax 활성화 함수를 사용한다.


레이어를 모두 생성했다면, 해당 모델을 Compile 한다.


model.compile(optimizer='adam',                         # 옵티마이저는 대중적인 adam 사용
loss='sparse_categorical_crossentropy', # 손실함수는 sparse categorical crossentropy 사용
metrics=['accuracy']) # 정확도(accuracy)를 지표로 활용


옵티마이저는 아담(adam)을 사용했고, 손실함수는 sparse categorical crossentropy를 사용했으며 지표는 정확도를 이용하게 된다.



데이터 훈련 단계


fit 메소드를 사용하여 데이터를 훈련한다. parameter값으로 학습 이미지 데이터와 레이블된 데이터를 넣고, 총 5번의 반복을 통해 학습을 시킨다.


# 훈련데이터를 모델에 주입 (train_images train_labels)
# 모델은 이미지와 레이블을 매핑하는 방법을 배움
model.fit(train_images, train_labels, epochs=5)


Epoch 5/5


   32/60000 [..............................] - ETA: 3s - loss: 0.2639 - accuracy: 0.9375

 1184/60000 [..............................] - ETA: 2s - loss: 0.2970 - accuracy: 0.9012

 2336/60000 [>.............................] - ETA: 2s - loss: 0.3054 - accuracy: 0.8964

 3360/60000 [>.............................] - ETA: 2s - loss: 0.2967 - accuracy: 0.8949

 4480/60000 [=>............................] - ETA: 2s - loss: 0.2930 - accuracy: 0.8964

 5344/60000 [=>............................] - ETA: 2s - loss: 0.3025 - accuracy: 0.8903

 6176/60000 [==>...........................] - ETA: 2s - loss: 0.3004 - accuracy: 0.8914

 7040/60000 [==>...........................] - ETA: 2s - loss: 0.2982 - accuracy: 0.8928

 7808/60000 [==>...........................] - ETA: 2s - loss: 0.2945 - accuracy: 0.8925

 8800/60000 [===>..........................] - ETA: 2s - loss: 0.2901 - accuracy: 0.8940

 9888/60000 [===>..........................] - ETA: 2s - loss: 0.2872 - accuracy: 0.8945

11136/60000 [====>.........................] - ETA: 2s - loss: 0.2895 - accuracy: 0.8936

12384/60000 [=====>........................] - ETA: 2s - loss: 0.2846 - accuracy: 0.8946

13568/60000 [=====>........................] - ETA: 2s - loss: 0.2881 - accuracy: 0.8926

14752/60000 [======>.......................] - ETA: 2s - loss: 0.2895 - accuracy: 0.8919

15936/60000 [======>.......................] - ETA: 2s - loss: 0.2870 - accuracy: 0.8926

17088/60000 [=======>......................] - ETA: 2s - loss: 0.2884 - accuracy: 0.8923

17920/60000 [=======>......................] - ETA: 2s - loss: 0.2885 - accuracy: 0.8921

18400/60000 [========>.....................] - ETA: 2s - loss: 0.2900 - accuracy: 0.8915

19072/60000 [========>.....................] - ETA: 2s - loss: 0.2887 - accuracy: 0.8921

19744/60000 [========>.....................] - ETA: 2s - loss: 0.2885 - accuracy: 0.8927

20480/60000 [=========>....................] - ETA: 2s - loss: 0.2907 - accuracy: 0.8926

21344/60000 [=========>....................] - ETA: 2s - loss: 0.2882 - accuracy: 0.8937

22272/60000 [==========>...................] - ETA: 1s - loss: 0.2882 - accuracy: 0.8935

22976/60000 [==========>...................] - ETA: 1s - loss: 0.2892 - accuracy: 0.8933

23488/60000 [==========>...................] - ETA: 1s - loss: 0.2891 - accuracy: 0.8936

23968/60000 [==========>...................] - ETA: 1s - loss: 0.2900 - accuracy: 0.8936

24480/60000 [===========>..................] - ETA: 1s - loss: 0.2895 - accuracy: 0.8939

24896/60000 [===========>..................] - ETA: 2s - loss: 0.2889 - accuracy: 0.8938

25280/60000 [===========>..................] - ETA: 2s - loss: 0.2892 - accuracy: 0.8937

25696/60000 [===========>..................] - ETA: 2s - loss: 0.2887 - accuracy: 0.8939

26048/60000 [============>.................] - ETA: 2s - loss: 0.2893 - accuracy: 0.8938

26720/60000 [============>.................] - ETA: 2s - loss: 0.2897 - accuracy: 0.8939

27296/60000 [============>.................] - ETA: 2s - loss: 0.2885 - accuracy: 0.8942

27968/60000 [============>.................] - ETA: 2s - loss: 0.2887 - accuracy: 0.8942

28800/60000 [=============>................] - ETA: 1s - loss: 0.2885 - accuracy: 0.8942

29632/60000 [=============>................] - ETA: 1s - loss: 0.2888 - accuracy: 0.8942

30144/60000 [==============>...............] - ETA: 1s - loss: 0.2882 - accuracy: 0.8944

30496/60000 [==============>...............] - ETA: 1s - loss: 0.2878 - accuracy: 0.8946

30912/60000 [==============>...............] - ETA: 1s - loss: 0.2876 - accuracy: 0.8945

31552/60000 [==============>...............] - ETA: 1s - loss: 0.2880 - accuracy: 0.8944

32096/60000 [===============>..............] - ETA: 1s - loss: 0.2880 - accuracy: 0.8944

32672/60000 [===============>..............] - ETA: 1s - loss: 0.2889 - accuracy: 0.8941

33312/60000 [===============>..............] - ETA: 1s - loss: 0.2899 - accuracy: 0.8936

34144/60000 [================>.............] - ETA: 1s - loss: 0.2897 - accuracy: 0.8935

35008/60000 [================>.............] - ETA: 1s - loss: 0.2886 - accuracy: 0.8939

35808/60000 [================>.............] - ETA: 1s - loss: 0.2893 - accuracy: 0.8938

36544/60000 [=================>............] - ETA: 1s - loss: 0.2887 - accuracy: 0.8938

37312/60000 [=================>............] - ETA: 1s - loss: 0.2894 - accuracy: 0.8934

38112/60000 [==================>...........] - ETA: 1s - loss: 0.2904 - accuracy: 0.8932

38784/60000 [==================>...........] - ETA: 1s - loss: 0.2906 - accuracy: 0.8931

39488/60000 [==================>...........] - ETA: 1s - loss: 0.2905 - accuracy: 0.8932

40352/60000 [===================>..........] - ETA: 1s - loss: 0.2904 - accuracy: 0.8933

41440/60000 [===================>..........] - ETA: 1s - loss: 0.2912 - accuracy: 0.8932

42592/60000 [====================>.........] - ETA: 1s - loss: 0.2909 - accuracy: 0.8931

43680/60000 [====================>.........] - ETA: 1s - loss: 0.2915 - accuracy: 0.8930

44672/60000 [=====================>........] - ETA: 1s - loss: 0.2914 - accuracy: 0.8931

45856/60000 [=====================>........] - ETA: 0s - loss: 0.2913 - accuracy: 0.8931

47328/60000 [======================>.......] - ETA: 0s - loss: 0.2918 - accuracy: 0.8930

48640/60000 [=======================>......] - ETA: 0s - loss: 0.2916 - accuracy: 0.8932

49984/60000 [=======================>......] - ETA: 0s - loss: 0.2920 - accuracy: 0.8933

51392/60000 [========================>.....] - ETA: 0s - loss: 0.2928 - accuracy: 0.8931

52736/60000 [=========================>....] - ETA: 0s - loss: 0.2927 - accuracy: 0.8930

54016/60000 [==========================>...] - ETA: 0s - loss: 0.2922 - accuracy: 0.8933

55328/60000 [==========================>...] - ETA: 0s - loss: 0.2922 - accuracy: 0.8933

56640/60000 [===========================>..] - ETA: 0s - loss: 0.2926 - accuracy: 0.8931

57856/60000 [===========================>..] - ETA: 0s - loss: 0.2925 - accuracy: 0.8931

59264/60000 [============================>.] - ETA: 0s - loss: 0.2926 - accuracy: 0.8930

60000/60000 [==============================] - 4s 58us/sample - loss: 0.2930 - accuracy: 0.8927


필자의 Train accuracy는 약 89% 정도의 정확도를 보여주었다.



모델 평가


이 학습된 모델에 test 데이터셋을 넣어서 train과 test의 성능 차이를 측정해본다. train은 over-fitting(과적합)이 될 수 있기에 결과는 test 데이터를 넣은 evaluate 메소드로 평가한다.


# 테스트데이터(test_images, test_labels)를 이용하여 모델의 성능 비교
# Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('\ntest accuracy:', test_acc)


10000/1 - 0s - loss: 0.2942 - accuracy: 0.8758


test accuracy: 0.8758


테스트 데이터를 모델에 넣었을 때, 87~88% 정도의 정확도를 보여주었다. 



분류 예측하기


모델이 성공적으로 만들어졌다면, 이제 실제 각각의 데이터를 호출하여 분류를 어떻게 하고 있는지 확인한다.


# 예측 만들기
predictions = model.predict(test_images)


우선, predict 메소드에 test_images 데이터셋을 넣어, 예측 값을 생성한다.


print(predictions[0])
print(class_names[test_labels[0]], '=>', class_names[np.argmax(predictions[0])])


0번째 데이터의 레이블별 예측 결과와 실제 데이터와 예측한 데이터를 호출한다.

predictions에는 해당 데이터의 모든 레이블값의 확률을 출력하기 때문에 argmax를 사용하여, 가장 높은 값의 index값을 리턴한다.


[1.0713779e-05 1.6287389e-07 1.1720056e-06 1.8306851e-07 4.9521755e-06

 1.6651142e-02 5.3147642e-06 4.2165697e-02 4.6171881e-05 9.4111449e-01]

Ankle boot => Ankle boot


0번째 데이터의 실제값은 Ankle boot이고, 예측값도 Ankle boot로 예측에 성공하였다.


이번에는 약 100개의 데이터를 호출하여 결과의 예측도를 확인해본다.


for i in range(0, 100):
print(class_names[test_labels[i]], '=>', class_names[np.argmax(predictions[i])])


...

T-shirt/top => T-shirt/top

Dress => Dress

Pullover => Pullover

T-shirt/top => T-shirt/top

Shirt => Pullover

Sandal => Sandal

Dress => Dress

Shirt => Shirt

Sneaker => Sneaker

Trouser => Trouser

Bag => Bag

T-shirt/top => T-shirt/top

...



참고자료