Sunday, July 10, 2011

Comprehensive comparision of recursion,memoization and iteration


In this post, I am going to talk about iteration, recursion and memoization. It is important to understand, when one should use recursion. Recursion is sometimes said to be "code beautification", because it improves readability, but mostly suffers on performance. Lets take a famous problem of tower of hanoi. It sounds difficult at first glance, but can be very easily solved in a recursive fashion. Here goes the code :

public class Hanaoi {
 static int ctr=0;
 
 public static void main(String[] args) {
  Stack<String> a,b,c;
  a=new Stack<String>();
  b=new Stack<String>();
  c=new Stack<String>();
  Hanaoi h = new Hanaoi();
  a.push("A");
  int n=5; //no of discs
  for(int i=n;i>0;i--) a.push(i+""); 
  
  b.push("B");
  c.push("C");

  h.doit(a.size()-1,a,c,b);
  System.out.println("Total moves " + ctr);
 }

 void doit(int n,Stack<String> a,Stack<String> c,Stack<String> b) {
  ctr++;
  if(n==0) return;
  else {
   doit(n-1,a,b,c);
   System.out.print("Move plate "+n+" from "+a+" to "+c);
   c.push(a.pop());
   System.out.println("--> Move plate "+n+" from "+a+" to "+c);
   doit(n-1,b,c,a);
  }
 }
}

Give n>=30 and you can see that amount of time taken to compute, rises steeply.

So, a important question arises, when to avoid recursion. Their is no general rule to it, but I would suggest a thumb rule based on my experience. If at any point of time while compution (say state/value C), you can make a decision (somewhat greedy) about next value to be computed P(generally rule based with some temporary data) and their is no need to comeback in state C or reuse value in state C, you should go for iteration. You should generally use recursion when one needs to do a whole state space search, to find global optima.

Let take few examples to explain it (Simultaneously I am going to compare performance of iteration,recursion and memoization where possible)

1) Fibonacci Series :

In this case, you can easily store last 2 values and compute iteratively. Values Fn and Fn-1 will be used to compute Fn+1, while Fn-2 can safely be discarded. Fn-2 will not be used at later point of time.

Type n Output Time taken (in seconds) Comment
Recursion 40 1.02334155E8 1.484 n>=40 takes huges time
Iteration 1476 1.3069892237633987E308 0.0 n>=1476 , double overflows
Recursion with memoization 1476 1.3069892237633987E308 0.015 n>=1476 , double overflows

Here goes the code :
public class Fib {

 HashMap<Double,Double> fib = new HashMap<Double, Double>();
 public static void main(String args[]) {
  Fib f = new Fib();
  f.analyzeFibRecursion(40);
  f.analyzeFibIteration(1476);
  f.analyzeFibMemoizedRecursion(1476);
 }
 
 public double fibRecursion(double l) {
  if(l <= 1) return l;
  else return fibRecursion(l-1)+fibRecursion(l-2);
 }
 
