
public class BST {

  TreeNode root = null;

  public class TreeNode {
    public int myValue;
    public TreeNode left; // holds smaller tree nodes
    public TreeNode right; // holds larger tree nodes

    public TreeNode(int val) { myValue = val; }
  }

  public void add(int newValue)
  {
    if(root == null)
      root = new TreeNode(newValue);
    else
      add(newValue, root);
  }
  public void add(int newValue, TreeNode current) {
    if (newValue < current.myValue) {
      if (current.left == null) {
        current.left = new TreeNode(newValue);
      } else {
        add(newValue, current.left);
      }
    } else {
      // newValue >= myValue
      if (current.right == null) {
        current.right = new TreeNode(newValue);
      } else {
        add(newValue, current.right);
      }
    }
  }

  public String toString()
  {
    return toString(root, "");
  }
  public String toString(TreeNode current, String level) {
    String leftString = "null";
    String rightString = "null";

    if (current.left != null)
      leftString = toString(current.left, level+"   ");
    if (current.right != null)
      rightString = toString(current.right, level+"   ");

    return current.myValue + "\n" + level +"L: " + leftString + "\n" + level + "R: " + rightString;
  }

  public int computeHeight()
  {
    return computeHeight(root);
  }
  private int computeHeight(TreeNode current) {
    if(current == null)
      return 0;
    int lResult = computeHeight(current.left);
    int rResult = computeHeight(current.right);
    if(lResult > rResult) {
      return lResult + 1;
    } else {
      return rResult + 1;
    }
  }

  public int countNodes()
  {
    return countNodes(root);
  }   
  public int countNodes(TreeNode current)
  {
    if(current == null) {
      return 0;
    }
    int lCount = countNodes(current.left);
    int rCount = countNodes(current.right);

    // your code here
    return 1 + lCount + rCount;
  }

  public boolean containsNode(int value) {
    return containsNode(value, root);
  }  
  private boolean containsNode(int value, TreeNode current) {
    // your code goes here
    // hint: base case...if the node is null, it does not contain the value
    if (current == null)
      return false;
    if(current.myValue == value)
      return true;
    boolean lBool = containsNode(value, current.left);
    boolean rBool = containsNode(value, current.right);
    return lBool || rBool;
  }

  public int findMax() {
    return findMax(root);
  }  
  private int findMax(TreeNode current) {
    //assuming all nodes have values greater than -1000
    int max = -1000;
    if(current == null)
      return max;
    int lMax = findMax(current.left);
    int rMax = findMax(current.right);
    return Math.max(current.myValue, Math.max(lMax, rMax));
  }

  public int numLeaves() {
    return numLeaves(root);
  }
  private int numLeaves(TreeNode current) {
    if(current.left == null && current.right == null) {
      // we are a leaf
      return 1;
    }

    int leafCount = 0;

    if(current.left != null)
      leafCount += numLeaves(current.left);

    if(current.right != null)
      leafCount += numLeaves(current.right);

    return leafCount;
  }

  public int levelCount(int target) {
    return levelCount(root, target);
  }

  /**
   * Returns number of nodes at specified level in t, where level >= 0.
   * @param level specifies the level, >= 0
   * @param t is the tree whose level-count is determined
   * @return number of nodes at given level in t
   */
  public int levelCount(TreeNode t, int level){
    if(t == null) {
      return 0;
    }

    if(level == 0) {
      return 1;
    }

    int nodeCount = 0;

    nodeCount += levelCount(t.left, level - 1);
    nodeCount += levelCount(t.right, level - 1);


    return nodeCount;

  }

  public boolean hasPathSum(int target) {
    return hasPathSum(root, target);
  }
  /**
   * Return true if and only if t has a root-to-leaf path that sums to target.
   * @param t is a binary tree
   * @param target is the value whose sum is searched for on some root-to-leaf path
   * @return true if and only if t has a root-to-leaf path summing to target
   */
  private boolean hasPathSum(TreeNode current, int target) {
    if(current == null)
      return false;

    if(current.left == null && current.right == null) {
      if(target == current.myValue) {
        return true;
      }
    }

    boolean rightPath = hasPathSum(current.right, target - current.myValue);
    boolean leftPath = hasPathSum(current.left, target - current.myValue);

    return rightPath || leftPath;
  } 

  public int findK(int k) {
    return findK(root, k);
  }

