{-# LANGUAGE NoMonomorphismRestriction #-}
module BalancedFold where

-- import Test.QuickCheck
-- import Test.QuickCheck.Property
import Control.Exception

balancedFold :: (a -> a -> a) -> [a] -> a
balancedFold f = go
    where
      go [x] = x
      go xs = 
          let
              (l,r) = splitAt (length xs `div` 2) xs
          in
            f (go l) (go r) 


-- | @AscendFromLeaf l r leaf m i@ 
-- \"descends\" to the /i/-th leaf -- counted from left to right, zero-based -- 
-- of a tree with the same structure as 
--
-- > @balancedFold Node (repeat m (Leaf ()))@
--
-- would produce. It then starts with the value /leaf/, and ascends back up, 
-- applying /l/ to the value whenever ascending up a left-child edge 
-- and /r/ when ascending up a right-child edge.
ascendFromLeaf :: (Integral int) => (t -> t) -> (t -> t) -> t -> int -> int -> t
ascendFromLeaf l r leaf = go
    where
      go 1 i = assert (i==0) $ leaf
      go m i = 
          let
              nl = m     `div` 2
              nr = (m+1) `div` 2
          in if i < nl 
             then l (go nl i)
             else r (go nr (i-nl))


-- * Testing

-- data Tree a = Node (Tree a) (Tree a) | Leaf a
--             deriving (Eq)
                     
-- leftChild (Node x _) = x
-- leftChild x = error ("leftChild "++show x)
-- rightChild (Node _ x) = x
-- rightChild x = error ("leftChild "++show x)
                     
-- instance Show a => Show (Tree a) where
--     show = go 0
--         where
--           go 0 x = case x of
--                      Leaf y -> show y
--                      Node x1 x2 -> "+\n"++go 1 x1++"\n"++go 1 x2
--           go n x = concat (replicate (n-1) "| ") ++ "+-" ++
--                    case x of
--                      Leaf y -> show y
--                      Node x1 x2 -> "+\n" ++ go (n+1) x1 ++ "\n" ++ go (n+1) x2

-- prop1 = do
--   n <- choose (1,100)
--   i <- choose (0,n-1)
--   let
--       tree :: Tree Int
--       tree = balancedFold Node [ Leaf i | i <- [0..n-1] ]
             
--       getter :: Tree Int -> Tree Int
--       getter = ascendFromLeaf 
--                (\f -> f . leftChild) 
--                (\f -> f . rightChild)
--                id
--                n
--                i
               
             
--   return (whenFail (print (n,tree,i,getter tree)) $
--           getter tree == Leaf i)
      
---- END TESTING