 public void analyzeFibRecursion(double l) {
  Date start = new Date();
  double value = fibRecursion(l);
  Date end = new Date();
  System.out.println("Final Output : "+value);
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 public double fibIteration(double l) {
  if(l <= 1) return l;
  else {
   double f1 = 0;
   double f2 = 1;
   double f3 = 0;
   double i = 2;
   while(i<=l) {
    f3 = f2 + f1;
    f1 = f2;
    f2 = f3;
    i++;
   }
   return f3;
  }
 }
 
 public void analyzeFibIteration(double l) {
  Date start = new Date();
  double value = fibIteration(l);
  Date end = new Date();
  System.out.println("Final Output : "+value);
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 public double fibMemoizedRecursion(double l) {
  if(fib.containsKey(l)) return fib.get(l);
  else {
   double v = fibMemoizedRecursion(l-1)+fibMemoizedRecursion(l-2);
   fib.put(l, v);
   return v;
  }
 }
 
 public void analyzeFibMemoizedRecursion(double l) {
  Date start = new Date();
  fib.clear();
  fib.put(0d,0d);
  fib.put(1d,1d);
  double value = fibMemoizedRecursion(l);
  Date end = new Date();
  System.out.println("Final Output : "+value);
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
}

2) Binary Search :

At, each point of time, you can choose the search in lower or upper partition. Hence, no need to recurse. You dont need to come back to current state again.

Type Array size Lookups Time taken (in seconds)
Recursion 631900 10000000 10.797
Iteration 631900 10000000 4.281

public class BinarySearch {
 
 static int arr[];
 
 public static int RANGE = 1000000;
 public static int ATTEMPTS = 10000000;

 public static void main(String args[]) {
  Random r = new Random();
  TreeSet<Integer> t = new TreeSet<Integer>();
  for(int i=0;i<RANGE;i++) t.add(r.nextInt(RANGE));
  arr=new int[t.size()];
  Integer[] iarr = t.toArray(new Integer[0]);
  for(int i=0;i<iarr.length;i++) {
   arr[i]=iarr[i];
  }  
  System.out.println(t.size());
  
  analyzeIterativeBinarySearch();
  analyzeRecursiveBinarySearch();
 }
 
 private static void analyzeIterativeBinarySearch() {
  Random r = new Random();
  Date start,end;
  int idx;
  int toFind;
  start = new Date();

  for(int i=0;i<ATTEMPTS;i++) {
   toFind = r.nextInt(RANGE);
   idx = iterativeBinarySearch(arr,toFind);
   //System.out.println(toFind+" "+idx);
  }
  end = new Date();
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 private static int iterativeBinarySearch(int arr[],int val) {
  int mid,low,high;
  low = 0 ;
  high = arr.length-1;
  while(low<=high) {
   mid=(low+high)/2;
   if(arr[mid]==val) return mid;
   else if(val<arr[mid]) high=mid-1;
   else low=mid+1;
  }
  return -1;
 }
 
 private static void analyzeRecursiveBinarySearch() {
  Random r = new Random();
  Date start,end;
  int idx;
  int toFind;
  start = new Date();

  for(int i=0;i<ATTEMPTS;i++) {
   toFind = r.nextInt(RANGE);
   idx = recursiveBinarySearch(arr,toFind,0,arr.length-1);
   //System.out.println(toFind+" "+idx);
  }
  end = new Date();
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 public static int recursiveBinarySearch(int[] inArray,int num,int start,int end) {
  int pivot=(int)Math.floor((end-start)/2)+start;
  if(num==inArray[pivot]) return pivot;
  if(start==end) return -1; 
  if(num<=inArray[pivot]) return recursiveBinarySearch(inArray,num,start,pivot);
  else return recursiveBinarySearch(inArray,num,pivot+1,end);
 }
}


3) Binary Tree Search/Tree traveral :

BST search is similar to binary search on tree. So, we can go for iteration. Same should be used in case of tries also. On the contrary, tree traveral need to go over all the nodes and hence a iterative traveral will be a over do.

Lets examine the performance of tree search

Type Array size Lookups Time taken (in seconds)
Recursion 631647 1000000 61.718
Iteration 631647 1000000 44.86

While in case of tree traveral, you can clearly see that use of stacks is an overdo.

Type Array size Time taken (in seconds)
Recursion 631647 0.015
Iteration 631647 0.063

Here goes the code :

public class TreeSearch {
 public static int RANGE = 1000000;
 public static int ATTEMPTS = 1000000;

 public static void main(String[] args) {
  
  Random r = new Random();
  HashSet<Double> h = new HashSet<Double>();
  for(int i=0;i<RANGE;i++) h.add((double)r.nextInt(RANGE));
  System.out.println(h.size());
  Tree t = new TreeSearch().new Tree();

  for(Double d : h) {
   t.insert(d.doubleValue());
  }
  
  analyzeRecursiveSearch(t);
  analyzeIterativeSearch(t);
  
  analyzeRecursiveBrowse(t);
  analyzeIterativeBrowse(t);
 }
 
 private static void analyzeRecursiveSearch(Tree t) {
  Random r = new Random();
  Date start,end;
  boolean found;
  double toFind;
  start = new Date();

  for(int i=0;i<ATTEMPTS;i++) {
   toFind = (double)r.nextInt(RANGE);
   found = t.recursiveSearch(t.root, toFind);
  }
  end = new Date();
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 private static void analyzeIterativeSearch(Tree t) {
  Random r = new Random();
  Date start,end;
  boolean found;
  double toFind;
  start = new Date();

  for(int i=0;i<ATTEMPTS;i++) {
   toFind = (double)r.nextInt(RANGE);
   found = t.iterativeSearch(t.root, toFind);
  }
  end = new Date();
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 private static void analyzeRecursiveBrowse(Tree t) {
  Date start = new Date();
  t.inorderRecursive(t.root);
  Date end = new Date();
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }
 
 private static void analyzeIterativeBrowse(Tree t) {
  Date start = new Date();
  t.inOrderIterative(t.root);
  Date end = new Date();
  System.out.println((end.getTime()-start.getTime())/1000.0+ " seconds");
 }

 class Node {
  Node left,right;
  double val;
  
  Node(double val) {
   this.val = val;
  }   
 }

 class Tree {
  Node root;
  Tree() {}

  void insert(double val) {
   root=insert(root,val);
  }
  
  Node insert(Node n,double val) {
   if(n==null) n = new Node(val);
   else if(n.val>val) n.left=insert(n.left,val);
   else n.right=insert(n.right,val);
   return n;
  }
  
  boolean recursiveSearch(Node n , double val) {
   if(n==null) return false;
   else {
    if(n.val==val) return true;
    else if(n.val>val) return recursiveSearch(n.left, val);
    else return recursiveSearch(n.right, val);
   }
  }
  
  boolean iterativeSearch(Node n , double val) {
   while(n!=null) {
    if(n.val==val) return true;
    else if(n.val>val) n=n.left;
    else n=n.right;
   }
   return false;
  }
  
  void inorderRecursive(Node n) {
   if(n==null) return;
   inorderRecursive(n.left);
   //System.out.print(n.val+" ");
   inorderRecursive(n.right);
  }
  
  public void inOrderIterative(Node n) {
   Stack<Node> s = new Stack<Node>(); 
   while (n !=null) {
     s.push(n);
     n = n.left;
   }
   while (!s.isEmpty()) {
    n = s.pop();
    //System.out.print(n.val+" ");
    n = n.right;
    while(n !=null) {
     s.push(n);
     n = n.left;
    } 
   } 
  }
 }
}

4) Matrix Chain Multiplication :

This is a classic example of dynamic programming. Iteration can work with a DP formulation only. Otherwise, a brute force approach would be to use recursion which is very bad. But, you can drastically improve performance by using recursion + memoization.

Type Length of p array No. of p's solved Time taken (in seconds)
Recursion 25 500 64.999
Recursion + Memoization 25 500 0.016
Iteration 25 500 0.0

Code :

public class MatrixMultiplication {
 public static int ATTEMPTS = 500;
 public static int RANGE = 25;

 public static void main(String args[]) {
  analyze();
 }

 private static void analyze() {
  Random r = new Random();
  Date start, end;
  int val;
  long totalTimeDP = 0, totalTimeRecursive = 0, totalTimeMemoized = 0;
  for (int i = 0; i < ATTEMPTS; i++) {
   HashSet<Integer> h = new HashSet<Integer>();
   for (int j = 0; j < RANGE; j++)
    h.add(r.nextInt(RANGE));
   h.remove(new Integer(0));
   int[] p = new int[h.size()];
   Integer[] iarr = h.toArray(new Integer[0]);
   for (int j = 0; j < iarr.length; j++) {
    p[j] = iarr[j];
   }
   start = new Date();
   val = dp(p);
   // System.out.println(val);
   end = new Date();
   totalTimeDP += (end.getTime() - start.getTime());

   start = new Date();
   int[][] m = new int[p.length][p.length];
   val = recursive(p, 1, p.length - 1, m);
   // System.out.println(val);
   end = new Date();
   totalTimeRecursive += (end.getTime() - start.getTime());

   start = new Date();
   m = new int[p.length][p.length];
   val = memoized(p, m);
   // System.out.println(val);
   end = new Date();
   totalTimeMemoized += (end.getTime() - start.getTime());
  }
  System.out.println(totalTimeDP / 1000.0 + " seconds");
  System.out.println(totalTimeRecursive / 1000.0 + " seconds");
  System.out.println(totalTimeMemoized / 1000.0 + " seconds");

 }

 public static int dp(int p[]) {
  int n = p.length - 1;

  int[][] m = new int[n + 1][n + 1];
  int[][] s = new int[n + 1][n + 1];

  for (int i = 1; i <= n; i++)
   m[i][i] = 0;

  for (int L = 2; L <= n; L++) {
   for (int i = 1; i <= n - L + 1; i++) {
    int j = i + L - 1;
    m[i][j] = Integer.MAX_VALUE;
    for (int k = i; k <= j - 1; k++) {
     // q = cost/scalar multiplications
     int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
     if (q < m[i][j]) {
      m[i][j] = q;
      s[i][j] = k;
     }
    }
   }
  }
  return m[1][n];
 }

 public static int recursive(int p[], int i, int j, int[][] m) {
  if (i == j)
   return 0;
  m[i][j] = Integer.MAX_VALUE;

  for (int k = i; k <= j - 1; k++) {
   int q = recursive(p, i, k, m) + recursive(p, k + 1, j, m)
     + p[i - 1] * p[k] * p[j];
   if (q < m[i][j])
    m[i][j] = q;
  }

  return m[i][j];
 }

 public static int memoized(int p[], int[][] m) {
  for (int i = 1; i < m.length; i++) {
   for (int j = 1; j < m.length; j++) {
    m[i][j] = Integer.MAX_VALUE;
   }
  }
  return memoized(p, 1, m.length - 1, m);
 }

 public static int memoized(int p[], int i, int j, int[][] m) {
  if (m[i][j] < Integer.MAX_VALUE)
   return m[i][j];

  if (i == j)
   m[i][j] = 0;
  else {
   for (int k = i; k <= j - 1; k++) {
    int q = memoized(p, i, k, m) + memoized(p, k + 1, j, m)
      + p[i - 1] * p[k] * p[j];
    if (q < m[i][j])
     m[i][j] = q;
   }
  }
  return m[i][j];
 }
}

1 comment:

  1. awesome!!!

    But I need big cup of coffee to understand completely..

    ReplyDelete