  private int findK(TreeNode current, int k) {
    if(current == null) {
      return 0;
    }
    if(current.left == null && current.right == null) {
      k--;
      if(k == 0) {
        return current.myValue;
      }
    } else {
      int val = findK(current.left, k);
      if(val != -1) {
        return val;
      }
      k = k - countNodes(current.left);
      k--;
      if(k == 0)
        return current.myValue;
      return findK(current.right, k);
    }


    return -1;

  } 

  public static int height(TreeNode t){

    if (t == null) return 0;
    return 1 + Math.max(height(t.left), height(t.right));
  } 

  public static int diameter(TreeNode t) {

    if (t == null)  return 0;

    int leftD  = diameter(t.left);
    int rightD = diameter(t.right);
    int rootD  = height(t.left) + height(t.right) + 1;

    return Math.max(rootD, Math.max(leftD, rightD));
  }

  /**
   * Return both height and diameter of t, return height as
   * value with index 0 and diameter as value with index 1
   * in the returned array.
   * @param t is a binary tree
   * @return array containing both height (index 0) and diameter (index 1)
   */
  public static int[] diameterHelper (TreeNode t) {

    int[] ret = new int[2];   // return this array

    if (t == null) {
      ret[0] = 0;   // height is 0
      ret[1] = 0;   // and diameter is 0
      return ret;
    }
    int[] left =  diameterHelper(t.left);
    int[] right = diameterHelper(t.right);

    ret[0] = 1 + Math.max(left[0],right[0]);  // this is height

    // fill in value for ret[1] below

    //FIX THIS
    ret[1] = Math.max((right[0] + left[0] + 1), Math.max(left[1], right[1]));

    return ret;

  }

  public static int diameterTwo(TreeNode t) {
    int[] ret = diameterHelper(t);
    return ret[1];
  }

  /**
   * Returns true if s and t are isomorphic, i.e., have same shape.
   * @param s is a binary tree (not necessarily a search tree)
   * @param t is a binary tree
   * @return true if and only if s and t are isomorphic
   */

  public static boolean isIsomorphic(TreeNode s, TreeNode t) {
    if(s == null && t == null) {
      return true;
    }
    if(s == null || t == null) {
      return false;
    }
    
    boolean leftIso = isIsomorphic(s.left, t.left);
    boolean rightIso = isIsomorphic(s.right, t.right);
    
    return leftIso && rightIso;
  }

  /**
   * Returns true if s and t are quasi-isomorphic, i.e., have same quasi-shape.
   * @param s is a binary tree (not necessarily a search tree)
   * @param t is a binary tree
   * @return true if and only if s and t are quasi-isomporphic
   */
  public static boolean isQuasiIsomorphic(TreeNode s, TreeNode t) {
    if(s == null && t == null) {
      return true;
    }
    if(s == null || t == null) {
      return false;
    }
    
    boolean leftQIso = isQuasiIsomorphic(s.left, t.left);
    boolean rightQIso = isQuasiIsomorphic(s.right, t.right);
    
    if(leftQIso && rightQIso)
      return true;
    
    leftQIso = isQuasiIsomorphic(s.left, t.right);
    rightQIso = isQuasiIsomorphic(s.right, t.left);
    
    return leftQIso && rightQIso;

  }



  public static void main(String[] args) {
    BST bst = new BST();
    int[] data = {6,8,2,4,1,7,5,3,9};
    for(int i: data)
      bst.add(i);
    System.out.println(bst.toString());
    System.out.printf("Number of leaves: %d\n", bst.numLeaves()); // 5
    System.out.printf("Number of nodes on level 2: %d\n", bst.levelCount(2)); // 4
    System.out.println("Has path sum 17: " + bst.hasPathSum(17));//true
    System.out.println("Has path sum 10: " + bst.hasPathSum(10));//false
    System.out.printf("The value in node 5 is: %d\n", bst.findK(5)); // 5
    
    int[] dataOne = {10,2,1,7,6,15,14,13}; // isomorphic to 3
    int[] dataTwo = {10,1,2,3,15,11,12,17}; //quasi-isomorphic to 1 & 3
    int[] dataThree = {11,3,2,8,7,16,15,14}; // isomorphic to 1
    BST one = new BST();
    BST two = new BST();
    BST three = new BST();
    for(int i: dataOne)
      one.add(i);
    for(int i: dataTwo)
      two.add(i);
    for(int i: dataThree)
      three.add(i);
    
    System.out.println(diameter(one.root)); // 7
    System.out.println(diameterTwo(one.root)); // also 7
    System.out.println(isIsomorphic(one.root, three.root)); //true
    System.out.println(isQuasiIsomorphic(one.root, two.root)); //true
  }
}
