{-# LANGUAGE CPP #-}

-- | The Zipper is a data structure which maintains a location in 
-- a tree, and allows O(1) movement and local changes
-- (to be more precise, in our case it is O(k) where k is the number
-- of children of the node at question; typically this is a very small number).
--
module Data.Generics.Fixplate.Zipper where

--------------------------------------------------------------------------------

import Prelude hiding (foldl,foldr,mapM,mapM_,concat,concatMap)
import Data.Foldable
import Data.Traversable
import Data.Maybe

import Text.Show
import Text.Read 

import Data.Generics.Fixplate.Base
import Data.Generics.Fixplate.Open
import Data.Generics.Fixplate.Misc

#ifdef WITH_QUICKCHECK
import Test.QuickCheck
import Data.Generics.Fixplate.Attributes
import Data.Generics.Fixplate.Traversals
import Data.Generics.Fixplate.Test.Tools
import Control.Monad (liftM)
#endif 

--------------------------------------------------------------------------------
-- * Types

-- | A context node. 
type Node f  =  Either (Mu f) (Path f)

-- | The context or path type. The invariant we must respect is that there is exactly
-- one child with the 'Right' constructor.
data Path f  =  Top
             |  Path { unPath :: f (Node f) } 
               
-- | The zipper type itself, which encodes a locations in thre tree @Mu f@.
data Loc f   =  Loc { focus :: Mu f , path :: Path f } 

--------------------------------------------------------------------------------

instance EqF f => Eq (Path f) where               
  Top     == Top      = True
  Path p1 == Path p2  = equalF p1 p2
  _       == _        = False
  
instance EqF f => Eq (Loc f) where               
  Loc f1 p1 == Loc f2 p2  = f1 == f2 && p1 == p2
  
instance ShowF f => Show (Path f) where               
  showsPrec d Top = showString "Top"
  showsPrec d (Path xs) = showParen (d>10) 
    $ showString "Path "
    . showsPrecF 11 xs

instance ShowF f => Show (Loc f) where
  showsPrec d (Loc foc path) = showParen (d>10) 
    $ showString "Loc "
    . showsPrec 11 foc
    . showChar ' '
    . showsPrec 11 path

instance ReadF f => Read (Path f) where 
#ifdef __GLASGOW_HASKELL__
  readPrec = parens $ 
    (do
      { Ident "Top" <- lexP
      ; return Top            
      })
    +++
    (prec app_prec $ do
      { Ident "Path" <- lexP
      ; p <- step readPrecF
      ; return (Path p)            
      })    
#else                                  
  readsPrec d r = readParen (d > app_prec)
     (\r -> [ (Top, s) 
            | ("Top", s) <- lex r]) r    
     ++
     (\r -> [ (Path p, t) 
            | ("Path", s) <- lex r
            , (f,t) <- readsPrecF (app_prec+1) s]) r    
            
#endif
    
instance ReadF f => Read (Loc f) where 
#ifdef __GLASGOW_HASKELL__
  readPrec = parens $ 
    (prec app_prec $ do
      { Ident "Loc" <- lexP
      ; f <- step readPrec
      ; p <- step readPrec
      ; return (Loc f p)            
      })
#else                                  
  readsPrec d r = readParen (d > app_prec)
     (\r -> [ (Loc f p, u) 
            | ("Loc", s) <- lex r
            , (f,t) <- readsPrec (app_prec+1) s
            , (p,u) <- readsPrec (app_prec+1) t]) r    
#endif

--------------------------------------------------------------------------------
-- * Converting to and from zippers

-- | Creates a zipper from a tree, with the focus at the root.
root :: Mu f -> Loc f
root t = Loc t Top

-- | Restores a tree from a zipper.
defocus :: Traversable f => Loc f -> Mu f
defocus (Loc foc path) = go foc path where
  go t Top = t
  go t (Path xs)  = go (Fix s) path' where
    (Just path', s) = mapAccumL h Nothing xs
    h  old  (Left   y)  =  (old     , y)
    h  _    (Right  p)  =  (Just p  , t)

-- | We attribute all nodes with a zipper focused at that location.
locations :: Traversable f => Mu f -> Attr f (Loc f)
locations tree = go (root tree) tree where
  go loc (Fix t) = Fix (Ann loc t') where
    t' = enumerateWith_ (\j x -> go (unsafeMoveDown j loc) x) t

-- | The list of all locations.
locationsList :: Traversable f => Mu f -> [Loc f]
locationsList = toList . Attrib . locations
    
-- | The zipper version of 'forget'.
locForget :: Functor f => Loc (Ann f a) -> Loc f    
locForget (Loc foc path) = Loc (forget foc) (go path) where
  go :: Functor f => Path (Ann f a) -> Path f    
  go Top = Top
  go (Path (Ann _ nodes)) = Path (fmap h nodes)
  
  h :: Functor f => Node (Ann f a) -> Node f    
  h (Left  t) = Left  (forget t)
  h (Right p) = Right (go p)
    
--------------------------------------------------------------------------------
-- * Manipulating the subtree at focus

-- | Extracts the subtree at focus. Synonym of 'focus'.
extract :: Loc f -> Mu f
extract = focus

-- | Replaces the subtree at focus. 
replace :: Mu f -> Loc f -> Loc f
replace new loc = loc { focus = new }

-- | Modifies the subtree at focus. 
modify :: (Mu f -> Mu f) -> Loc f -> Loc f
modify h loc = replace (h (focus loc)) loc

--------------------------------------------------------------------------------
-- * Safe movements

-- | Moves down to the child with the given index.
-- The leftmost children has index @0@.
moveDown :: Traversable f => Int -> Loc f -> Maybe (Loc f)
moveDown pos (Loc foc path) = new where
  new = case mfoc' of  
     Nothing    ->  Nothing
     Just foc'  ->  Just $ Loc foc' (Path nodes')
  ((mfoc',_),nodes')  =  mapAccumL g (Nothing,0) (unFix foc)    
  g (old,j) x  =  if j==pos 
    then  ((Just x  , j+1),  Right  path  ) 
    else  ((old     , j+1),  Left   x     )      

-- | Moves down to the leftmost child.
moveDownL :: Traversable f => Loc f -> Maybe (Loc f)
moveDownL (Loc foc path) = new where
  new = case mfoc' of  
     Nothing    ->  Nothing
     Just foc'  ->  Just $ Loc foc' (Path nodes')
  (mfoc',nodes')  =  mapAccumL g Nothing (unFix foc)    
  g old x  = case old of
    Nothing  -> (Just x  ,  Right  path  ) 
    _        -> (old     ,  Left   x     )      

-- | Moves down to the rightmost child.
moveDownR :: Traversable f => Loc f -> Maybe (Loc f)
moveDownR (Loc foc path) = new where
  new = case mfoc' of  
     Nothing    ->  Nothing
     Just foc'  ->  Just $ Loc foc' (Path nodes')
  (mfoc',nodes')  =  mapAccumR g Nothing (unFix foc)    
  g old x  = case old of
    Nothing  -> (Just x  ,  Right  path  ) 
    _        -> (old     ,  Left   x     )      
    
--------------------------------------------------------------------------------

-- | Moves up.
moveUp :: Traversable f => Loc f -> Maybe (Loc f)
moveUp (Loc foc path) = case path of
  Top         -> Nothing
  Path nodes  -> 
    case mpath of
      Nothing      -> error "moveUp: shouldn't happen"
      Just path'   -> Just $ case path' of
        Path nodes'    -> Loc (Fix foc') (Path nodes')
        Top            -> Loc (Fix foc') Top
    where      
      (mpath,foc') = mapAccumL g Nothing nodes 
      g old ei = case ei of
        Right  p  -> (Just p  , foc)
        Left   x  -> (old     , x  )

--------------------------------------------------------------------------------

moveRight :: Traversable f => Loc f -> Maybe (Loc f)
moveRight (Loc foc path) = case path of
  Top         -> Nothing
  Path nodes  -> 
    case two of
      Two foc' -> Just $ Loc foc' (Path nodes')
      _        -> Nothing
    where      
      (two,nodes') = mapAccumL g Empty nodes 
      g old ei = case ei of
        Right  p  ->  (One p  , Left foc  )
        Left   x  -> case old of
          One p ->    (Two x  , Right p   )
          _     ->    (old    , ei        )

moveLeft :: Traversable f => Loc f -> Maybe (Loc f)
moveLeft (Loc foc path) = case path of
  Top         -> Nothing
  Path nodes  -> 
    case two of
      Two foc' -> Just $ Loc foc' (Path nodes')
      _        -> Nothing
    where      
      (two,nodes') = mapAccumR g Empty nodes 
      g old ei = case ei of
        Right  p  ->  (One p  , Left foc  )
        Left   x  -> case old of
          One p ->    (Two x  , Right p   )
          _     ->    (old    , ei        )

--------------------------------------------------------------------------------
-- * Testing for borders

-- | Checks whether we are at the top (root).
isTop :: Loc f -> Bool
isTop (Loc _ p) = case p of { Top -> True ; _ -> False }

-- | Checks whether we cannot move down.
isBottom :: Traversable f => Loc f -> Bool
isBottom = isNothing . moveDownL

isLeftmost :: Traversable f => Loc f -> Bool
isLeftmost = isNothing . moveLeft

isRightmost :: Traversable f => Loc f -> Bool
isRightmost = isNothing . moveRight

--------------------------------------------------------------------------------
-- * Location queries

-- | Gives back the index of the given location among the children of its parent.
-- Indexing starts from zero. In case of root node (no parent), we also return zero.
horizontalPos :: Foldable f => Loc f -> Int
horizontalPos (Loc _ path) = case path of
  Top        -> 0
  Path nodes -> 
    case mpos of
      Right pos -> pos
      Left _    -> error "horizontalPos: shouldn't happen"
    where
      mpos = foldl g (Left 0) nodes 
      g old ei = case old of
        Right _ -> old
        Left  j -> case ei of
          Left _  -> Left (j+1)
          Right _ -> Right j

-- | We return the full path from the root as a sequence of child indices.
-- This means that
-- 
-- > loc == foldl (flip unsafeMoveDown) (moveTop loc) (fullPathDown loc)
--
fullPathDown :: Foldable f => Loc f -> [Int]
fullPathDown = reverse . fullPathUp

-- | The following equations hold for 'fullPathUp' and 'fullPathDown':
-- 
-- > fullPathUp == reverse . fullPathDown
-- > loc == foldr unsafeMoveDown (moveTop loc) (fullPathUp loc)
--
fullPathUp :: Foldable f => Loc f -> [Int]
fullPathUp (Loc _ pth) = go pth where    
  go path = case path of
    Top        -> []
    Path nodes -> 
      case mpos of
        Right (pos,parent) -> pos : go parent
        Left _             -> error "fullPathUp: shouldn't happen"
      where
        mpos = foldl g (Left 0) nodes 
        g old ei = case old of
          Right _ -> old
          Left  j -> case ei of
            Left _  -> Left (j+1)
            Right p -> Right (j,p)    

--------------------------------------------------------------------------------
-- * Compound movements
  
-- | Moves to the top, by repeatedly moving up.
moveTop :: Traversable f => Loc f -> Loc f
moveTop = tillNothing moveUp

-- | Moves left until it can.
-- It should be faster than repeated left steps.
leftmost :: Traversable f => Loc f -> Loc f
leftmost orig@(Loc foc path) = case path of
  Top         -> orig
  Path nodes  -> 
    case both of
      Both {}  -> Loc foc' (Path nodes')
      _        -> error "leftmost: shouldn't happen"
    where  
      -- this tricky implementation uses lazyness 
      -- so that we only need a single traversal
      (foc',pnew) = case both of { Both f p -> (f,p) ; _ -> error "leftmost: shouldn't happen" }    
      (both,nodes') = mapAccumL g None nodes 
      g old ei = case old of
        None -> case ei of
          Left  x  -> (First x    , Right pnew)
          Right p  -> (Both foc p , ei        )   -- we are already at the leftmost position
        First f -> case ei of
          Left  x  -> (old        , ei        )
          Right p  -> (Both f p   , Left  foc )
        Both {} -> (old, ei)

-- | Moves right until it can.
-- It should be faster than repeated right steps.
rightmost :: Traversable f => Loc f -> Loc f
rightmost orig@(Loc foc path) = case path of
  Top         -> orig
  Path nodes  -> 
    case both of
      Both {}  -> Loc foc' (Path nodes')
      _        -> error "rightmost: shouldn't happen"
    where  
      -- this tricky implementation uses lazyness 
      -- so that we only need a single traversal
      (foc',pnew) = case both of { Both f p -> (f,p) ; _ -> error "rightmost: shouldn't happen" }    
      (both,nodes') = mapAccumR g None nodes 
      g old ei = case old of
        None -> case ei of
          Left  x  -> (First x    , Right pnew)
          Right p  -> (Both foc p , ei        )   -- we are already at the rightmost position
        First f -> case ei of
          Left  x  -> (old        , ei        )
          Right p  -> (Both f p   , Left  foc )
        Both {} -> (old, ei)
          
--------------------------------------------------------------------------------
-- * Unsafe movements

unsafeMoveDown :: Traversable f => Int -> Loc f -> Loc f
unsafeMoveDown i = unsafe (moveDown i) "unsafeMoveDown: cannot move down"
  
unsafeMoveDownL :: Traversable f => Loc f -> Loc f
unsafeMoveDownR :: Traversable f => Loc f -> Loc f
unsafeMoveUp    :: Traversable f => Loc f -> Loc f

unsafeMoveDownL = unsafe moveDownL "unsafeMoveDownL: cannot move down"
unsafeMoveDownR = unsafe moveDownR "unsafeMoveDownR: cannot move down"  
unsafeMoveUp    = unsafe moveUp    "unsafeMoveUp: cannot move up"  

unsafeMoveLeft, unsafeMoveRight :: Traversable f => Loc f -> Loc f
unsafeMoveLeft  = unsafe moveLeft  "unsafeMoveLeft: cannot move left"  
unsafeMoveRight = unsafe moveRight "unsafeMoveRight: cannot move right"    

--------------------------------------------------------------------------------
#ifdef WITH_QUICKCHECK
-- * Tests

type LocT a = Loc (TreeF a)

{-
data Step
  = StepUp
  | StepLeft
  | StepRight
  | StepDown Int
  | StepDownL
  | StepDownR
  deriving (Eq,Ord,Show)

newtype Walk = Walk [Step] deriving (Eq,Ord,Show)
  
walk :: Traversable f => Walk -> Loc f -> Loc f  
walk (Walk steps) loc = foldl (flip singleStep) loc steps

singleStep :: Traversable f => Step -> Loc f -> Loc f
singleStep s loc = case stepMaybe s loc of { Nothing -> loc ; Just new -> new }

stepMaybe :: Traversable f => Step -> Loc f -> Maybe (Loc f)
stepMaybe s = case s of
  StepUp     -> moveUp
  StepLeft   -> moveLeft
  StepRight  -> moveRight
  StepDown j -> moveDown j
  StepDownL  -> moveLeft
  StepDownR  -> moveRight
  
instance Arbitrary Step where
  arbitrary = oneof
    [ return StepUp
    , return StepLeft
    , return StepRight
    , do { j <- choose (1,7) ; return (StepDown j) }
    , return StepDownL
    , return StepDownR
    ]

instance Arbitrary Walk where
  arbitrary = liftM Walk arbitrary
  shrink (Walk steps) = map Walk (shrink steps)
-}

-- | Assuming a left-to-right canonical numbering, we find the given
-- location.
findLoc :: Traversable f => Int -> Loc (Ann f Int) -> Loc (Ann f Int) 
findLoc k = go where
  go loc = 
    case compare j k of
      GT -> error "findLoc: shouldn't happen?"
      EQ -> loc
      LT -> case moveDownL loc of
        Just xx -> go xx
        Nothing -> case moveRight loc of
          Just yy -> go yy
          Nothing -> goUpR (unsafeMoveUp loc)
    where
      Fix (Ann j _) = focus loc
  goUpR loc = case moveRight loc of
    Nothing -> goUpR (unsafeMoveUp loc)
    Just zz -> go zz

----
tmp = treeF "root"
  [ treeF "a" [ treeF "a1" [] , treeF "a2" [] ]
  , treeF "b" []
  , treeF "c" [ treeF "c1" [] , treeF "c2" [] , treeF "c3" [] ]
  ]
----
  
instance Arbitrary a => Arbitrary (LocT a) where
  arbitrary = do
    tree <- arbitrary 
    let (n,numbered) = enumerateNodes tree
    k <- choose (0,n-1)
    return $ locForget $ findLoc k (root numbered)

rndLoc :: IO (LocT Label)
rndLoc = liftM (!!7) $ sample' arbitrary
  
newtype ChildIndex = ChildIndex Int deriving Show

instance Arbitrary ChildIndex where
  arbitrary = liftM ChildIndex $ choose (0,7)
  
--------------------------------------------------------------------------------

runtests_Zipper :: IO ()
runtests_Zipper = do
  quickCheck prop_ReadShowLoc
  quickCheck prop_findLoc
  quickCheck prop_locationsList
  quickCheck prop_contextList
  quickCheck prop_Top
  quickCheck prop_defocus
  quickCheck prop_horizontalPos
  quickCheck prop_fullPathDown
  quickCheck prop_fullPathUp
  quickCheck prop_fullPathUp2
  quickCheck prop_leftmost
  quickCheck prop_rightmost
  quickCheck prop_DownLUp
  quickCheck prop_DownRUp
  quickCheck prop_UpDownL 
  quickCheck prop_UpDownR
  quickCheck prop_DownL
  quickCheck prop_DownR
  quickCheck prop_UpDownJ
  quickCheck prop_LeftRight
  quickCheck prop_RightLeft

----------------------------------------

leftmostNaive :: Traversable f => Loc f -> Loc f
leftmostNaive = tillNothing moveLeft

rightmostNaive :: Traversable f => Loc f -> Loc f
rightmostNaive = tillNothing moveRight

fullPathUpNaive :: Traversable f => Loc f -> [Int]
fullPathUpNaive = go where
  go loc@(Loc _ path) = case path of
    Top -> []
    _   -> horizontalPos loc : go (unsafeMoveUp loc)

----------------------------------------

prop_ReadShowLoc :: LocT Label -> Bool
prop_ReadShowLoc loc = read (show loc) == loc

prop_locationsList :: FixT Label -> Bool
prop_locationsList tree = locationsList tree == [ locForget $ findLoc i top | i<-[0..n-1] ] where
  top = root numbered
  (n,numbered) = enumerateNodes tree

prop_findLoc :: FixT Label -> Bool
prop_findLoc tree = [0..n-1] == [ attribute $ focus $ findLoc i top | i<-[0..n-1] ] where
  top = root numbered
  (n,numbered) = enumerateNodes tree

prop_contextList :: FixT Label -> Bool  
prop_contextList tree =
  map (\(Fix (TreeF l ts),replace) -> replace (Fix (TreeF (h l) ts))) (contextList tree)
  ==
  [ defocus $ modify (\(Fix (TreeF l ts)) -> Fix (TreeF (h l) ts) ) $ locForget $ findLoc i top | i<-[0..n-1] ]
  where
    top = root numbered
    (n,numbered) = enumerateNodes tree
    h (Label xs) = Label ('_':xs)
  
prop_Top :: LocT Label -> Bool
prop_Top loc = root (defocus loc) == moveTop loc

prop_defocus :: FixT Label -> Bool
prop_defocus tree = 
  Prelude.and [ defocus (findLoc i top) == numbered | i<-[0..n-1] ] 
  where
    top = root numbered
    (n,numbered) = enumerateNodes tree

----------------------------------------

prop_horizontalPos :: LocT Label -> Bool
prop_horizontalPos loc = 
  loc == iterateN (horizontalPos loc) unsafeMoveRight (leftmost loc)

prop_fullPathDown :: LocT Label -> Bool
prop_fullPathDown loc = 
  loc == foldl (flip unsafeMoveDown) (moveTop loc) (fullPathDown loc)

prop_fullPathUp :: LocT Label -> Bool
prop_fullPathUp loc = 
  fullPathUp loc == fullPathUpNaive loc

prop_fullPathUp2 :: LocT Label -> Bool 
prop_fullPathUp2 loc = 
  loc == foldr unsafeMoveDown (moveTop loc) (fullPathUp loc)  
  
----------------------------------------
    
prop_leftmost :: LocT Label -> Bool
prop_leftmost loc = leftmost loc == leftmostNaive loc

prop_rightmost :: LocT Label -> Bool
prop_rightmost loc = rightmost loc == rightmostNaive loc
    
prop_DownLUp :: LocT Label -> Property
prop_DownLUp loc = 
  (not $ isBottom loc) 
  ==> unsafeMoveUp (unsafeMoveDownL loc) == loc

prop_DownRUp :: LocT Label -> Property  
prop_DownRUp loc = 
  (not $ isBottom loc) 
  ==> unsafeMoveUp (unsafeMoveDownR loc) == loc

prop_UpDownL :: LocT Label -> Property
prop_UpDownL loc = 
  (not $ isTop loc) 
  ==> unsafeMoveDownL (unsafeMoveUp loc) == leftmost loc

prop_UpDownR :: LocT Label -> Property
prop_UpDownR loc = 
  (not $ isTop loc) 
  ==> unsafeMoveDownR (unsafeMoveUp loc) == rightmost loc

prop_DownL :: LocT Label -> Property
prop_DownL loc =
  (not $ isBottom loc)
  ==> unsafeMoveDownL loc == unsafeMoveDown 0 loc

prop_DownR :: LocT Label -> Property
prop_DownR loc =
  (not $ isBottom loc)
  ==> let k = length $ children $ focus loc
      in  unsafeMoveDownR loc == unsafeMoveDown (k-1) loc

prop_UpDownJ :: ChildIndex -> LocT Label -> Property
prop_UpDownJ (ChildIndex j) loc = 
  (not $ isTop loc) 
  ==> (j < (length $ children $ focus $ unsafeMoveUp loc))  
  ==> unsafeMoveDown j (unsafeMoveUp loc) == iterateN j unsafeMoveRight (leftmost loc)

prop_LeftRight :: LocT Label -> Property
prop_LeftRight loc = 
  (not $ isLeftmost loc)
  ==> unsafeMoveRight (unsafeMoveLeft loc) == loc  

prop_RightLeft :: LocT Label -> Property
prop_RightLeft loc = 
  (not $ isRightmost loc)
  ==> (unsafeMoveLeft (unsafeMoveRight loc) == loc)  

--------------------------------------------------------------------------------

#endif