[R] 인공신경망을 활용한 스팸 필터링 분석 (Spam Filtering using neuralnet in R)

해당 포스트에서는 R에서 인공신경망(nnet)을 이용해 스팸 필터링 문제를 풀이하는 방법에 대해 설명합니다.

spam-filter-using-r

INTRO

스팸 필터링(Spam Filtering) 문제는 기본적으로 자연어 처리(NLP, Natural Language Processing)를 기반으로 하며, 텍스트 데이터가 숫자(임베딩)로 변환된 후에는 다양한 알고리즘 적용이 가능합니다.


아래에서는 인공신경망(ANN)을 사용하여 스팸을 예측하는 분류 문제를 소개합니다. 풀이 절차는 nnet 패키지를 사용하여 모델을 구축하고, 내장된 스팸 데이터셋을 불러와 전처리 한 뒤, 적절한 노드 수를 찾아 모델을 적합시키고 결과를 분석합니다.

[참고] 스팸 필터링(Spam Filtering)에 대한 이론적 이해가 필요하신 분은 아래 링크를 참고해 주세요.

스팸 필터링 분석

1. 환경 설정

먼저, 분석에 필요한 kernlabnnet 패키지를 설치하고 라이브러리를 로드합니다.

# Set the Environment

install.packages("kernlab")
library(kernlab)

install.packages("nnet")
library(nnet)

2. 데이터 로딩 및 전처리

spam 데이터 세트를 로드하고 type 변수에 따라 '스팸 메일(spam_s)'과 '일반 메일(email_s)'로 나눕니다. 이후, 두 데이터를 각각 동일한 비율(5:3:2)로 분리하여 spam.train, spam.test, spam.valid 에 저장하고, class.ind() 함수를 사용하여 type 변수를 범주형으로 변환합니다.

# Load Data

data(spam)
head(spam)
table(spam$type)

# Preprocessing

# sampling
spam_s <- spam[spam$type == "spam", ]
email_s <- spam[spam$type == "nonspam", ]

# Random split
set.seed(1225)
ind1 <- sample(3, nrow(spam_s), replace=T, prob=c(0.5, 0.3, 0.2))

set.seed(0425)
ind2 <- sample(3, nrow(email_s), replace=T, prob=c(0.5, 0.3, 0.2))

spam.train <- rbind(spam_s[ind1==1,], email_s[ind2==1,]) # train data
spam.test <- rbind(spam_s[ind1==2,], email_s[ind2==2,]) # test data
spam.valid <- rbind(spam_s[ind1==3,], email_s[ind2==3,]) # validation data

# transform data
spam.train$type <- class.ind(spam.train$type)
spam.test$type <- class.ind(spam.test$type)
spam.valid$type <- class.ind(spam.valid$type)


myformula에 공식을 저장하고, test.err() 함수를 사용하여, 테스트 오차가 가장 낮은 노드 수를 찾습니다. sapply() 함수를 사용하여 노드 수를 변경하면서 테스트 오차를 계산하고, 그래프를 그려서 최적의 노드 수를 결정합니다.

# formula
myformula <- type ~ .

# Decide the number of nodes by using test error
test.err <- function(h.size, maxit0){
  spam_model_e1 <- nnet(myformula, data=spam.train, size=h.size, decay=5e-4, trace=F, maxit=maxit0)
  y <- spam.test$type
  p <- predict(spam_model_e1, spam.test)
  err <- mean(y != p)
  c(h.size, err)
}

# Comparing test error rates for neural networks with 1-10 hidden units
out <- t(sapply(1:10, FUN = test.err, maxit0=200))
plot(out, type="b", xlab="The number of Hidden units", ylab="Test Error")
# nodes number = 4 (1st : 2 or 10, 2nd : 4 or 8,  3rd : 2 or 4, 4th : 2 or 4, 5th : 2 or 8)

3. 데이터 분석

분석 단계에서는 인공신경망 분석을 위해 nnet 패키지의 nnet() 함수로 신경망 모델을 적합시킵니다. 이 때, size, decay, range, maxit 등의 하이퍼 파라미터(Hyperparameter)를 조정하고, 적합된 모델(model_fit)은 summary() 함수를 사용하여 요약 정보를 확인합니다.

# Data Analysis

