Data science/머신러닝

K-means clustering(k-평균 알고리즘)

gokimkh 2022. 7. 23. 15:11

개념

  1. 무작위로 k개의 클러스터 중심을 정한다.
  2. 각 샘플에서 가장 가까운 클러스터 중심을 찾아 해당 클러스터의 샘플로 지정한다.
  3. 클러스터에 속한 샘플의 평균값으로 클러스터 중심을 변경한다.
  4. 클러스터 중심에 변화가 없을 때까지 2번으로 돌아가 반복한다.

https://commons.wikimedia.org/wiki/File:K-means_convergence.gif

 

문법

from sklearn.cluster import KMeans     # K-mean 모델 사용

# n_clusters : 클러스터 개수, n_init : 반복 횟수, max_iter : 최대 반복 횟수
km = KMeans(n_clusters=8, n_init=10, max_iter=200)
km.fit(data)

 

실습

!wget https://bit.ly/fruits_300_data -O fruits_300.npy   # 과일사진 데이터 준비

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans     # k-means 알고리즘 사용

fruits = np.load('fruits_300.npy')
fruits_2d = fruits.reshape(-1, 100 * 100)     # 행, 열을 한줄로 나타낸다(100 x 100)
km = KMeans(n_clusters=3, random_state=42)    # 클러스터 개수 3개
km.fit(fruits_2d)                    		  # 학습

def draw_fruits(arr, ratio=1):  			  # 그림 그리기
  n = len(arr)
  rows = int(np.ceil(n / 10))
  cols = n if rows < 2 else 10
  fig, axs = plt.subplots(rows, cols, figsize=(cols * ratio, rows * ratio), squeeze = False)

  for i in range(rows):
    for j in range(cols):
      if i * 10 + j < n:  axs[i, j].imshow(arr[i * 10 + j], cmap = 'gray_r')
      axs[i, j].axis('off')
  plt.show()

draw_fruits(fruits[km.labels_==0])
draw_fruits(fruits[km.labels_==1])
draw_fruits(fruits[km.labels_==2])

draw_fruits(km.cluster_centers_.reshape(-1,100,100),ratio=3)