-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathClustering.java
More file actions
89 lines (79 loc) · 2.88 KB
/
Clustering.java
File metadata and controls
89 lines (79 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import edu.princeton.cs.algs4.CC;
import edu.princeton.cs.algs4.Edge;
import edu.princeton.cs.algs4.EdgeWeightedGraph;
import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.KruskalMST;
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.StdOut;
public class Clustering {
private int m; // vertices; locations on map
private int k; // number of clusters
private CC cc; // graph of connected components
// run the clustering algorithm and create the clusters
public Clustering(Point2D[] locations, int k) {
if (locations == null) {
throw new IllegalArgumentException();
}
m = locations.length;
if (k < 1 || k > m) {
throw new IllegalArgumentException();
}
this.k = k;
// graph that connects locations and is weighted according to distance
EdgeWeightedGraph graph = new EdgeWeightedGraph(m);
for (int i = 0; i < locations.length; i++) {
for (int j = 0; j < locations.length; j++) {
if (i == j) continue;
double distance = locations[i].distanceSquaredTo(locations[j]);
Edge edge = new Edge(i, j, distance);
graph.addEdge(edge);
}
}
KruskalMST mst = new KruskalMST(graph);
EdgeWeightedGraph newgraph = new EdgeWeightedGraph(m);
// keeps track of the number of edges
int count = 0;
// add m-k edges to graph with lowest weight of spanning tree
for (Edge edge : mst.edges()) {
if (count == m - k) break;
count++;
newgraph.addEdge(edge);
}
cc = new CC(newgraph);
}
// return the cluster of the ith point
public int clusterOf(int i) {
if (i < 0 || i > m - 1) {
throw new IllegalArgumentException();
}
return cc.id(i);
}
// use the clusters to reduce the dimensions of an input
public int[] reduceDimensions(int[] input) {
if (input == null) {
throw new IllegalArgumentException();
}
if (input.length != m) {
throw new IllegalArgumentException();
}
// create int array of reduced dimensions
int[] array = new int[k];
for (int i = 0; i < input.length; i++) {
int cluster = clusterOf(i);
array[cluster] += input[i];
}
return array;
}
// unit testing (required)
public static void main(String[] args) {
In in = new In(args[0]);
Point2D[] points = new Point2D[in.readInt()];
for (int i = 0; i < points.length; i++) {
points[i] = new Point2D(in.readDouble(), in.readDouble());
}
Clustering cluster = new Clustering(points, 5);
StdOut.println(cluster.clusterOf(2));
int[] input = { 1, 2, 3, 4 };
StdOut.println(cluster.reduceDimensions(input));
}
}