{-# LANGUAGE ExistentialQuantification #-}

module Data.InfiniteTree
  ( Tree
  , mkTree
  , root
  , left
  , right
  , branchF
  , surreals
  , showTree
  , showTreeWide
  , showTree'
  , showWide
  , rotateR
  , rotateL
  ) where

import Control.Arrow ((&&&))
import Control.Comonad

data Tree a = forall b. T b (b -> a) (b -> b) (b -> b)

mkTree :: seed -> (seed -> a) -> (seed -> seed) -> (seed -> seed) -> Tree a
mkTree seed v l r = T seed v l r

root :: Tree a -> a
root (T s v _ _) = v s

left :: Tree a -> Tree a
left (T s v l r) = T (l s) v l r

right :: Tree a -> Tree a
right (T s v l r) = T (r s) v l r

instance Functor Tree where
  fmap f (T s v l r) = T s (f . v) l r

instance Comonad Tree where
  extract = root
  extend f (T s v l r) = T s (\s' -> f (T s' v l r)) l r

branchF :: Functor f => f (Tree a) -> Tree (f a)
branchF f = mkTree f (fmap root) (fmap left) (fmap right)

surreals :: Fractional a => Tree a
surreals = mkTree (Nothing, Nothing) avg (fst &&& Just . avg) (Just . avg &&& snd)
  where
    avg (Nothing, Nothing) = 0
    avg (Just x,  Nothing) = x + 1
    avg (Nothing, Just y)  = y - 1
    avg (Just x,  Just y)  = (x + y) / 2
    
{-
infix 1 &&&
f &&& g = \x -> (f x, g x)
-}

showTree :: Show a => Int -> Tree a -> String
showTree = showTreeWide True

showTreeWide :: Show a => Bool -> Int -> Tree a -> String
showTreeWide wide d t = showTree' wide [] [] t d ""

showTree' :: Show a => Bool -> [String] -> [String] -> Tree a -> Int -> ShowS
showTree' _    _     _     _ 0 = id
showTree' _    lbars _     t 1
  = showBars lbars . shows (root t) . showString "...\n"
showTree' wide lbars rbars t d
  = showTree' wide (withBar rbars) (withEmpty rbars) (right t) (d - 1) .
    showWide wide rbars .
    showBars lbars . shows (root t) . showChar '\n' .
    showWide wide lbars .
    showTree' wide (withEmpty lbars) (withBar lbars) (left t) (d - 1)

showWide :: Bool -> [String] -> ShowS
showWide wide bars
  | wide      = showString (concat (reverse bars)) . showString "|\n"
  | otherwise = id

showBars :: [String] -> ShowS
showBars []   = id
showBars bars = showString (concat (reverse (tail bars))) . showString node

node :: String
node           = "+--"

withBar, withEmpty :: [String] -> [String]
withBar   bars = "|  " :bars
withEmpty bars = "   " :bars

data Rot = Zero | One | Two | Three


rotateL :: Tree a -> Tree a
rotateL t' = mkTree (t',Two) n l r
  where
    n (t,Two)  = root (right t)
    n (t,One)  = root t
    n (t,Zero) = root t
    n (_,Three) = error "rotateL n Three"
    
    l (t,Two)  = (t, One)
    l (t,One)  = (left t, Zero)
    l (t,Zero) = (left t, Zero)
    l (_,Three) = error "rotateL l Three"
    
    r (t,Two)  = (right (right t), Zero)
    r (t,One)  = (left (right t), Zero)
    r (t,Zero) = (right t, Zero)
    r (_,Three) = error "rotateL r Three"

rotateR :: Tree a -> Tree a
rotateR t' = mkTree (t',Two) n l r
  where
    n (t,Two)  = root (left t)
    n (t,One)  = root t
    n (t,Zero) = root t
    n (_,Three) = error "rotateR n Three"
    
    l (t,Two)  = (left (left t), Zero)
    l (t,One)  = (right (left t), Zero)
    l (t,Zero) = (left t, Zero)
    l (_,Three) = error "rotateR l Three"
    
    r (t,Two)  = (t, One)
    r (t,One)  = (right t, Zero)
    r (t,Zero) = (right t, Zero)
    r (_,Three) = error "rotateR r Three"