케라스(Keras) #2 - MNIST 분류 모델 만들기
MNIST(Modified National Institute of Standards and Technology database)는 0~9까지의 숫자를 사람들이 손으로 직접 쓴 손글씨를 이미지화한 데이터셋이다.
딥러닝(Deep Learning)을 제외한 머신러닝(Machine Learning)에 IRIS(붓꽃 데이터)셋이 있다면 딥러닝에는 MNIST가 있으며 그만큼 성능을 자랑하기에 매우 딥러닝에 효율적이며 기본으로 내장이 되어 있어서 언제든지 불러서 실습을 해볼 수가 있는 데이터셋이다.
MNIST 데이터셋
최근에는 이에 MNIST보다 조금 더 분류가 힘든 fashion mnist라는 데이터셋이 신규로 추가되었다. 둘의 포맷은 완전히 동일하기에 성능을 측정하기에 MNIST보다 더 수월해보인다.
Fashion MNIST 데이터셋
해당 소스는 기본적으로 텐서플로우(Tensorflow) 사이트에서 제공하는 example 소스를 기반으로 불필요한 내용들은 제외하고 오로지 테스트 결과가 어떤지만을 보여주는 소스로 간략화 시켰다.
from tensorflow import keras
import numpy as np
케라스(Keras)와 numpy를 사용하기에 두개의 라이브러리를 임포트한다
# fashion mnist 데이터를 로드
# 0~9 사이의 숫자 손글씨였던 mnist보다 업그레이드된 버전으로 패션 이미지를 맞추는 문제로 mnist와 포맷은 동일
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
keras.datasets에 fashion_mnist라는 메소드를 호출하면 데이터를 읽어들이며, 해당 데이터를 train_set과 test_set에 나누어서 저장한다.
처음 실행하면, 데이터를 읽어들이는 과정이 진행되며 다음 실행에는 위와 같은 위치에 있는 데이터를 읽어서 빠르게 수행된다.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
숫자로 만들어진 레이블값에 대응하는 실제 레이블명을 배열로 생성한다.
# 값의 범위를 0~1 사이로 조정하기 위해, 255로 나눈다
train_images = train_images / 255.0
test_images = test_images / 255.0
데이터가 0부터 255까지의 값으로 이루어져 있기에 255로 나누게 되면, 0~1사이의 실수값으로 변환된다.
# 모델 구성
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)), # flatten을 사용하여 2차원의 배열을 1차원의 배열 값으로 변경
keras.layers.Dense(128, activation='relu'), # relu 활성화 함수를 사용하고, 128개의 노드를 가진 신경망 층
keras.layers.Dense(10, activation='softmax') # 10개의 레이블 값을 가진 output layer, softmax로 확률로 변형
])
28 by 28로 이루어진 배열 데이터를 flatten을 사용하여 1차원의 배열값(784개)으로 변형하고, relu 활성화 함수(activation function)로 이루어진, 128개의 노드의 신경망을 만들어서 계산하며, 최종적으로는 10개의 레이블값을 가진 output 레이어를 생성한다.
output 레이어는 0~9값의 확률을 구해야 하기 때문에 softmax 활성화 함수를 사용한다.
레이어를 모두 생성했다면, 해당 모델을 Compile 한다.
model.compile(optimizer='adam', # 옵티마이저는 대중적인 adam 사용
loss='sparse_categorical_crossentropy', # 손실함수는 sparse categorical crossentropy 사용
metrics=['accuracy']) # 정확도(accuracy)를 지표로 활용
옵티마이저는 아담(adam)을 사용했고, 손실함수는 sparse categorical crossentropy를 사용했으며 지표는 정확도를 이용하게 된다.
데이터 훈련 단계
fit 메소드를 사용하여 데이터를 훈련한다. parameter값으로 학습 이미지 데이터와 레이블된 데이터를 넣고, 총 5번의 반복을 통해 학습을 시킨다.
# 훈련데이터를 모델에 주입 (train_images와 train_labels)
# 모델은 이미지와 레이블을 매핑하는 방법을 배움
model.fit(train_images, train_labels, epochs=5)
Epoch 5/5
32/60000 [..............................] - ETA: 3s - loss: 0.2639 - accuracy: 0.9375
1184/60000 [..............................] - ETA: 2s - loss: 0.2970 - accuracy: 0.9012
2336/60000 [>.............................] - ETA: 2s - loss: 0.3054 - accuracy: 0.8964
3360/60000 [>.............................] - ETA: 2s - loss: 0.2967 - accuracy: 0.8949
4480/60000 [=>............................] - ETA: 2s - loss: 0.2930 - accuracy: 0.8964
5344/60000 [=>............................] - ETA: 2s - loss: 0.3025 - accuracy: 0.8903
6176/60000 [==>...........................] - ETA: 2s - loss: 0.3004 - accuracy: 0.8914
7040/60000 [==>...........................] - ETA: 2s - loss: 0.2982 - accuracy: 0.8928
7808/60000 [==>...........................] - ETA: 2s - loss: 0.2945 - accuracy: 0.8925
8800/60000 [===>..........................] - ETA: 2s - loss: 0.2901 - accuracy: 0.8940
9888/60000 [===>..........................] - ETA: 2s - loss: 0.2872 - accuracy: 0.8945
11136/60000 [====>.........................] - ETA: 2s - loss: 0.2895 - accuracy: 0.8936
12384/60000 [=====>........................] - ETA: 2s - loss: 0.2846 - accuracy: 0.8946
13568/60000 [=====>........................] - ETA: 2s - loss: 0.2881 - accuracy: 0.8926
14752/60000 [======>.......................] - ETA: 2s - loss: 0.2895 - accuracy: 0.8919
15936/60000 [======>.......................] - ETA: 2s - loss: 0.2870 - accuracy: 0.8926
17088/60000 [=======>......................] - ETA: 2s - loss: 0.2884 - accuracy: 0.8923
17920/60000 [=======>......................] - ETA: 2s - loss: 0.2885 - accuracy: 0.8921
18400/60000 [========>.....................] - ETA: 2s - loss: 0.2900 - accuracy: 0.8915
19072/60000 [========>.....................] - ETA: 2s - loss: 0.2887 - accuracy: 0.8921
19744/60000 [========>.....................] - ETA: 2s - loss: 0.2885 - accuracy: 0.8927
20480/60000 [=========>....................] - ETA: 2s - loss: 0.2907 - accuracy: 0.8926
21344/60000 [=========>....................] - ETA: 2s - loss: 0.2882 - accuracy: 0.8937
22272/60000 [==========>...................] - ETA: 1s - loss: 0.2882 - accuracy: 0.8935
22976/60000 [==========>...................] - ETA: 1s - loss: 0.2892 - accuracy: 0.8933
23488/60000 [==========>...................] - ETA: 1s - loss: 0.2891 - accuracy: 0.8936
23968/60000 [==========>...................] - ETA: 1s - loss: 0.2900 - accuracy: 0.8936
24480/60000 [===========>..................] - ETA: 1s - loss: 0.2895 - accuracy: 0.8939
24896/60000 [===========>..................] - ETA: 2s - loss: 0.2889 - accuracy: 0.8938
25280/60000 [===========>..................] - ETA: 2s - loss: 0.2892 - accuracy: 0.8937
25696/60000 [===========>..................] - ETA: 2s - loss: 0.2887 - accuracy: 0.8939
26048/60000 [============>.................] - ETA: 2s - loss: 0.2893 - accuracy: 0.8938
26720/60000 [============>.................] - ETA: 2s - loss: 0.2897 - accuracy: 0.8939
27296/60000 [============>.................] - ETA: 2s - loss: 0.2885 - accuracy: 0.8942
27968/60000 [============>.................] - ETA: 2s - loss: 0.2887 - accuracy: 0.8942
28800/60000 [=============>................] - ETA: 1s - loss: 0.2885 - accuracy: 0.8942
29632/60000 [=============>................] - ETA: 1s - loss: 0.2888 - accuracy: 0.8942
30144/60000 [==============>...............] - ETA: 1s - loss: 0.2882 - accuracy: 0.8944
30496/60000 [==============>...............] - ETA: 1s - loss: 0.2878 - accuracy: 0.8946
30912/60000 [==============>...............] - ETA: 1s - loss: 0.2876 - accuracy: 0.8945
31552/60000 [==============>...............] - ETA: 1s - loss: 0.2880 - accuracy: 0.8944
32096/60000 [===============>..............] - ETA: 1s - loss: 0.2880 - accuracy: 0.8944
32672/60000 [===============>..............] - ETA: 1s - loss: 0.2889 - accuracy: 0.8941
33312/60000 [===============>..............] - ETA: 1s - loss: 0.2899 - accuracy: 0.8936
34144/60000 [================>.............] - ETA: 1s - loss: 0.2897 - accuracy: 0.8935
35008/60000 [================>.............] - ETA: 1s - loss: 0.2886 - accuracy: 0.8939
35808/60000 [================>.............] - ETA: 1s - loss: 0.2893 - accuracy: 0.8938
36544/60000 [=================>............] - ETA: 1s - loss: 0.2887 - accuracy: 0.8938
37312/60000 [=================>............] - ETA: 1s - loss: 0.2894 - accuracy: 0.8934
38112/60000 [==================>...........] - ETA: 1s - loss: 0.2904 - accuracy: 0.8932
38784/60000 [==================>...........] - ETA: 1s - loss: 0.2906 - accuracy: 0.8931
39488/60000 [==================>...........] - ETA: 1s - loss: 0.2905 - accuracy: 0.8932
40352/60000 [===================>..........] - ETA: 1s - loss: 0.2904 - accuracy: 0.8933
41440/60000 [===================>..........] - ETA: 1s - loss: 0.2912 - accuracy: 0.8932
42592/60000 [====================>.........] - ETA: 1s - loss: 0.2909 - accuracy: 0.8931
43680/60000 [====================>.........] - ETA: 1s - loss: 0.2915 - accuracy: 0.8930
44672/60000 [=====================>........] - ETA: 1s - loss: 0.2914 - accuracy: 0.8931
45856/60000 [=====================>........] - ETA: 0s - loss: 0.2913 - accuracy: 0.8931
47328/60000 [======================>.......] - ETA: 0s - loss: 0.2918 - accuracy: 0.8930
48640/60000 [=======================>......] - ETA: 0s - loss: 0.2916 - accuracy: 0.8932
49984/60000 [=======================>......] - ETA: 0s - loss: 0.2920 - accuracy: 0.8933
51392/60000 [========================>.....] - ETA: 0s - loss: 0.2928 - accuracy: 0.8931
52736/60000 [=========================>....] - ETA: 0s - loss: 0.2927 - accuracy: 0.8930
54016/60000 [==========================>...] - ETA: 0s - loss: 0.2922 - accuracy: 0.8933
55328/60000 [==========================>...] - ETA: 0s - loss: 0.2922 - accuracy: 0.8933
56640/60000 [===========================>..] - ETA: 0s - loss: 0.2926 - accuracy: 0.8931
57856/60000 [===========================>..] - ETA: 0s - loss: 0.2925 - accuracy: 0.8931
59264/60000 [============================>.] - ETA: 0s - loss: 0.2926 - accuracy: 0.8930
60000/60000 [==============================] - 4s 58us/sample - loss: 0.2930 - accuracy: 0.8927
필자의 Train accuracy는 약 89% 정도의 정확도를 보여주었다.
모델 평가
이 학습된 모델에 test 데이터셋을 넣어서 train과 test의 성능 차이를 측정해본다. train은 over-fitting(과적합)이 될 수 있기에 결과는 test 데이터를 넣은 evaluate 메소드로 평가한다.
# 테스트데이터(test_images, test_labels)를 이용하여 모델의 성능 비교
# Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('\ntest accuracy:', test_acc)
10000/1 - 0s - loss: 0.2942 - accuracy: 0.8758
test accuracy: 0.8758
테스트 데이터를 모델에 넣었을 때, 87~88% 정도의 정확도를 보여주었다.
분류 예측하기
모델이 성공적으로 만들어졌다면, 이제 실제 각각의 데이터를 호출하여 분류를 어떻게 하고 있는지 확인한다.
# 예측 만들기
predictions = model.predict(test_images)
우선, predict 메소드에 test_images 데이터셋을 넣어, 예측 값을 생성한다.
print(predictions[0])
print(class_names[test_labels[0]], '=>', class_names[np.argmax(predictions[0])])
0번째 데이터의 레이블별 예측 결과와 실제 데이터와 예측한 데이터를 호출한다.
predictions에는 해당 데이터의 모든 레이블값의 확률을 출력하기 때문에 argmax를 사용하여, 가장 높은 값의 index값을 리턴한다.
[1.0713779e-05 1.6287389e-07 1.1720056e-06 1.8306851e-07 4.9521755e-06
1.6651142e-02 5.3147642e-06 4.2165697e-02 4.6171881e-05 9.4111449e-01]
Ankle boot => Ankle boot
0번째 데이터의 실제값은 Ankle boot이고, 예측값도 Ankle boot로 예측에 성공하였다.
이번에는 약 100개의 데이터를 호출하여 결과의 예측도를 확인해본다.
for i in range(0, 100):
print(class_names[test_labels[i]], '=>', class_names[np.argmax(predictions[i])])
...
T-shirt/top => T-shirt/top
Dress => Dress
Pullover => Pullover
T-shirt/top => T-shirt/top
Shirt => Pullover
Sandal => Sandal
Dress => Dress
Shirt => Shirt
Sneaker => Sneaker
Trouser => Trouser
Bag => Bag
T-shirt/top => T-shirt/top
...
참고자료