segunda-feira, 25 de junho de 2012

k-means clustering com R e ggplot2

Olá pessoal, tudo bem? Neste post vamos falar um pouco sobre análise de clusters, uma técnica bastante poderosa e que possui aplicações nos mais diversos campos: biologia, medicina, marketing, para citar alguns. Nossa tarefa consiste em classificar um grupo de países em 3 categorias distintas (desenvolvidos, emergentes e subdesenvolvidos), analisando uma tabela contendo indicadores socio-econômicos para cada país. Ah sim, usaremos R para a análise!

Visão geral

O conceito de cluster ou agrupamento é bem genérico e intuitivo: podemos pensar em um cluster de galáxias, computadores, pessoas, etc. De maneira geral, dado um conjunto de objetos do mesmo tipo, uma análise de clusters consiste em agrupar esses objetos em grupos distintos, de maneira que objetos no mesmo grupo possuam características similares.

As técnicas de clustering são bastante diversificadas, e os algoritmos diferem entre si justamente no que se refere a noção do que constitui um cluster. Os modelos mais comuns de clustering incluem: connectivity based, centroid based, distribution based, density based, entre outros. Para simplificar as coisas, usaremos o algoritmo k-means que é baseado no conceito de centróide. Aqui você encontra uma boa introdução sobre as técnicas citadas acima.

Algoritmo k-means

O objetivo do k-means é achar a melhor divisão de n objetos em k grupos, de modo que a distância total entre os membros de um grupo (cluster) e o seu respectivo centróide seja minimizada. Imagine cada centróide como sendo o 'representante médio' dos objetos em um mesmo cluster, ao minimizar esta distância, estamos maximizando a similaridade entre os objetos deste cluster. De modo simplificado, o algoritmo segue os seguintes passos:

  1. Defina o centróide inicial de cada grupo (normalmente de modo randômico)
  2. Atribua cada objeto ao cluster com centróide mais próximo
  3. Recalcule os centróides (já que novos objetos foram atribuídos no passo anterior)
  4. Repita os passos 2 e 3 até não serem detectadas mudanças nos grupos
O k-means é um algoritmo 'greedy', sendo bastante popular e eficiente do ponto de vista computacional.

Cenário

Considere o problema de dividir uma lista de países em 3 grupos distintos (desenvolvidos, emergentes ou subdesenvolvidos), considerando os seguintes atributos:

  • Rendar per capita
  • Taxa de alfabetização
  • Mortalidade infantil
  • Expectativa de vida
Nesse caso, um algoritmo de clustering parece ser uma boa alternativa.

Enter R

Faça o download do dataset aqui. O primeiro passo é criar um objeto da classe 'data.frame' com as informações deste arquivo:

  paises = read.csv(file = 'countries.csv', header = TRUE, sep = '\t')    
 
No argumento file podemos especificar um caminho (absoluto ou relativo), ou uma URL externa. Ao especificar um caminho relativo, R vai procurar o arquivo a partir do seu working directory atual (digite getwd() caso não saiba onde fica o diretório de trabalho atual). No segundo argumento, especificamos a existência de headers na primeira linha do arquivo. Por fim, especificamos o caractere de separação ('tab' no nosso caso).

Agora já podemos executar o k-means:

   k = 3
   iter = 15
   km = kmeans (x = paises, centers = k, iter.max = iter)
  
O argumento x aceita qualquer objeto que possa ser convertido em uma matrix numérica, como nosso data frame, em que todas as colunas são numéricas. O argumento centers aceita um inteiro especificando o número de clusters que desejamos identificar, ou uma matrix numérica para inicialização dos centróides. Por fim, especificamos o número máximo de iterações. Existe um terceiro argumento, algorithm, que nos permite especificar o algoritmo a ser usado - ficamos com o default.

O código acima retorna um objeto da classe 'kmeans'. Podemos obter o centróide de cada cluster:

   centroids = km$centers 
  
E também o vetor que indica o cluster em que objeto foi alocado:
   classif = km$cluster
  
Os resultados parecem fazer sentido, com algumas exceções. Observe que Namíbia, Georgia, Paquistão e India estão no mesmo grupo dos países mais pobres, quando deveriam estar no grupo dos países emergentes.

Visualizando

Nesta seção vamos usar o pacote ggplot2 para tentar visualizar melhor os resultados obtidos. Caso ainda não o tenha instalado, dê uma olhada neste post. Observando os resultados acima, vemos que os clusters 1, 2 e 3 foram atribuídos aos países subdesenvolvidos, desenvolvidos e emergentes, respectivamente. Para facilitar a visualização do gráfico, vamos substituir esses números pelas respectivas categorias:

   # cria um vetor de strings
   categorias = c('underdeveloped', 'developed', 'emergent')
   
   # aloca um vetor numérico de tamanho igual ao número de países
   cluster = numeric(nrow(paises))
   
   # atribui a categoria[i] de acordo com o cluster
   for (i in 1:3)
    cluster[km$cluster == i] = categorias[i]
  
Agora podemos usar o ggplot2:
   library(ggplot2)
   plotmatrix(paises, aes(colour = cluster)) + 
   theme_bw() +
   opts(axis.title.x = theme_blank()) + 
   opts(axis.title.y = theme_blank())
  

Para um gráfico mais simples:

   pairs(paises, bg = km$cluster, pch = 21, cex = 1.5)
Repare no agrupamento dos países, de acordo com o esperado. Parece que o k-means funcionou bem neste caso. Era isso, até a próxima!

2 comentários:

  1. estou utilizando suas linhas para minhas aulas de Pesquisa Operacional e Métodos Quantitativos. Estarei visitando mais os seus blogs e indicando aos meus alunos. Obrigado.

    ResponderExcluir
    Respostas
    1. Obrigado! Preciso arrumar mais tempo para postar mais coisas, aceito sugestões.
      Grande abraço,
      Fernando T.

      Excluir