Implementing a Binary Search Tree in Java - Part Three


Time to get back to implementing our MyBinarySearchTree class. Now that the remove function is finished let’s go ahead and implement the getMaxNodeAndParent() and getMinNodeAndParent() functions that were alluded to in the previous post: Implementing a Binary Search Tree in Java - Part Two.

Starting with the getMaxNodeAndParent() function. This function will input the node to start searching from (root) and the parent to that node (parent). It will return the maximum node (and its parent) that is a child node of the root node.

This function will be fairly simple in implementation there are three cases to worry about:

  1. The root node is null
  2. The root node has a child node on it’s right side (ie root.right != null)
  3. The root node does not have a child node on it’s right side (ie root.right == null)

Here’s how we will handle each case:

  1. Return null as the input data is invalid.
  2. Recursively get the max node and parent of the root (which becomes the new parent) and it’s child (which becomes the new root).
  3. Return a pair consisting of the parent and the root as we have found the maximum node.

The function definition looks as follows:

private Pair<TreeNode<T>, TreeNode<T>> getMaxNodeAndParent(TreeNode<T> parent, TreeNode<T> root) {

}

If the root node is null simply return null.

	if(root == null)
		return null;

If there is a node to the right of the root node recursively return the maximum node and parent of the right node.

	if(root.right != null)
		return getMaxNodeAndParent(root, root.right);

If we make it past that statement (ie root.right == null) then return a pair consisting of parent and root as root is the maximum node.

	return new Pair<TreeNode<T>, TreeNode<T>>(parent, root);

The finished getMaxNodeAndParent() function is as follows:

private Pair<TreeNode<T>, TreeNode<T>> getMaxNodeAndParent(TreeNode<T> parent, TreeNode<T> root) {
	if(root == null)
		return null;
	if(root.right != null)
		return getMaxNodeAndParent(root, root.right);
	return new Pair<TreeNode<T>, TreeNode<T>>(parent, root);
}

Onto the getMinNodeAndParent() function. It is similar to the getMaxNodeAndParent() function except that it returns the minimum node instead of the maximum node.

This function will be similar in implementation there are three cases to worry about:

  1. The root node is null
  2. The root node has a child node on it’s left side (ie root.left != null)
  3. The root node does not have a child node on it’s left side (ie root.left == null)

Here’s how we will handle each case:

  1. Return null as the input data is invalid.
  2. Recursively get the min node and parent of the root (which becomes the new parent) and it’s child (which becomes the new root).
  3. Return a pair consisting of the parent and the root as we have found the minimum node.

The function definition is as follows:

private Pair<TreeNode<T>, TreeNode<T>> getMinNodeAndParent(TreeNode<T> parent, TreeNode<T> root) {

}

Return null if the root node is null.

	if(root == null)
		return null;

If root.left is not equal to null recursively get the minimum node.

	if(root.left != null)
		return getMinNodeAndParent(root, root.left);

If root.left is equal to null simply return a pair consisting of the parent and the root node as the root node is the minimum node.

	return new Pair<TreeNode<T>, TreeNode<T>>(parent, root);

The finished getMinNodeAndParent() function is as follows:

 private Pair<TreeNode<T>, TreeNode<T>> getMinNodeAndParent(TreeNode<T> parent, TreeNode<T> root) {
	if(root == null)
		return null;
	if(root.left != null)
		return getMinNodeAndParent(root, root.left);
	return new Pair<TreeNode<T>, TreeNode<T>>(parent, root);
 }

Next let’s implement a function to recursively print out our tree to a StringBuilder. This function will input the node representing the current location in the tree as well as the StringBuilder to be appended to. This function will be later on used in our toString() method. The toStringRecursive() function definition is as follows:

private void toStringRecursive(TreeNode<T> node, StringBuilder s) {

}

If the node is null simply return null as there is nothing to do.

	if(node == null)
		return;

If the node has a left node (ie node.left != null) we need to print it out first. We’ll do this by calling toStringRecursive() inputted with node.left. After calling that we append “,” to the string builder to make the outputted data look pretty.

	if(node.left != null) {
		toStringRecursive(node.left, s);
		s.append(", ");
	}

Next append the value of node.data to the StringBuilder.

	s.append(node.data);

Now we’ll handle the right side of the node. If the node has a right node (ie node.right != null) we need to print it out now (after printing out the left side of the node as well as the current node). First we make sure to append “,” to the StringBuilder to make the outputted data look pretty. Then at last we call toStringRecursive() inputted with node.right.

	if(node.right != null) {
		s.append(", ");
		toStringRecursive(node.right, s);
	}

The finished toStringRecursive() function is as follows:

private void toStringRecursive(TreeNode<T> node, StringBuilder s) {
	if(node == null)
		return;
	if(node.left != null) {
		toStringRecursive(node.left, s);
		s.append(", ");
	}
	s.append(node.data);
	if(node.right != null) {
		s.append(", ");
		toStringRecursive(node.right, s);
	}
}

With that out of the way we can write a toString() method to return a String representation of our MyBinarySearchTree class. This will print out all the items in the tree out in order from smallest to largest. The implementation of this is fairly straightforward as most of the heavy lifting is already done in the toStringRecursive() function.

