| | | 1 | | using System; |
| | | 2 | | using System.Collections.Generic; |
| | | 3 | | using System.Linq; |
| | | 4 | | using UnityEngine; |
| | | 5 | | |
| | | 6 | | // The k-means clusterer class performs k-means clustering. |
| | | 7 | | public class KMeansClusterer : ClustererBase { |
| | | 8 | | // Default maximum number of iterations. |
| | | 9 | | protected const int _defaultMaxNumIterations = 20; |
| | | 10 | | |
| | | 11 | | // Convergence threshold. |
| | | 12 | | protected const float _epsilon = 1e-3f; |
| | | 13 | | |
| | | 14 | | // Number of clusters. |
| | | 15 | | private readonly int _k; |
| | | 16 | | |
| | | 17 | | // Maximum number of iterations. |
| | | 18 | | private readonly int _maxNumIterations; |
| | | 19 | | |
| | 39 | 20 | | public KMeansClusterer(int k) : this(k, _defaultMaxNumIterations) {} |
| | 26 | 21 | | public KMeansClusterer(int k, int maxNumIterations) { |
| | 13 | 22 | | _k = k; |
| | 13 | 23 | | _maxNumIterations = maxNumIterations; |
| | 13 | 24 | | } |
| | | 25 | | |
| | | 26 | | // Generate the clusters from the list of hierarchical objects. |
| | 13 | 27 | | public override List<Cluster> Cluster(IEnumerable<IHierarchical> hierarchicals) { |
| | 15 | 28 | | if (hierarchicals == null || !hierarchicals.Any()) { |
| | 2 | 29 | | return new List<Cluster>(); |
| | | 30 | | } |
| | | 31 | | |
| | | 32 | | // Initialize the clusters with centroids located at the positions of k random hierarchical |
| | | 33 | | // objects. Perform Fisher-Yates shuffling to find k random hierarchical objects. |
| | 11 | 34 | | List<IHierarchical> hierarchicalsList = hierarchicals.ToList(); |
| | | 35 | | |
| | | 36 | | // Validate that k is not greater than the number of hierarchical objects. |
| | 12 | 37 | | if (_k > hierarchicalsList.Count) { |
| | 1 | 38 | | throw new InvalidOperationException( |
| | | 39 | | $"Cannot create {_k} clusters from {hierarchicalsList.Count} hierarchical objects."); |
| | | 40 | | } |
| | | 41 | | |
| | 10 | 42 | | var clusters = new List<Cluster>(); |
| | 83 | 43 | | for (int i = 0; i < _k; ++i) { |
| | 21 | 44 | | int j = UnityEngine.Random.Range(i, hierarchicalsList.Count); |
| | 21 | 45 | | (hierarchicalsList[i], hierarchicalsList[j]) = (hierarchicalsList[j], hierarchicalsList[i]); |
| | | 46 | | |
| | 21 | 47 | | clusters.Add(new Cluster { Centroid = hierarchicalsList[i].Position }); |
| | 21 | 48 | | } |
| | | 49 | | |
| | 10 | 50 | | bool converged = false; |
| | 10 | 51 | | int numIteration = 0; |
| | 48 | 52 | | while (!converged && numIteration < _maxNumIterations) { |
| | 19 | 53 | | AssignToClusters(clusters, hierarchicals); |
| | | 54 | | |
| | | 55 | | // Calculate the new clusters as the mean of all assigned hierarchical objects. |
| | 19 | 56 | | converged = true; |
| | 149 | 57 | | for (int clusterIndex = 0; clusterIndex < clusters.Count; ++clusterIndex) { |
| | 37 | 58 | | Vector3 oldClusterPosition = clusters[clusterIndex].Centroid; |
| | 37 | 59 | | if (clusters[clusterIndex].IsEmpty) { |
| | 0 | 60 | | int hierarchicalIndex = UnityEngine.Random.Range(0, hierarchicalsList.Count); |
| | 0 | 61 | | clusters[clusterIndex].Centroid = hierarchicalsList[hierarchicalIndex].Position; |
| | 37 | 62 | | } else { |
| | 37 | 63 | | clusters[clusterIndex].Recenter(); |
| | 37 | 64 | | } |
| | | 65 | | |
| | | 66 | | // Check whether the algorithm has converged by checking whether the cluster has moved. |
| | 48 | 67 | | if (Vector3.Distance(oldClusterPosition, clusters[clusterIndex].Position) > _epsilon) { |
| | 11 | 68 | | converged = false; |
| | 11 | 69 | | } |
| | 37 | 70 | | } |
| | 19 | 71 | | ++numIteration; |
| | 19 | 72 | | } |
| | 10 | 73 | | AssignToClusters(clusters, hierarchicals); |
| | 10 | 74 | | return clusters; |
| | 12 | 75 | | } |
| | | 76 | | |
| | | 77 | | private static void AssignToClusters(IReadOnlyList<Cluster> clusters, |
| | 29 | 78 | | IEnumerable<IHierarchical> hierarchicals) { |
| | | 79 | | // Clear all clusters. |
| | 261 | 80 | | foreach (var cluster in clusters) { |
| | 58 | 81 | | cluster.ClearSubHierarchicals(); |
| | 58 | 82 | | } |
| | | 83 | | |
| | | 84 | | // Determine the closest centroid to each hierarchical object. |
| | 471 | 85 | | foreach (var hierarchical in hierarchicals) { |
| | 128 | 86 | | float minDistance = Mathf.Infinity; |
| | 128 | 87 | | int minIndex = -1; |
| | 1024 | 88 | | for (int clusterIndex = 0; clusterIndex < clusters.Count; ++clusterIndex) { |
| | 256 | 89 | | float distance = Vector3.Distance(clusters[clusterIndex].Centroid, hierarchical.Position); |
| | 427 | 90 | | if (distance < minDistance) { |
| | 171 | 91 | | minDistance = distance; |
| | 171 | 92 | | minIndex = clusterIndex; |
| | 171 | 93 | | } |
| | 256 | 94 | | } |
| | 128 | 95 | | clusters[minIndex].AddSubHierarchical(hierarchical); |
| | 128 | 96 | | } |
| | 29 | 97 | | } |
| | | 98 | | } |