# Fitting the model using nnet
model_fit <- nnet(myformula, size=2, decay=5e-4, range=0.1, maxit=200, data=spam.train)
summary(model_fit)
# a 57-2-2 network with 122 weights
# options were - decay=5e-04
# b->h1  i1->h1  i2->h1  i3->h1  i4->h1  i5->h1  i6->h1  i7->h1  i8->h1  i9->h1 i10->h1 i11->h1 i12->h1 i13->h1 i14->h1 i15->h1 i16->h1 i17->h1 i18->h1 i19->h1 i20->h1 i21->h1 
# 1.85    0.11   -0.03   -0.04   -1.29   -1.52   -0.84  -28.02   -0.19   -2.21    0.42    2.48    0.54    0.74    1.35   -0.74   -0.55   -5.69   -0.05    0.13   -6.95   -0.21 
# i22->h1 i23->h1 i24->h1 i25->h1 i26->h1 i27->h1 i28->h1 i29->h1 i30->h1 i31->h1 i32->h1 i33->h1 i34->h1 i35->h1 i36->h1 i37->h1 i38->h1 i39->h1 i40->h1 i41->h1 i42->h1 i43->h1 
# 0.07  -22.44   -1.56   29.12    3.02   27.69   -2.13   14.67   -1.07   -0.11    2.17   -0.12   -3.50    6.98   -2.25    1.05    7.46    1.28    4.86   -0.06    4.76   -0.68 
# i44->h1 i45->h1 i46->h1 i47->h1 i48->h1 i49->h1 i50->h1 i51->h1 i52->h1 i53->h1 i54->h1 i55->h1 i56->h1 i57->h1 
# 8.27    1.39   27.53    2.38   13.85    2.20    1.37   -0.10   -1.82  -27.16    0.14   -1.20    0.02    0.00 
# b->h2  i1->h2  i2->h2  i3->h2  i4->h2  i5->h2  i6->h2  i7->h2  i8->h2  i9->h2 i10->h2 i11->h2 i12->h2 i13->h2 i14->h2 i15->h2 i16->h2 i17->h2 i18->h2 i19->h2 i20->h2 i21->h2 
# 0.08    0.01   -0.10   -0.01    0.00   -0.16    0.27    0.01    0.00   -0.02    0.02    0.01    0.00    0.01   -0.04    0.02   -1.61    0.00   -0.01   -0.08    0.00    1.60 
# i22->h2 i23->h2 i24->h2 i25->h2 i26->h2 i27->h2 i28->h2 i29->h2 i30->h2 i31->h2 i32->h2 i33->h2 i34->h2 i35->h2 i36->h2 i37->h2 i38->h2 i39->h2 i40->h2 i41->h2 i42->h2 i43->h2 
# -0.01    0.00    0.29   -0.36   -0.12   -0.15    0.01   -0.01   -0.01    0.00   -0.02   -0.01    0.00    0.02   -0.07    0.00    0.01    0.02   -0.01    0.01   -0.07    0.01 
# i44->h2 i45->h2 i46->h2 i47->h2 i48->h2 i49->h2 i50->h2 i51->h2 i52->h2 i53->h2 i54->h2 i55->h2 i56->h2 i57->h2 
# 0.00   -0.04    0.01   -0.02    0.01   -0.07   -0.17    0.01    3.09    0.01    0.01    0.13    0.25   -1.74 
# b->o1 h1->o1 h2->o1 
# -3.97  23.68   7.31 
# b->o2 h1->o2 h2->o2 
# 3.98 -23.85  -6.69 

출력 결과 해석

  • 결과에서 i_n은 n번째 input node를 의미하고 h_o는 o번째 hidden node, o_p는 p번째 output node를 의미합니다.
  • 57개의 input node가 첫 번째 hidden node로 향할 때는 i_n→h1만큼의 가중치를 가지고 57개의 input값의 선형결합이 만들어집니다.
  • 두 번째 hidden node로 갈 때에도 마찬가지로 input node각각의 가중치로 선형결합이 만들어집니다.
  • 그리고 세 개의 hidden node 내에서는 Activation function을 통해 값이 변화되고 그 값들이 또다시 가중치를 받아 선형결합을 이루어 output node로 나오게 됩니다.
  • 57개의 input node 때문에 많은 가중치들이 존재하는데, 기준에 따라 다르겠지만 여기서는 가중치의 값이 20이 넘어가면 매우 크다고 보고 영향력이 큰 부분이라고 판단합니다.
  • input node 7, 23, 27, 31, 41, 44에서 첫 번째 hidden node로 가는 시냅스 부분이 가중치가 20 이상입니다.
  • 따라서 7, 23, 27, 31, 41, 44 input node에 해당하는 이 변수들의 가중치가 크다는 것은 다른 변수들에 비해 영향력이 크다는 것입니다.
  • hidden node에서 output node로 가는 가중치를 살펴보면 첫 번째 hidden node에서 첫 번째 output node로 갈 때의 가중치와, 세 번째 hidden node에서 두 번째 output node로 갈 때의 가중치는 음수입니다.
  • 또한, 첫 번째 hidden node에서 두 번째 output node로 갈 때의 가중치와, 세 번째 hidden node에서 첫 번째 output node로 갈 때의 가중치는 양수입니다.
  • 이 값은 최종 모형에서 sigmoid 함수의 계수 값으로 weighted sum을 대입한 activation function의 가중치를 나타낸다.
  • 결국 첫 번째, 세 번째 hidden node의 값이 클수록, 두 번째 hidden node의 값이 작을수록 1로 분류할 확률이 높아진다고 볼 수 있습니다.(여기서 1은 spam)

