数据集来源Geolife
加载数据
importnumpyasnpimportmatplotlib.pyplotaspltimportpandasaspdimportosfrommatplotlib.colorsimportrgb2hexfromshapely.geometryimportMultiPointfromgeopy.distanceimportgreat_circlefromsklearn.clusterimportKMeansfromsklearn.clusterimportDBSCANuserdata='../Lab-work/Geolife Trajectories 1.3/Data/001/Trajectory/'filelist=os.listdir(userdata) names= ['lat','lng','zero','alt','days','date','time'] df_list= [pd.read_csv(userdata+f,header=6,names=names,index_col=False) forfinfilelist] df=pd.concat(df_list, ignore_index=True) print(df.head(10)) plt.plot(df.lat, df.lng)
lat lng zero alt days date time 0 39.984198 116.319322 0 492 39744.245208 2008-10-23 05:53:06 1 39.984224 116.319402 0 492 39744.245266 2008-10-23 05:53:11 2 39.984211 116.319389 0 492 39744.245324 2008-10-23 05:53:16 3 39.984217 116.319422 0 491 39744.245382 2008-10-23 05:53:21 4 39.984710 116.319865 0 320 39744.245405 2008-10-23 05:53:23 5 39.984674 116.319810 0 325 39744.245463 2008-10-23 05:53:28 6 39.984623 116.319773 0 326 39744.245521 2008-10-23 05:53:33 7 39.984606 116.319732 0 327 39744.245579 2008-10-23 05:53:38 8 39.984555 116.319728 0 324 39744.245637 2008-10-23 05:53:43 9 39.984579 116.319769 0 309 39744.245694 2008-10-23 05:53:48 [<matplotlib.lines.Line2D at 0x17efc43eac8>]
K-Means
coords=df[['lat','lng']].valuesn_clusters=100cls=KMeans(n_clusters).fit(coords) colors=tuple([(np.random.random(),np.random.random(), np.random.random()) foriinrange(n_clusters)]) colors= [rgb2hex(x) forxincolors] fori, colorinenumerate(colors): members=cls.labels_==iplt.scatter(coords[members, 0], coords[members, 1], s=60, c=color, alpha=0.5) plt.show()
获取 K-Means 聚类结果
cluster_labels=cls.labels_num_clusters=len(set(cluster_labels) -set([-1])) print('Clustered '+str(len(df_min)) +' points to '+str(num_clusters) +' clusters') clusters=pd.Series([coords[cluster_labels==n] forninrange(num_clusters)]) print(clusters)
Clustered 9045 points to 100 clusters 0 [[40.014459, 116.305603], [40.014363, 116.3056... 1 [[39.975246000000006, 116.358976], [39.975244,... 2 [[40.001312, 116.193358], [40.001351, 116.1932... 3 [[39.984559000000004, 116.326696], [39.984669,... 4 [[39.964969, 116.434923], [39.964886, 116.4350... ... 95 [[40.004549, 116.260581], [40.004515999999995,... 96 [[39.97964, 116.323856], [39.979701, 116.32396... 97 [[40.0009, 116.23948500000002], [40.000831, 11... 98 [[39.962336, 116.32817800000001], [39.96223300... 99 [[39.9663, 116.353677], [39.966291999999996, 1... Length: 100, dtype: object
获取每个群集的中心点
defget_centermost_point(cluster): centroid= (MultiPoint(cluster).centroid.x, MultiPoint(cluster).centroid.y) centermost_point=min(cluster, key=lambdapoint: great_circle(point, centroid).m) returntuple(centermost_point) centermost_points=clusters.map(get_centermost_point) lats, lons=zip(*centermost_points) rep_points=pd.DataFrame({'lon':lons, 'lat':lats}) print(rep_points)
lon lat 0 116.306558 40.013751 1 116.353295 39.975357 2 116.190167 40.004290 3 116.326944 39.986492 4 116.438241 39.961273 .. ... ... 95 116.256309 40.004774 96 116.326462 39.978752 97 116.232672 39.998630 98 116.328847 39.958271 99 116.358655 39.966451 [100 rows x 2 columns]
描绘中心点
fig, ax=plt.subplots(figsize=[10, 6]) rs_scatter=ax.scatter(rep_points['lon'][0], rep_points['lat'][0], c='#99cc99', edgecolor='None', alpha=0.7, s=450) ax.scatter(rep_points['lon'][1], rep_points['lat'][1], c='#99cc99', edgecolor='None', alpha=0.7, s=250) ax.scatter(rep_points['lon'][2], rep_points['lat'][2], c='#99cc99', edgecolor='None', alpha=0.7, s=250) ax.scatter(rep_points['lon'][3], rep_points['lat'][3], c='#99cc99', edgecolor='None', alpha=0.7, s=150) df_scatter=ax.scatter(df_min['lng'], df_min['lat'], c='k', alpha=0.9, s=3) ax.set_title('Full GPS trace vs. DBSCAN clusters') ax.set_xlabel('Longitude') ax.set_ylabel('Latitude') ax.legend([df_scatter, rs_scatter], ['GPS points', 'Cluster centers'], loc='upper right') labels= ['cluster{0}'.format(i) foriinrange(1, num_clusters+1)] forlabel, x, yinzip(labels, rep_points['lon'], rep_points['lat']): plt.annotate( label, xy= (x, y), xytext= (-25, -30), textcoords='offset points', ha='right', va='bottom', bbox=dict(boxstyle='round,pad=0.5', fc='white', alpha=0.5), arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0')) plt.show()