프로그래밍/기타2015.06.12 15:45

나이브 베이지안 알고리즘에 대해 생각해 볼 일이 생겼다.

일단 대체 이게 어디에 쓰이는 알고리즘인지 알아보자.


구글 검색을 해보니 역시 조대협님의 블로그에서 잘 정리된 정보를 찾을 수 있었다. (역시 보물창고)

http://bcho.tistory.com/1010


한 줄 요약하면, 머신 런닝 분야에서 분류 알고리즘으로서 널리 쓰이고 있으며, 이 알고리즘을 통해 문서 분류기 같은 것을 만들 수 있다. 예를 들어 어떤 메일이 있을 때 이 메일이 스팸이냐 아니냐를 분류하거나, 어떤 뉴스 기사가 있을 때 해당 기사가 경제 기사냐, 스포츠 기사냐를 분류한다.


자세한 수학적 이론과 예제는 위에 조대협님 블로그에서 참고하기로 하고 여기에서는 나이브 베이지안 알고리즘을 적용한 문서 분류기 코드를 작성해 보자.


문제) 다음과 같이 5개의 학습 문서가 존재하고, 분류가 comedy(코메디 영화), action(액션 영화) 두개가 존재한다고 하자. 이제 어떤 문서에 fun, furious, fast 라는 3개의 단어만 있는 문서가 있을 때, 이 문서가 코메디인지 액션 영화인지 분리를 해보자. (문제 예시는 조대협님의 블로그 예시를 그대로 가져옴)


 영화 

 단어

 분류 

 1

 fun, couple, love, love

 Comedy

 2

 fast, furious, shoot

 Action

 3

 couple, fly, fast, fun, fun

 Comedy

 4

 furious, shoot, shoot, fun

 Action

 5

 fly, fast, shoot, love

 Action


자 일단 전체 코드 부터 보자.

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class NaiveBayesianClassification {

    private String[] dataSet;
    private Map<String, Long> classifies = new HashMap<>();
    private Map<String, Map<String, Long>> counter = new HashMap<>();

    public NaiveBayesianClassification(String[] dataSet) {
        if (dataSet == null || dataSet.length == 0)
            throw new IllegalArgumentException("Empty dataSet");
        this.dataSet = dataSet;
    }

    private String getClassify(String input) {
        int divide = input.indexOf('|');
        return divide > - 1 ? input.substring(0, divide) : null;
    }

    private String[] getWords(String input) {
        int divide = input.indexOf('|');
        return divide > -1 ? input.substring(divide+1).split(",") : null;
    }

    public void training() {
        Arrays.stream(dataSet).forEach(data -> {
            String classify = getClassify(data);
            String[] words = getWords(data);
            //-- 분류명과 분류명이 나타난 횟수를 classifies에 저장한다.
            if (classify != null) {
                Long count = classifies.get(classify);
                if (count == null)
                    count = 1L;
                else
                    count++;
                classifies.put(classify, count);
                //-- 각 분류명에 대해 특정 단어가 나타난 횟수를 couner에 저장한다.
                if (words != null) {
                    Arrays.stream(words).forEach(word -> {
                        Map<String, Long> wordCounter = counter.get(classify);
                        if (wordCounter == null) {
                            wordCounter = new HashMap<>();
                            counter.put(classify, wordCounter);
                        }
                        Long wordCount = wordCounter.get(word);
                        if (wordCount == null)
                            wordCount = 1L;
                        else
                            wordCount++;
                        wordCounter.put(word, wordCount);
                    });
                }
            }
        });
    }

    public String judgment(String[] words) {
        Map<String, Double> results = new HashMap<>();
        long classifiesTotalCount = classifies.values().stream().mapToLong(Long::longValue).sum();
        classifies.forEach((classify, count) -> {
            double[] points = Arrays.stream(words).mapToDouble(word -> {
                Map<String, Long> wordCounter = counter.get(classify);
                if (wordCounter == null)
                    return 0.0f;
                Long wordCount = wordCounter.get(word);
                if (wordCount == null)
                    return 0.0f;
                long wordTotalCount = wordCounter.values().stream().mapToLong(Long::longValue).sum();
                return (double)wordCount / wordTotalCount;
            }).toArray();
            double total = (double)classifies.get(classify) / classifiesTotalCount;
            total = Arrays.stream(points).reduce(total, (x, y) -> x * y);
            results.put(classify, total);
        });
        results.entrySet().forEach(entry ->
                System.out.println(String.format("%s : %f", entry.getKey(), entry.getValue())));
        return results.entrySet().stream().max(Map.Entry.comparingByValue(Double::compareTo)).get().getKey();
    }

    public static void main(String[] args) throws Exception {
        //-- 학습 데이터
        String[] dataSet = {
                "Comedy|fun,couple,love,love",
                "Action|fast,furious,shoot",
                "Comedy|couple,fly,fast,fun,fun",
                "Action|furious,shoot,shoot,fun",
                "Action|fly,fast,shoot,love"
        };
        //-- 테스트 데이터
        String[] words = {"fun", "furious", "fast"};

        NaiveBayesianClassification classifier = new NaiveBayesianClassification(dataSet);
        classifier.training();
        String classify = classifier.judgment(words);
        System.out.println(classify);
    }

}

trainning 메소드는 dataSet을 통해 문서 분류기를 학습 시키는 역할을 한다. 먼저 데이터로부터 분류와 단어들을 추출하고, classifies에는 분류명과 해당 분류명의 나타난 횟수를 기록한다. counter는 분류별로 특정 단어가 나타난 횟수를 기록하는데 사용한다.


judgment 메소드는 단어들이 주어졌을 때, 주어진 단어를 통해 해당 문서가 어떤 분류에 속할지 계산한다. 이를 위헤 classifies에 포함된 모든 분류들에 대해 확률값을 계산한 후 그 중 가장 큰 확률값을 지닌 분류를 선택한다.


위 분류기에 의한 결과값은 Action : 0.001803, Comedy : 0.000000 으로 "fun", "furious", "fast" 단어들을 포함하는 문서는 Action 영화으로 분류된다.

Posted by devop

댓글을 달아 주세요