[Java] 직접 구현해본 나이브 베이즈 분류기 #1

    이 포스팅은 머신러닝 알고리즘 중 하나인 나이브 베이즈 분류기를 자바(Java)로 구현해본 것으로 개발자가 쉽게 접근할 수 있는 것을 목표로 합니다. 

     

    쉽고 강력한 머신러닝, 나이브 베이즈 분류 (Naive Bayes Classification)

     

    쉽고 강력한 머신러닝, 나이브 베이즈 분류 (Naive Bayes Classification)

    ※ 베이즈 정리를 모르는 분들은 나이브 베이즈를 알기에 앞서 베이즈 정리에 대해서 먼저 이해해야 한다. 확률의 함정을 간파, 베이즈 정리(Bayes' Theorem) 확률의 함정을 간파, 베이즈 정리(Bayes' T

    needjarvis.tistory.com

    포스팅에서 설명하는 스텝 즉 나이브베이즈의 정석인 사전확률, 우도, 사후확률을 쉽게 이해하고자 각각 메소드로 구현을 하였고 자바를 어느정도 알고 있는 개발자에 해당 포스팅을 이해한 사람이라면 충분히 소스 분석이 가능합니다. 다만 해당 코드는 러프하게 짠 것이기 때문에 에러가 있을 수 있으니 참고하기 바랍니다.

     

    자바로 구현하는 나이브베이즈 분류기 #1

     

    사용한 데이터

    play.csv
    0.00MB

     

     

    날씨에 따른 테니스 가능 여부의 14개의 행 데이터입니다. 이 예제는 위 나이브 베이즈 포스팅에 올려놓은 데이터와 동일합니다.

     

    play.csv

    no,outlook,temperature,humidity,windy,play
    1,sunny,hot,high,False,no
    2,sunny,hot,high,True,no
    3,overcast,hot,high,False,yes
    4,rainy,mild,high,False,yes
    5,rainy,cool,normal,False,yes
    6,rainy,cool,normal,True,no
    7,overcast,cool,normal,True,yes
    8,sunny,mild,high,False,no
    9,sunny,cool,normal,False,yes
    10,rainy,mild,normal,False,yes
    11,sunny,mild,normal,True,yes
    12,overcast,mild,high,True,yes
    13,overcast,hot,normal,False,yes
    14,rainy,mild,high,True,no

     

    play.csv 데이터 읽기

    List<Map<String, String>> csvData 
    			= fileModule.readCsvToListMap("./sample/play.csv", true, ",");

     

    데이터를 행과 열로 표현하기 위해서 list안에 map을 담았고 데이터는 일단 string 형으로 모두 읽습니다. object 등으로 받아서 처리를 한다면 feature별로 형을 설정해야 하는데 play.csv 파일을 보면 알겠지만 모든 데이터가 문자형입니다.

     

    readCsvToListMap

    /**
     * CSV 파일을 읽은 후, LisMap 형태로 변환
     * 
     * @param fullPath
     * @param useLower
     * @param sepa
     * @return
     */
    public List<Map<String, String>> readCsvToListMap(
    		String fullPath, boolean useLower, String sepa) {
    	
    	List<Map<String, String>> rtnList = new ArrayList<Map<String, String>> ();
    	BufferedReader br = null;
    	try {
    		br = new BufferedReader(
    					new InputStreamReader(
    						new FileInputStream(fullPath), "UTF8"));
    		
    		String line = "";
    		int lineCnt = 0;
    		String[] cols = null;
    		
    		while((line = br.readLine()) != null) {
    			// header
    			if(lineCnt == 0) {
    				cols = line.trim().split(sepa);
    			} else {
    				Map<String, String> map = new LinkedHashMap<String, String> ();
    				if(line.trim().length() > 0) {
    					String[] values = line.trim().split(sepa);
    					for(int i = 0; i < cols.length; i++) {
    						String val = (useLower) ? values[i].toLowerCase() : values[i];
    						map.put(cols[i], val);							
    					}
    					rtnList.add(map);
    				}
    			}
    			lineCnt++;
    		}				
    	} catch (Exception e) {
    		System.out.println("readCsvToListMap error : " + e.getMessage());
    		return null;
    	} finally {
    		if(br != null) {
    			try {
    				br.close();
    			} catch (Exception eo) {}
    		}
    	}
    }

    첫번째 인자는 파일 경로, 두번째 인자는 소문자 형태로 변환할지의 여부, 마지막은 csv의 구분자를 지정합니다.

     

    System.out.println(csvData);	// CSV값 출력

    CSV값 출력결과

    [
    {no=1, outlook=sunny, temperature=hot, humidity=high, windy=false, play=no},
    {no=2, outlook=sunny, temperature=hot, humidity=high, windy=true, play=no},
    {no=3, outlook=overcast, temperature=hot, humidity=high, windy=false, play=yes},
    {no=4, outlook=rainy, temperature=mild, humidity=high, windy=false, play=yes}, 
    {no=5, outlook=rainy, temperature=cool, humidity=normal, windy=false, play=yes}, 
    {no=6, outlook=rainy, temperature=cool, humidity=normal, windy=true, play=no}, 
    {no=7, outlook=overcast, temperature=cool, humidity=normal, windy=true, play=yes}, 
    {no=8, outlook=sunny, temperature=mild, humidity=high, windy=false, play=no}, 
    {no=9, outlook=sunny, temperature=cool, humidity=normal, windy=false, play=yes}, 
    {no=10, outlook=rainy, temperature=mild, humidity=normal, windy=false, play=yes}, 
    {no=11, outlook=sunny, temperature=mild, humidity=normal, windy=true, play=yes}, 
    {no=12, outlook=overcast, temperature=mild, humidity=high, windy=true, play=yes}, 
    {no=13, outlook=overcast, temperature=hot, humidity=normal, windy=false, play=yes}, 
    {no=14, outlook=rainy, temperature=mild, humidity=high, windy=true, play=no}
    ]

     

    CSV를 리스트 맵형태로 담은 변수를 기반으로 나이브 베이즈 학습을 진행해야 합니다. 

     

     

    train 메소드

    /**
     * 나이브베이즈 학습
     * 
     * @param map
     * @param X (특징 벡터)
     * @param Y (클래스)
     * @param useSmooth (라플라스 스무딩)
     */
    public boolean train(
    		List<Map<String, String>> trainData,
    		String[] X, String Y, boolean useSmooth) {
    	
    	// 빈도맵을 생성
    	Map<String, Map<String, Map<String, Integer>>> freqMap = makeFreqMap(trainData, X, Y);
    	// debug
    	for(String feat : freqMap.keySet()) {
    		System.out.println(feat + "=>" + freqMap.get(feat));
    	}
    	
    	// 빈도맵을 기반으로 사전 확률을 구한다
    	PRIOR_PROB_VO = calcPriorProb(freqMap);
    	System.out.println("class : " + PRIOR_PROB_VO.getClssPriorProb());
    	System.out.println("feat : " + PRIOR_PROB_VO.getFeatValPriorProb());
    	
    	// 빈도맵을 기반으로 우도를 구한다
    	LIKELIHOOD = calcLikelihood(freqMap, PRIOR_PROB_VO.getClssCnt());
    	System.out.println("======================= likelihood ====================");
    	System.out.println(LIKELIHOOD);
    	
    	if(PRIOR_PROB_VO == null || LIKELIHOOD == null && LIKELIHOOD.size() == 0)
    		return false;
    	
    	return true;
    }

    train메소드의 첫번째 인자값은 csv를 읽은 리스트맵 데이터이며, 두번째 인자값 X는 학습에 사용될 feature 배열, 세번째 인자값 Y는 class key 값이고, 마지막은 스무딩(Smoothing)을 사용할지 여부인데 본 포스팅에서는 인자값만 넣고 스무딩을 구현하지는 않았습니다. 추후 텍스트를 기반으로 나이브 베이즈 분류를 하는 것을 설명할 때 스무딩 기능을 추가할 예정입니다.

     

    train메소드에서 시작하자마자 makeFreqMap을 호출하는데 이 메소드는 빈도 테이블을 만드는 메소드입니다.

     

    makeFreqMap

    /**
     * 원본 맵을 빈도 맵으로 변경
     * 
     * @param map
     * @return
     */
    public Map<String, Map<String, Map<String, Integer>>> makeFreqMap(
    		List<Map<String, String>> list, String[] X, String Y) {
    	
    	// feature(temperature), val(cool,hot,mild), (clss=cnt, clss=cnt, clss=cnt)
    	Map<String, Map<String, Map<String, Integer>>> freqMap 
    					= new HashMap<String, Map<String, Map<String, Integer>>> ();
    	
    	for(Map<String, String> map : list) {
    		for(String feat : X) {
    			Map<String, Map<String, Integer>> valMap = null;
    			Map<String, Integer> clssCntMap = null;
    			
    			// 해당 feature가 존재하는 경우
    			if(freqMap.containsKey(feat)) {
    				// feature 맵을 읽는다
    				valMap = freqMap.get(feat);
    				
    				// 해당 feature맵 안에 value 값이 존재하는 경우
    				if(valMap.containsKey(map.get(feat))) {	
    					
    					// value 맵을 읽어서, clssCntMap을 가져온다
    					clssCntMap = valMap.get(map.get(feat));
    					
    					int cnt = 1;
    					// value에 clss 카운트가 존재하면 가져온 후 + 1
    					if(clssCntMap.containsKey(map.get(Y))) {
    						cnt += clssCntMap.get(map.get(Y));	
    					}						
    					clssCntMap.put(map.get(Y), cnt);
    				} else {
    					// feature 안에 해당 value가 없을 경우, 우선 해당 클래스에 1을 생성
    					clssCntMap = new HashMap<String, Integer> ();
    					clssCntMap.put(map.get(Y), 1);
    				}					
    				// 클래스 카운트맵을 value맵에 세팅
    				valMap.put(map.get(feat), clssCntMap);
    				
    			} else {
    				// 해당 feature가 존재하지 않는 경우 1로 세팅한 map 생성
    				valMap = new HashMap<String, Map<String, Integer>> ();
    				// feature 안에 해당 value가 없을 경우, 우선 해당 클래스에 1을 생성
    				clssCntMap = new HashMap<String, Integer> ();
    				clssCntMap.put(map.get(Y), 1); 					
    				valMap.put(map.get(feat), clssCntMap);
    			}
    			
    			freqMap.put(feat, valMap);								
    		}
    		
    	}
    	
    	return freqMap;
    }

    메소드 안에 주석을 자세히 달았으니 주석을 보고 용도가 어떤지를 판단하면 될 것 같네요. 더 효과적인 방법이 있을 수 있으나 최대한 이전 포스팅을 기반으로 만들어 봤습니다.

     

    내용이 길어져서 사전확률과 우도를 구하고 최종적으로 테스트를 하는 부분은 다음 포스팅에서 작성하도록 하겠습니다. 

     

    댓글

    Designed by JB FACTORY