inputs:
'data' is a Nx2 numpy array of [x,y] points
'numsegs' and 'numclass' are integer scalars. The greater these numbers the greater the 'complexity' of the output and the longer the processing time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.cluster.vq import kmeans,vq | |
from pylab import * | |
def patchwork(data,numsegs,numclass): | |
# computing K-Means with K = numsegs | |
centroids,_ = kmeans(data,numsegs) | |
# assign each sample to a cluster | |
idx,_ = vq(data,centroids) | |
# loop through number of classes | |
for k in range(numsegs): | |
# get data in kth segment | |
datXY=data[idx==k,:] | |
# get distances of each point from centroid | |
dat1=np.sqrt((datXY[:,0]-centroids[k,0])**2 + (datXY[:,1]-centroids[k,1])**2) | |
# get numclass new centroids within kth segment | |
cents,_ = kmeans(dat1,numclass) | |
# assign each to subcluster | |
i,_ = vq(dat1,cents) | |
# loop through and plot each | |
for p in range(numclass): | |
plot(datXY[i==p,0],datXY[i==p,1],'.') | |
savefig("outputs"+str(numsegs)+'_'+str(numclass)+'_res.png') | |
close() |
Some example outputs in increasing complexity:
numsegs=10, numclass=5
numsegs=15, numclass=15
numsegs=20, numclass=50
No comments:
Post a Comment