Data science/머신러닝
K-means clustering(k-평균 알고리즘)
gokimkh
2022. 7. 23. 15:11
개념
- 무작위로 k개의 클러스터 중심을 정한다.
- 각 샘플에서 가장 가까운 클러스터 중심을 찾아 해당 클러스터의 샘플로 지정한다.
- 클러스터에 속한 샘플의 평균값으로 클러스터 중심을 변경한다.
- 클러스터 중심에 변화가 없을 때까지 2번으로 돌아가 반복한다.
문법
from sklearn.cluster import KMeans # K-mean 모델 사용
# n_clusters : 클러스터 개수, n_init : 반복 횟수, max_iter : 최대 반복 횟수
km = KMeans(n_clusters=8, n_init=10, max_iter=200)
km.fit(data)
실습
!wget https://bit.ly/fruits_300_data -O fruits_300.npy # 과일사진 데이터 준비
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans # k-means 알고리즘 사용
fruits = np.load('fruits_300.npy')
fruits_2d = fruits.reshape(-1, 100 * 100) # 행, 열을 한줄로 나타낸다(100 x 100)
km = KMeans(n_clusters=3, random_state=42) # 클러스터 개수 3개
km.fit(fruits_2d) # 학습
def draw_fruits(arr, ratio=1): # 그림 그리기
n = len(arr)
rows = int(np.ceil(n / 10))
cols = n if rows < 2 else 10
fig, axs = plt.subplots(rows, cols, figsize=(cols * ratio, rows * ratio), squeeze = False)
for i in range(rows):
for j in range(cols):
if i * 10 + j < n: axs[i, j].imshow(arr[i * 10 + j], cmap = 'gray_r')
axs[i, j].axis('off')
plt.show()
draw_fruits(fruits[km.labels_==0])
draw_fruits(fruits[km.labels_==1])
draw_fruits(fruits[km.labels_==2])
draw_fruits(km.cluster_centers_.reshape(-1,100,100),ratio=3)