[Java] 직접 구현해본 나이브 베이즈 분류기 #2 (코드포함)

    이 포스팅은 직접 구현해본 나이브 베이즈 분류기 #1에 연이은 포스팅으로 #1을 아직 못보신 분들은 이전 포스팅을 읽고 오셔야 이해가 될 것이다. #1 포스팅을 보고 싶으면 본 포스팅의 끝에 있는 연관 글을 찾고 해당 글을 클릭하면 된다. 

     

    자바로 구현하는 나이브베이즈 분류기 #2

     

    빈도테이블의 값 출력(디버깅)

    for(String feat : freqMap.keySet()) {
    	System.out.println(feat + "=>" + freqMap.get(feat));
    }
    temperature=>{mild={no=2, yes=4}, cool={no=1, yes=3}, hot={no=2, yes=2}}
    humidity=>{normal={no=1, yes=6}, high={no=4, yes=3}}
    outlook=>{rainy={no=2, yes=3}, overcast={yes=4}, sunny={no=3, yes=2}}
    windy=>{false={no=2, yes=6}, true={no=3, yes=3}}

    이전 포스팅에서 빈도테이블값을 구하는 메소드까지 봤는데 해당 메소드를 호출 후 데이터를 출력해보면 위와 같이 Feature의 값별로 클래스 카운트가 세팅되어 있는 것을 확인할 수 있다.

     

    calcPriorProb (사전 확률을 구한다)

    /**
     * 사전 확률을 연산 한 후, 결과를 리턴
     * 
     * @param map
     * @return
     */
    public NBPriorProbVO calcPriorProb(
    		Map<String, Map<String, Map<String, Integer>>> freqMap) {
    
    	NBPriorProbVO vo = new NBPriorProbVO ();
    	Map<String, Map<String, Double>> featValPriorProb 
    						= new HashMap<String, Map<String, Double>> ();
    	
    	int total = 0;
    	for(String feat : freqMap.keySet()) {
    		// {mild={no=2, yes=4}, cool={no=1, yes=3}, hot={no=2, yes=2}
    		Map<String, Map<String, Integer>> valMap = freqMap.get(feat);
    		
    		// 확률을 계산하기 위해 total값 세팅
    		if(total == 0) {
    			Map<String, Integer> clssCntMap = new HashMap<String, Integer> ();
    			for(String val : valMap.keySet()) {
    				Map<String, Integer> cntMap = valMap.get(val);
    				for(String clss : cntMap.keySet()) {
    					total += cntMap.get(clss);
    					
    					int cnt = cntMap.get(clss);
    					if(clssCntMap.containsKey(clss)) {
    						cnt += clssCntMap.get(clss);							
    					}
    					clssCntMap.put(clss, cnt);				
    				}
    			}				
    			vo.setClssCnt(clssCntMap);	// 클래스별 카운트 값을 세팅(우도 연산 위함)
    			
    			// 클래스 사전 확률 세팅
    			Map<String, Double> clssPriorProb = new HashMap<String, Double> ();
    			for(String clss : clssCntMap.keySet()) {
    				clssPriorProb.put(clss, (double)clssCntMap.get(clss)/total);
    			}
    			vo.setClssPriorProb(clssPriorProb);
    		}
    		
    		// 특징별 사전확률 세팅
    		Map<String, Double> valPriorProb = new HashMap<String, Double> (); 
    		for(String val : valMap.keySet()) {
    			int valCnt = 0;
    			Map<String, Integer> cntMap = valMap.get(val);
    			
    			for(String clss : cntMap.keySet()) {
    				valCnt += cntMap.get(clss);
    			}
    			
    			valPriorProb.put(val, (double)valCnt/total);
    			//System.out.println(val + "=>" + valCnt);				
    		}			
    		featValPriorProb.put(feat, valPriorProb);
    	}		
    	vo.setFeatValPriorProb(featValPriorProb);
    	
    	return vo;
    }	

    빈도테이블을 사전 확률의 인자값으로 던지게 되면, 사전 확률 Value Object에 값이 담겨져서 리턴된다.

     

    사전확률 디버깅

    System.out.println("class : " + PRIOR_PROB_VO.getClssPriorProb());
    System.out.println("feat : " + PRIOR_PROB_VO.getFeatValPriorProb());
    class : {no=0.35714285714285715, yes=0.6428571428571429}
    feat : {temperature={mild=0.42857142857142855, cool=0.2857142857142857, hot=0.2857142857142857}, humidity={normal=0.5, high=0.5}, outlook={rainy=0.35714285714285715, overcast=0.2857142857142857, sunny=0.35714285714285715}, windy={false=0.5714285714285714, true=0.42857142857142855}}

     

    빈도맵으로 우도 계산

    // 빈도맵을 기반으로 우도를 구한다
    LIKELIHOOD = calcLikelihood(freqMap, PRIOR_PROB_VO.getClssCnt());
    System.out.println("======================= likelihood ====================");
    System.out.println(LIKELIHOOD);
    /**
     * 우도 연산
     * 
     * @param freqMap
     * @param clssCntMap
     * @return
     */
    public Map<String, Double> calcLikelihood(
    		Map<String, Map<String, Map<String, Integer>>> freqMap,
    		Map<String, Integer> clssCntMap) {
    	
    	Map<String, Double> likelihoodMap = new HashMap<String, Double> ();
    	
    	for(String feat : freqMap.keySet()) {
    		Map<String, Map<String, Integer>> valMap = freqMap.get(feat);
    		for(String val : valMap.keySet()) {
    			String key = feat + "_" + val;	// feature와 value를 합쳐 키 생성
    			
    			Map<String, Integer> clssMap = valMap.get(val);
    			for(String clss : clssMap.keySet()) {
    				int cnt = clssMap.get(clss);
    				
    				// P(key | clss)
    				likelihoodMap.put(key + "|" + clss, (double)cnt/clssCntMap.get(clss));					
    			}					
    		}
    	}
    	
    	
    	return likelihoodMap;		
    }	

    우도 디버깅

    ======================= likelihood ====================
    {outlook_rainy|no=0.4, temperature_mild|no=0.4, temperature_cool|no=0.2, outlook_overcast|yes=0.4444444444444444, temperature_hot|no=0.4, humidity_high|no=0.8, temperature_mild|yes=0.4444444444444444, outlook_sunny|no=0.6, windy_false|no=0.4, humidity_high|yes=0.3333333333333333, outlook_rainy|yes=0.3333333333333333, temperature_cool|yes=0.3333333333333333, temperature_hot|yes=0.2222222222222222, humidity_normal|no=0.2, humidity_normal|yes=0.6666666666666666, outlook_sunny|yes=0.2222222222222222, windy_true|yes=0.3333333333333333, windy_true|no=0.6, windy_false|yes=0.6666666666666666}

    우도까지 모두 구했다면, Train 과정이 모두 완료가 되었다. 이제 이 Train 데이터를 기반으로 test를 통해 결과를 예측해본다.

     

    Test 과정

    Test 호출

    if(naiveBayes.train(csvData, "outlook,temperature,humidity,windy".split(","),"play",false)) {
    	Map<String, String> paramMap = new HashMap<String, String> ();
    	paramMap.put("outlook", "rainy");
    	paramMap.put("temperature", "cool");
    	paramMap.put("humidity", "normal");
    	
    	NBResultVO vo = naiveBayes.test(paramMap, "yes,no".split(","));
    	System.out.println(vo.getPostProbMap());
    	System.out.println("max prob class => " + vo.getClss());
    	System.out.println("max prob => " + vo.getProb());
    } else {
    	System.out.println("Fail");
    }

    Train이 성공하면, 테스트를 진행하는데 naiveBayes.test에는 Feature에 Value를 넣은 map이 있고, 이 맵과 클래스값을 인자값으로 test에 던졌다. 

     

    test 메소드

    /**
     * 모델을 기반으로 결과 생성
     * 
     * @param test X 
     * @param clss
     */
    public NBResultVO test(Map<String, String> paramMap, String[] clssArr) {
    	NBResultVO vo = new NBResultVO();
    	Map<String, Map<String, Double>> featValPriorProb = PRIOR_PROB_VO.getFeatValPriorProb();
    	double probX = 0.0;
    	boolean probXFlag = true;
    	int cnt = 0;
    	
    	// result
    	Map<String, Double> postProbMap = new HashMap<String, Double> ();
    	double maxProb = 0.0;
    	String maxClss = "";
    	
    	for(String clss : clssArr) {
    		double clssXLike = COMMON.formatDouble(PRIOR_PROB_VO.getClssPriorProb().get(clss), 2);
    		
    		for(String feat : paramMap.keySet()) {
    			// 우도값을 가져온다
    			String key = feat + "_" + paramMap.get(feat) + "|" + clss;
    			
    			// P( X | Clss )				
    			//System.out.println(key + "=>" + LIKELIHOOD.get(key));
    			clssXLike *= COMMON.formatDouble(LIKELIHOOD.get(key), 2);
    			
    			Map<String, Double> valPriorProb = featValPriorProb.get(feat);
    			
    			if(probXFlag) {
    				if(cnt == 0) {
    					probX = COMMON.formatDouble(valPriorProb.get(paramMap.get(feat)), 2);
    				} else {
    					probX *= COMMON.formatDouble(valPriorProb.get(paramMap.get(feat)), 2);
    				}
    			}
    			cnt++;
    		}			
    		probXFlag = false;
    		
    		System.out.println("=============" + clss + "===============");
    		System.out.println("clssXlike => " + COMMON.formatDouble(clssXLike, 4));
    		System.out.println("probX => " + COMMON.formatDouble(probX, 4));
    		double prob = COMMON.formatDouble(clssXLike, 5)/COMMON.formatDouble(probX, 4);
    		postProbMap.put(clss, prob);	// 사후확률 계산
    		
    		if(prob >= maxProb) {
    			maxProb = prob;
    			maxClss = clss;
    		}
    	}
    	
    	vo.setClss(maxClss);
    	vo.setProb(maxProb);
    	vo.setPostProbMap(postProbMap);
    	
    	return vo;
    }

    연산을 사용할 떄 COMMON.formatDouble이라는 메소드는 가독성을 위한 메소드로, 소수점의 자리를 끊어주는 처리를 한다. 즉 foramtDouble(값, 2) 라고 적으면 소수점 2번째자리까지 가져오라는 의미이다.

    /**
    * 더블형을 받아온 후, 소수점을 자르고 리턴
    * 
    * @param value
    * @param point
    */
    public double formatDouble(double value, int point) {
    	return Double.parseDouble(String.format("%." + point + "f", value));
    }

     

    test 메소드의 디버깅

    =============yes===============
    clssXlike => 0.0467
    probX => 0.0522
    =============no===============
    clssXlike => 0.0058
    probX => 0.0522

    테스트까지 모두 연산이 끝나면, 최종 결과를 확인해보도록 한다. 

     

    최종 분류 결과

    {no=0.1103448275862069, yes=0.8946360153256704}
    max prob class => yes
    max prob => 0.8946360153256704

    test 메소드를 거친 NBResultVO의 값을 출력하면, no는 0.11, yes는 0.89의 확률이라는 것을 알 수 있다. 그래서 최종적으로는 분류가 yes 즉 테니스를 칠 수 있다고 결론이 난다. 소스를 이렇게 포스팅에 하나씩 설명하는게 얼마나 힘든지 이번에 알게되어서 다음에는 핵심만 쓰던지 아니면 설명없이 붙이는 형태가 낫지 않을까 싶다.

     

    기계학습(머신러닝)을 만약에 처음 접하는 사람이라면 이것이 머신러닝이 맞어?라고 생각할 수 있다. 하지만 우리는 학습 데이터를 토대로 모델(베이즈 통계기반 모델)을 만들었고 이 모델을 토대로 특징과 값을 넣으니 결과가 나왔다. 이것이 머신러닝이고, 신경망도 사실 매우 간단한 신경망은 개념 역시 쉽다. 딥러닝으로 갈수록 복잡할 뿐이지...

     

    그리고 이 몇시간도 안돼서 짠, 자바 소스가 매우 강력한 예측모델까지 가지고 있으니 개발자라면 왠만해서는 머신러닝 모델을 직접 짜보는 것을 추천드린다.

     

    최종소스

    steel-analyzer.zip
    0.05MB

     

    연관포스팅

    확률의 함정을 간파, 베이즈 정리(Bayes' Theorem)

    쉽고 강력한 머신러닝, 나이브 베이즈 분류 (Naive Bayes Classification)

    [Python] 파이썬으로 나이브베이즈 구현하기

    [Java] 직접 구현해본 나이브 베이즈 분류기 #1

    댓글

    Designed by JB FACTORY