{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE ScopedTypeVariables #-}

 -- the CPP seems to confuse GHC; we have uniplate patterns
{-# OPTIONS_GHC -fno-warn-unused-binds -fno-warn-incomplete-patterns #-}
{-# OPTIONS_HADDOCK show-extensions #-}

-- |
-- Module      :  Yi.Syntax.Tree
-- License     :  GPL-2
-- Maintainer  :  yi-devel@googlegroups.com
-- Stability   :  experimental
-- Portability :  portable

-- Generic syntax tree handling functions
module Yi.Syntax.Tree (IsTree(..), toksAfter, allToks, tokAtOrBefore,
                       toksInRegion, sepBy, sepBy1,
                       getLastOffset, getFirstOffset,
                       getFirstElement, getLastElement,
                       getLastPath,
                       getAllSubTrees,
                       tokenBasedAnnots, tokenBasedStrokes,
                       subtreeRegion,
                       fromLeafToLeafAfter, fromNodeToFinal)
  where

-- Some of this might be replaced by a generic package
-- such as multirec, uniplace, emgm, ...

import           Control.Applicative
import           Control.Arrow (first)
import           Data.Foldable
import           Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import           Data.Maybe
import           Data.Monoid
import           Prelude hiding (concatMap, error)
import           Yi.Buffer.Basic
import           Yi.Debug
import           Yi.Lexer.Alex
import           Yi.Region
import           Yi.String

#ifdef TESTING
import           Test.QuickCheck
import           Test.QuickCheck.Property (unProperty)
#endif

-- Fundamental types
type Path = [Int]
type Node t = (Path, t)

class Foldable tree => IsTree tree where
    -- | Direct subtrees of a tree
    subtrees :: tree t -> [tree t]
    subtrees = fst . uniplate
    uniplate :: tree t -> ([tree t], [tree t] -> tree t)
    emptyNode :: tree t

toksAfter :: Foldable t1 => t -> t1 a -> [a]
toksAfter _begin = allToks

allToks :: Foldable t => t a -> [a]
allToks = toList

tokAtOrBefore :: Foldable t => Point -> t (Tok t1) -> Maybe (Tok t1)
tokAtOrBefore p res =
  listToMaybe $ reverse $ toksInRegion (mkRegion 0 (p+1)) res

toksInRegion :: Foldable t1 => Region -> t1 (Tok t) -> [Tok t]
toksInRegion reg = takeWhile (\t -> tokBegin t <= regionEnd   reg)
                   . dropWhile (\t -> tokEnd t < regionStart reg)
                   . toksAfter (regionStart reg)

tokenBasedAnnots :: (Foldable t1) => (a1 -> Maybe a) -> t1 a1 -> t -> [a]
tokenBasedAnnots tta t begin = catMaybes (tta <$> toksAfter begin t)

tokenBasedStrokes :: (Foldable t3) => (a -> b) -> t3 a -> t -> t2 -> t1 -> [b]
tokenBasedStrokes tts t _point begin _end = tts <$> toksAfter begin t

-- | Prune the nodes before the given point.
-- The path is used to know which nodes we can force or not.
pruneNodesBefore :: IsTree tree => Point -> Path -> tree (Tok a) -> tree (Tok a)
pruneNodesBefore _ [] t = t
pruneNodesBefore p (x:xs) t = rebuild $ left' <> (pruneNodesBefore p xs c : rs)
    where (children,rebuild) = uniplate t
          (left,c:rs) = splitAt x children
          left' = fmap replaceEmpty left
          replaceEmpty s = if getLastOffset s < p then emptyNode else s

-- | Given an approximate path to a leaf at the end of the region,
-- return: (path to leaf at the end of the region,path from focused
-- node to the leaf, small node encompassing the region)
fromNodeToFinal :: IsTree tree => Region -> Node (tree (Tok a))
                -> Node (tree (Tok a))
fromNodeToFinal r (xs,root) =
    trace ("r = " <> showT r) $
    trace ("focused ~ " <> showT (subtreeRegion focused) ) $
    trace ("pathFromFocusedToLeaf = " <> showT focusedToLeaf) $
    trace ("pruned ~ " <> showT (subtreeRegion focused)) (xs', pruned)

    where n@(xs',_) = fromLeafToLeafAfter (regionEnd r) (xs,root)
          (_,(focusedToLeaf,focused)) = fromLeafAfterToFinal p0 n
          p0 = regionStart r
          pruned = pruneNodesBefore p0 focusedToLeaf focused

-- | Return the first element that matches the predicate, or the last
-- of the list if none matches.
firstThat :: (a -> Bool) -> NonEmpty a -> a
firstThat _ (x :| []) = x
firstThat p (x :| [y]) = if p x then x else y
firstThat p (x :| y : xs) = if p x then x else firstThat p (y :| xs)

-- | Return the element before first element that violates the
-- predicate, or the first of the list if that one violates the
-- predicate.
lastThat :: (a -> Bool) -> NonEmpty a -> a
lastThat p (x :| xs) = if p x then work x xs else x
    where work x0 [] = x0
          work x0 (y:ys) = if p y then work y ys else x0

-- | Given a path to a node, return a path+node which node that
-- encompasses the given node + a point before it.
fromLeafAfterToFinal :: IsTree tree => Point -> Node (tree (Tok a))
                     -> (Path, Node (tree (Tok a)))
fromLeafAfterToFinal p n =
    -- trace ("reg = " <> showT (fmap (subtreeRegion . snd) nsPth)) $
      firstThat (\(_,(_,s)) -> getFirstOffset s <= p) ns
    where ns = NE.reverse (nodesOnPath n)

-- | Search the tree in pre-order starting at a given node, until
-- finding a leaf which is at or after the given point. An effort is
-- also made to return a leaf as close as possible to @p@.
--
-- TODO: rename to fromLeafToLeafAt
fromLeafToLeafAfter :: IsTree tree => Point
                    -> Node (tree (Tok a))
                    -> Node (tree (Tok a))
fromLeafToLeafAfter p (xs, root) =
  trace "fromLeafToLeafAfter:" $
  trace ("xs = " <> showT xs) $
  trace ("xsValid = " <> showT xsValid) $
  trace ("p = " <> showT p) $
  trace ("leafBeforeP = " <> showT leafBeforeP) $
  trace ("leaf ~ " <> showT (subtreeRegion leaf)) $
  trace ("xs' = " <> showT xs') result
  where
    xs' = case candidateLeaves of
      [] -> []
      c:cs -> fst $ firstOrLastThat (\(_,s) -> getFirstOffset s >= p) (c :| cs)
    candidateLeaves = allLeavesRelative relChild n
    (firstOrLastThat,relChild) = if leafBeforeP then (firstThat,afterChild)
                                                else (lastThat,beforeChild)
    (xsValid,leaf) = wkDown (xs,root)
    leafBeforeP = getFirstOffset leaf <= p
    n = (xsValid,root)
    result = (xs',root)

allLeavesRelative :: IsTree tree => (Int -> [(Int, tree a)] -> [(Int, tree a)])
                  -> Node (tree a)
                  -> [Node (tree a)]
allLeavesRelative select
   = filter (not . nullSubtree . snd) . allLeavesRelative' select
     . NE.toList . NE.reverse . nodesAndChildIndex
     -- we remove empty subtrees because their region is [0,0].

-- | Takes a list of (node, index of already inspected child), and
-- return all leaves in this node after the said child).
allLeavesRelative' :: IsTree tree => (Int -> [(Int, tree a)] -> [(Int, tree a)])
                   -> [(Node (tree a), Int)] -> [Node (tree a)]
allLeavesRelative' select l =
  [(xs <> xs', t') | ((xs,t),c) <- l
                   , (xs',t') <- allLeavesRelativeChild select c t]

-- | Given a root, return all the nodes encountered along it, their
-- paths, and the index of the child which comes next.
nodesAndChildIndex :: IsTree tree => Node (tree a)
                   -> NonEmpty (Node (tree a), Int)
nodesAndChildIndex ([],t) = return (([],t),negate 1)
nodesAndChildIndex (x:xs, t) = case index x (subtrees t) of
  Just c' -> (([],t), x)
             NE.<| fmap (first $ first (x:)) (nodesAndChildIndex (xs,c'))
  Nothing -> return (([],t),negate 1)

nodesOnPath :: IsTree tree => Node (tree a) -> NonEmpty (Path, Node (tree a))
nodesOnPath ([],t) = return ([],([],t))
nodesOnPath (x:xs,t) = ([],(x:xs,t)) NE.<| case index x (subtrees t) of
  Nothing -> error "nodesOnPath: non-existent path"
  Just c -> fmap (first (x:)) (nodesOnPath (xs,c))


beforeChild :: Int -> [a] -> [a]

beforeChild (-1) = reverse -- (-1) indicates that all children should be taken.
beforeChild c = reverse . take (c-1)

afterChild :: Int -> [a] -> [a]
afterChild c = drop (c+1)

-- | Return all leaves after or before child depending on the relation
-- which is given.
allLeavesRelativeChild :: IsTree tree => (Int -> [(Int, tree a)]
                                          -> [(Int, tree a)])
                       -> Int
                       -> tree a -> [Node (tree a)]
allLeavesRelativeChild select c t
  | null ts = return ([], t)
  | otherwise = [(x:xs,t') | (x,ct) <- select c (zip [0..] ts),
                 (xs, t') <- allLeavesIn select ct]
 where ts = subtrees t


-- | Return all leaves (with paths) inside a given root.
allLeavesIn :: (IsTree tree) => (Int -> [(Int, tree a)] -> [(Int, tree a)])
            -> tree a -> [Node (tree a)]
allLeavesIn select = allLeavesRelativeChild select (-1)

-- | Return all subtrees in a tree; each element of the return list
-- contains paths to nodes. (Root is at the start of each path)
getAllPaths :: IsTree tree => tree t -> [[tree t]]
getAllPaths t = fmap (<>[t]) ([] : concatMap getAllPaths (subtrees t))

goDown :: IsTree tree => Int -> tree t -> Maybe (tree t)
goDown i = index i . subtrees

index :: Int -> [a] -> Maybe a
index _ [] = Nothing
index 0 (h:_) = Just h
index n (_:t) = index (n-1) t

walkDown :: IsTree tree => Node (tree t) -> Maybe (tree t)
walkDown ([],t) = return t
walkDown (x:xs,t) = goDown x t >>= curry walkDown xs

wkDown :: IsTree tree => Node (tree a) -> Node (tree a)
wkDown ([],t) = ([],t)
wkDown (x:xs,t) = case goDown x t of
    Nothing -> ([],t)
    Just t' -> first (x:) $ wkDown (xs,t')

-- | Search the given list, and return the last tree before the given
-- point; with path to the root. (Root is at the start of the path)
getLastPath :: IsTree tree => [tree (Tok t)] -> Point -> Maybe [tree (Tok t)]
getLastPath roots offset =
    case takeWhile ((< offset) . posnOfs . snd) allSubPathPosn of
      [] -> Nothing
      xs -> Just $ fst $ last xs
    where
      allSubPathPosn = [ (p,posn) | root <- roots
                                  , p@(t':_) <- getAllPaths root
                                  , Just tok <- [getFirstElement t']
                                  , let posn = tokPosn tok
                                  ]

-- | Return all subtrees in a tree, in preorder.
getAllSubTrees :: IsTree tree => tree t -> [tree t]
getAllSubTrees t = t : concatMap getAllSubTrees (subtrees t)

-- | Return the 1st token of a subtree.
getFirstElement :: Foldable t => t a -> Maybe a
getFirstElement tree = getFirst $ foldMap (First . Just) tree

nullSubtree :: Foldable t => t a -> Bool
nullSubtree = null . toList

getFirstTok, getLastTok :: Foldable t => t a -> Maybe a

getFirstTok = getFirstElement
getLastTok = getLastElement

-- | Return the last token of a subtree.
getLastElement :: Foldable t => t a -> Maybe a
getLastElement tree = getLast $ foldMap (Last . Just) tree

getFirstOffset, getLastOffset :: Foldable t => t (Tok t1) -> Point
getFirstOffset = maybe 0 tokBegin . getFirstTok
getLastOffset = maybe 0 tokEnd . getLastTok

subtreeRegion :: Foldable t => t (Tok t1) -> Region
subtreeRegion t = mkRegion (getFirstOffset t) (getLastOffset t)

-- | Given a tree, return (first offset, number of lines).
getSubtreeSpan :: (Foldable tree) => tree (Tok t) -> (Point, Int)
getSubtreeSpan tree = (posnOfs firstOff, lastLine - firstLine)
    where bounds@[firstOff, _last] = fmap (tokPosn . assertJust)
                                     [getFirstElement tree, getLastElement tree]
          [firstLine, lastLine] = fmap posnLine bounds
          assertJust (Just x) = x
          assertJust _ = error "assertJust: Just expected"

-------------------------------------
-- Should be in Control.Applicative.?

sepBy :: (Alternative f) => f a -> f v -> f [a]
sepBy p s   = sepBy1 p s <|> pure []

sepBy1 :: (Alternative f) => f a -> f v -> f [a]
sepBy1 p s  = (:) <$> p <*> many (s *> p)


----------------------------------------------------
-- Testing code.

#ifdef TESTING

nodeRegion :: IsTree tree => Node (tree (Tok a)) -> Region
nodeRegion n = subtreeRegion t
    where Just t = walkDown n

data Test a = Empty | Leaf a | Bin (Test a) (Test a) deriving (Show, Eq, Foldable)

instance IsTree Test where
    uniplate (Bin l r) = ([l,r],\[l',r'] -> Bin l' r')
    uniplate t = ([],\[] -> t)
    emptyNode = Empty

type TT = Tok ()

instance Arbitrary (Test TT) where
    arbitrary = sized $ \size -> do
      arbitraryFromList [1..size+1]
    shrink (Leaf _) = []
    shrink (Bin l r) = [l,r] <>  (Bin <$> shrink l <*> pure r) <>  (Bin <$> pure l <*> shrink r)

tAt :: Point -> TT
tAt idx =  Tok () 1 (Posn (idx * 2) 0 0)

arbitraryFromList :: [Int] -> Gen (Test TT)
arbitraryFromList [] = error "arbitraryFromList expects non empty lists"
arbitraryFromList [x] = pure (Leaf (tAt (fromIntegral x)))
arbitraryFromList xs = do
  m <- choose (1,length xs - 1)
  let (l,r) = splitAt m xs
  Bin <$> arbitraryFromList l <*> arbitraryFromList r

newtype NTTT = N (Node (Test TT)) deriving Show

instance Arbitrary NTTT where
    arbitrary = do
      t <- arbitrary
      p <- arbitraryPath t
      return $ N (p,t)

arbitraryPath :: Test t -> Gen Path
arbitraryPath (Leaf _) = return []
arbitraryPath (Bin l r) = do
  c <- choose (0,1)
  let Just n' = index c [l,r]
  (c :) <$> arbitraryPath n'

regionInside :: Region -> Gen Region
regionInside r = do
  b :: Int <- choose (fromIntegral $ regionStart r, fromIntegral $ regionEnd r)
  e :: Int <- choose (b, fromIntegral $ regionEnd r)
  return $ mkRegion (fromIntegral b) (fromIntegral e)

pointInside :: Region -> Gen Point
pointInside r = do
  p :: Int <- choose (fromIntegral $ regionStart r, fromIntegral $ regionEnd r)
  return (fromIntegral p)

prop_fromLeafAfterToFinal :: NTTT -> Property
prop_fromLeafAfterToFinal (N n) = let
    fullRegion = subtreeRegion $ snd n
 in forAll (pointInside fullRegion) $ \p -> do
   let final@(_, (_, finalSubtree)) = fromLeafAfterToFinal p n
       finalRegion = subtreeRegion finalSubtree
       initialRegion = nodeRegion n

   whenFail (do putStrLn $ "final = " <> show final
                putStrLn $ "final reg = " <> show finalRegion
                putStrLn $ "initialReg = " <> show initialRegion
                putStrLn $ "p = " <> show p
            )
     ((regionStart finalRegion <= p) && (initialRegion `includedRegion` finalRegion))

prop_allLeavesAfter :: NTTT -> Property
prop_allLeavesAfter (N n@(xs,t)) = property $ do
  let after = allLeavesRelative afterChild n
  (xs',t') <- elements after
  let t'' = walkDown (xs',t)
  unProperty $ whenFail (do
      putStrLn $ "t' = " <> show t'
      putStrLn $ "t'' = " <> show t''
      putStrLn $ "xs' = " <> show xs'
    ) (Just t' == t'' && xs <= xs')

prop_allLeavesBefore :: NTTT -> Property
prop_allLeavesBefore (N n@(xs,t)) = property $ do
  let after = allLeavesRelative beforeChild n
  (xs',t') <- elements after
  let t'' = walkDown (xs',t)
  unProperty $ whenFail (do
      putStrLn $ "t' = " <> show t'
      putStrLn $ "t'' = " <> show t''
      putStrLn $ "xs' = " <> show xs'
    ) (Just t' == t'' && xs' <= xs)

prop_fromNodeToLeafAfter :: NTTT -> Property
prop_fromNodeToLeafAfter (N n) = forAll (pointInside (subtreeRegion $ snd n)) $ \p -> do
   let after = fromLeafToLeafAfter p n
       afterRegion = nodeRegion after
   whenFail (do putStrLn $ "after = " <> show after
                putStrLn $ "after reg = " <> show afterRegion
            )
     (regionStart afterRegion >= p)

prop_fromNodeToFinal :: NTTT -> Property
prop_fromNodeToFinal  (N t) = forAll (regionInside (subtreeRegion $ snd t)) $ \r -> do
   let final@(_, finalSubtree) = fromNodeToFinal r t
       finalRegion = subtreeRegion finalSubtree
   whenFail (do putStrLn $ "final = " <> show final
                putStrLn $ "final reg = " <> show finalRegion
                putStrLn $ "leaf after = " <> show (fromLeafToLeafAfter (regionEnd r) t)
            ) $ do
     r `includedRegion` finalRegion

#endif