| | | 1 | | using System; |
| | | 2 | | using System.Collections.Generic; |
| | | 3 | | using System.Linq; |
| | | 4 | | using UnityEngine; |
| | | 5 | | |
| | | 6 | | // K-D tree node. |
| | | 7 | | public class KDNode<T> { |
| | 0 | 8 | | public T Data { get; internal set; } |
| | 0 | 9 | | public KDNode<T> Left { get; internal set; } |
| | 0 | 10 | | public KDNode<T> Right { get; internal set; } |
| | | 11 | | } |
| | | 12 | | |
| | | 13 | | // K-D tree with 2D coordinates. |
| | | 14 | | public class KDTree<T> { |
| | | 15 | | // Root node. |
| | | 16 | | private KDNode<T> _root; |
| | | 17 | | |
| | | 18 | | // Function to get the coordinates from the tree elements. |
| | | 19 | | private Func<T, Vector2> _getCoordinates; |
| | | 20 | | |
| | | 21 | | public KDTree(IReadOnlyList<T> points, Func<T, Vector2> getCoordinates) { |
| | | 22 | | _getCoordinates = getCoordinates; |
| | | 23 | | _root = BuildTree(points, depth: 0); |
| | | 24 | | } |
| | | 25 | | |
| | | 26 | | // Find the nearest neighbor to the target node. |
| | | 27 | | public T NearestNeighbor(in Vector2 target) { |
| | | 28 | | KDNode<T> neighbor = NearestNeighbor(_root, target, depth: 0, bestNode: null); |
| | | 29 | | if (neighbor == null) { |
| | | 30 | | return default(T); |
| | | 31 | | } |
| | | 32 | | return neighbor.Data; |
| | | 33 | | } |
| | | 34 | | |
| | | 35 | | // Find the nearest neighbor to the target node. |
| | | 36 | | private KDNode<T> NearestNeighbor(KDNode<T> node, in Vector2 target, int depth, |
| | | 37 | | KDNode<T> bestNode) { |
| | | 38 | | if (node == null) { |
| | | 39 | | return bestNode; |
| | | 40 | | } |
| | | 41 | | |
| | | 42 | | Vector2 nodeCoordinates = _getCoordinates(node.Data); |
| | | 43 | | float currentDistance = Vector2.Distance(nodeCoordinates, target); |
| | | 44 | | float bestDistance = bestNode == null |
| | | 45 | | ? float.MaxValue |
| | | 46 | | : Vector2.Distance(_getCoordinates(bestNode.Data), target); |
| | | 47 | | if (currentDistance < bestDistance) { |
| | | 48 | | bestNode = node; |
| | | 49 | | bestDistance = currentDistance; |
| | | 50 | | } |
| | | 51 | | |
| | | 52 | | int axis = depth % 2; |
| | | 53 | | float targetCoordinatesValue = axis == 0 ? target.x : target.y; |
| | | 54 | | float nodeCoordinatesValue = axis == 0 ? nodeCoordinates.x : nodeCoordinates.y; |
| | | 55 | | KDNode<T> nextBranch = targetCoordinatesValue < nodeCoordinatesValue ? node.Left : node.Right; |
| | | 56 | | KDNode<T> otherBranch = targetCoordinatesValue < nodeCoordinatesValue ? node.Right : node.Left; |
| | | 57 | | |
| | | 58 | | // Explore the next branch first. |
| | | 59 | | bestNode = NearestNeighbor(nextBranch, target, depth + 1, bestNode); |
| | | 60 | | if (bestNode != null) { |
| | | 61 | | bestDistance = Vector2.Distance(_getCoordinates(bestNode.Data), target); |
| | | 62 | | ; |
| | | 63 | | } |
| | | 64 | | |
| | | 65 | | // Explore the other branch. |
| | | 66 | | if (Mathf.Abs(targetCoordinatesValue - nodeCoordinatesValue) < bestDistance) { |
| | | 67 | | bestNode = NearestNeighbor(otherBranch, target, depth + 1, bestNode); |
| | | 68 | | } |
| | | 69 | | return bestNode; |
| | | 70 | | } |
| | | 71 | | |
| | | 72 | | // Construct a tree from the list of points. |
| | | 73 | | private KDNode<T> BuildTree(IReadOnlyList<T> points, int depth) { |
| | | 74 | | if (points.Count == 0) { |
| | | 75 | | return null; |
| | | 76 | | } |
| | | 77 | | |
| | | 78 | | int k = 2; |
| | | 79 | | int axis = depth % k; |
| | | 80 | | |
| | | 81 | | // Sort the points by axis and find the median. |
| | | 82 | | List<T> sortedPoints = |
| | | 83 | | points.OrderBy(point => (axis == 0 ? _getCoordinates(point).x : _getCoordinates(point).y)) |
| | | 84 | | .ToList(); |
| | | 85 | | int medianIndex = sortedPoints.Count / 2; |
| | | 86 | | T medianPoint = sortedPoints[medianIndex]; |
| | | 87 | | |
| | | 88 | | var node = new KDNode<T> { Data = medianPoint }; |
| | | 89 | | List<T> leftPoints = sortedPoints.GetRange(0, medianIndex); |
| | | 90 | | List<T> rightPoints = |
| | | 91 | | sortedPoints.GetRange(medianIndex + 1, sortedPoints.Count - medianIndex - 1); |
| | | 92 | | |
| | | 93 | | node.Left = BuildTree(leftPoints, depth + 1); |
| | | 94 | | node.Right = BuildTree(rightPoints, depth + 1); |
| | | 95 | | return node; |
| | | 96 | | } |
| | | 97 | | } |