4. 성능 평가

이번 단계는 분류 모델의 성능을 평가하기 위해 검증(validation) 데이터와 테스트(test) 데이터를 사용하여 혼동 행렬(confusion matrix)을 계산하고, 정확도(accuracy), 민감도(sensitivity), 특이도(specificity)를 계산합니다.

성능 평가 도구로 SDMTools 패키지를 사용하며, spam.train으로 학습한 예측 모델로 spam.valid, spam.test을 각각 예측하고 평가합니다. 각 결과는 아래와 같습니다.

출력 결과 해석 (spam.valid)

# Validation

# Confusion matrix using validation data
install.packages("SDMTools")
library(SDMTools)

spam.pred.valid <- predict(model_fit, new=spam.valid)

obs <- spam.valid$type[,2] > 0.5
pred <- spam.pred.valid[,2] > 0.5
cfm <- table(data.frame(pred, obs)) # confusion matrix of validation data

sensitivity <- cfm[1,1]/(cfm[1,1]+cfm[2,1])
specificity <- cfm[2,2]/(cfm[1,2]+cfm[2,2])
accuracy <- (cfm[1,1]+cfm[2,2])/(cfm[1,1]+cfm[1,2]+cfm[2,1]+cfm[2,2])
> cfm
#        obs
# pred    FALSE TRUE
# FALSE   496   32
# TRUE     34  318

> sensitivity # 0.9358491
> specificity # 0.9085714
> accuracy # 0.925

spam.valid의 경우, sensitivity는 0.9358491, specificity는 0.9085714, accuracy는 0.925입니다. 이는 spam.valid 데이터에서 모델 예측이 양성일 때 실제로 양성인 비율(sensitivity)이 높은 것을 보여주며, 모델 예측이 음성일 때 실제로 음성인 비율(specificity)은 상대적으로 낮게 나타났습니다. 전체적인 예측 정확도(accuracy)는 92.5%입니다.

출력 결과 해석 (spam.test)

# Compare with test & validation
plot(spam.pred.valid[,2] ~ spam.valid$type[,2])

spam.pred.test <- predict(model_fit, new=spam.test)

obs <- spam.test$type[,2] > 0.5
pred <- spam.pred.test[,2] > 0.5
cfm.test <- table(data.frame(pred, obs)) # confusion matrix of validation data

sensitivity <- cfm.test[1,1]/(cfm.test[1,1]+cfm.test[2,1])
specificity <- cfm.test[2,2]/(cfm.test[1,2]+cfm.test[2,2])
accuracy <- (cfm.test[1,1]+cfm.test[2,2])/(cfm.test[1,1]+cfm.test[1,2]+cfm.test[2,1]+cfm.test[2,2])
> cfm.test
#       obs
# pred    FALSE TRUE
# FALSE   792   43
# TRUE     64  518

> sensitivity # 0.9252336
> specificity # 0.9233512
> accuracy #0.9244884

spam.test의 경우, sensitivity는 0.9252336, specificity는 0.9233512, accuracy는 0.9244884입니다. 이는 spam.test 데이터에서 spam.valid 데이터와 비슷한 결과를 보여주는데, sensitivity는 약간 낮아졌지만 specificity는 약간 높아졌고, 전체적인 예측 정확도는 92.4%로 0.1% 차이로 확인되었습니다.

5. 결론

이번 분석에서는 spam.train 데이터셋을 이용하여 학습하고, spam.valid, spam.test 데이터셋을 이용하여 모델을 검증하였습니다. 결과적으로 spam.validspam.test 데이터 모두 예측 정확도가 높게 나타났으며, spam.valid 데이터에서는 sensitivity가 높지만 specificity가 낮은 것에 비해, spam.test 데이터에서는 두 지표가 비슷한 수준으로 나타나는 것을 볼 수 있었습니다.

이러한 결과를 통해, 구축된 모델이 일반화하여 예측하는 능력이 있다는 것을 알 수 있습니다.

관련 링크

[1] 스팸 필터링(Spam Filtering)
[2] 인공신경망(ANN)


banner-request-analysis