We must simply:

  1. declare a new StringBuilder
  2. append ‘[’ to the StringBuilder
  3. call toStringRecursive() with the rootNode and the newly declared StringBuilder
  4. append ‘]’ to the StringBuilder (this makes it so the content of the tree is contained within the characters ‘[]’)
  5. return the String representation of the StringBuilder

So here is the finished toString() function:

public String toString() {
    StringBuilder s = new StringBuilder();
    s.append("[");
    toStringRecursive(rootNode, s);
    s.append("]");
    return s.toString();
}

The finished MyBinarySearchTree class is as follows:

package com.spartanengineer.datastructures;
import java.util.*;
import javafx.util.Pair;
public class MyBinarySearchTree<T extends Comparable<T>> {
    private class TreeNode<U extends Comparable<U>> {
        public TreeNode<U> left = null;
        public TreeNode<U> right = null;
        public U data = null;
        public TreeNode(U data) {
            this.data = data;
        }
    }
    private TreeNode<T> rootNode = null;
    private int size = 0;
    public MyBinarySearchTree() {
    }
    public void insert(T data) {
		TreeNode<T> newNode = new TreeNode<T>(data);
		if(rootNode == null) {
			rootNode = newNode;
		} else {
			TreeNode<T> parentNode = rootNode;
			while(true) {
				if(data.compareTo(parentNode.data) <= 0) {
					if(parentNode.left == null) {
						parentNode.left = newNode;
						break;
					}
					parentNode = parentNode.left;
				} else {
					if(parentNode.right == null) {
						parentNode.right = newNode;
						break;
					}
					parentNode = parentNode.right;
				}
			}
		}
		size++;
    }
    public boolean contains(T data) {
		TreeNode<T> node = rootNode;
		while(node != null) {
			if(data.compareTo(node.data) == 0)
				return true;
			else if(data.compareTo(node.data) < 0)
				node = node.left;
			else
				node = node.right;
		}
		return false;
    }
    private Pair<TreeNode<T>, TreeNode<T>> getMaxNodeAndParent(TreeNode<T> parent, TreeNode<T> root) {
		if(root == null)
			return null;
		if(root.right != null)
			return getMaxNodeAndParent(root, root.right);
		return new Pair<TreeNode<T>, TreeNode<T>>(parent, root);
    }
    private Pair<TreeNode<T>, TreeNode<T>> getMinNodeAndParent(TreeNode<T> parent, TreeNode<T> root) {
		if(root == null)
			return null;
		if(root.left != null)
			return getMinNodeAndParent(root, root.left);
		return new Pair<TreeNode<T>, TreeNode<T>>(parent, root);
    }
    public boolean remove(T data) {
		if(rootNode == null)
			return false;
		if(data.compareTo(rootNode.data) == 0 && rootNode.left == null && rootNode.right == null) {
			size = 0;
			rootNode = null;
			return true;
		}
		TreeNode<T> parentNode = null;
		TreeNode<T> toDelete = rootNode;
		while(toDelete != null) {
			if(data.compareTo(toDelete.data) == 0) {
				//this is where we remove the node
				Pair<TreeNode<T>, TreeNode<T>> pair = getMaxNodeAndParent(toDelete, toDelete.left);
				TreeNode<T> toMove = null;
				TreeNode<T> toMoveParent = null;
				if(pair != null) {
					toMoveParent = pair.getKey();
					toMove = pair.getValue();
					if(toMoveParent.left == toMove)
						toMoveParent.left = toMove.left;
					else
						toMoveParent.right = toMove.left;
				} else {
					pair = getMinNodeAndParent(toDelete, toDelete.right);
					if(pair != null) {
						toMoveParent = pair.getKey();
						toMove = pair.getValue();
						if(toMoveParent.left == toMove)
							toMoveParent.left = toMove.right;
						else
							toMoveParent.right = toMove.right;
					}
				}
				if(toMove != null) {
					toMove.left = toDelete.left;
					toMove.right = toDelete.right;
				}
				if(parentNode != null)
					if(parentNode.left == toDelete)
						parentNode.left = toMove;
					else
						parentNode.right = toMove;
				else
					rootNode = toMove;
				size--;
				return true;
			} else if(data.compareTo(toDelete.data) < 0) {
				parentNode = toDelete;
				toDelete = toDelete.left;
			} else {
				parentNode = toDelete;
				toDelete = toDelete.right;
			}
		}
		return false;
    }
    private void toStringRecursive(TreeNode<T> node, StringBuilder s) {
		if(node == null)
			return;
		if(node.left != null) {
			toStringRecursive(node.left, s);
			s.append(", ");
		}
		s.append(node.data);
		if(node.right != null) {
			s.append(", ");
			toStringRecursive(node.right, s);
		}
    }
    public String toString() {
        StringBuilder s = new StringBuilder();
        s.append("[");
        toStringRecursive(rootNode, s);
        s.append("]");
        return s.toString();
    }
}

This concludes the implementation of MyBinarySearchTree.

Data Structures