{-# LANGUAGE MultiParamTypeClasses, ScopedTypeVariables #-}

-- |
-- Module: Language.KURE.Walker
-- Copyright: (c) 2012--2013 The University of Kansas
-- License: BSD3
--
-- Maintainer: Neil Sculthorpe <neil@ittc.ku.edu>
-- Stability: beta
-- Portability: ghc
--
-- This module provides combinators that traverse a tree.
--
-- Note that all traversals take place on the node, its children, or its descendents.
-- Deliberately, there is no mechanism for \"ascending\" the tree.

module Language.KURE.Walker
        (
        -- * Shallow Traversals

        -- ** Tree Walkers
          Walker(..)
        -- ** Child Transformations
        , childR
        , childT

        -- * Deep Traversals

        -- ** Rewrite Traversals
        , alltdR
        , allbuR
        , allduR
        , anytdR
        , anybuR
        , anyduR
        , onetdR
        , onebuR
        , prunetdR
        , innermostR
        , allLargestR
        , anyLargestR
        , oneLargestR

        -- ** Translate Traversals
        , foldtdT
        , foldbuT
        , onetdT
        , onebuT
        , prunetdT
        , crushtdT
        , crushbuT
        , collectT
        , collectPruneT
        , allLargestT
        , oneLargestT

        -- * Utilitity Translations
        , numChildrenT
        , hasChildT
        , summandIsTypeT

        -- * Paths
        -- ** Absolute Paths
        , AbsolutePath
        , rootAbsPath
        , PathContext(..)
        , absPathT
        -- ** Relative Paths
        , Path
        , rootPath
        , rootPathT
        , pathsToT
        , onePathToT
        , oneNonEmptyPathToT
        , prunePathsToT
        , uniquePathToT
        , uniquePrunePathToT

        -- ** Building Lenses from Paths
        , pathL
        , exhaustPathL
        , repeatPathL
        , rootL

        -- ** Applying transformations at the end of Paths
        , pathR
        , pathT

        -- ** Testing Paths
        , testPathT
) where

import Prelude hiding (id)

import Data.Maybe (isJust)
import Data.Monoid
import Data.List
import Data.DList (singleton, toList)

import Control.Monad
import Control.Arrow
import Control.Category hiding ((.))

import Language.KURE.MonadCatch
import Language.KURE.Translate
import Language.KURE.Lens
import Language.KURE.Injection
import Language.KURE.Combinators

-------------------------------------------------------------------------------

-- | 'Walker' captures the ability to walk over a tree containing nodes of type @g@,
--   using a specific context @c@.
--
--   Minimal complete definition: 'allR'.
--
--   Default definitions are provided for 'anyR', 'oneR', 'allT', 'oneT', and 'childL',
--   but they may be overridden for efficiency.

class Walker c g where

  -- | Apply a 'Rewrite' to all immediate children, succeeding if they all succeed.
  allR :: MonadCatch m => Rewrite c m g -> Rewrite c m g

  -- | Apply a 'Translate' to all immediate children, succeeding if they all succeed.
  --   The results are combined in a 'Monoid'.
  allT :: (MonadCatch m, Monoid b) => Translate c m g b -> Translate c m g b
  allT = unwrapAllT . allR . wrapAllT
  {-# INLINE allT #-}

  -- | Apply a 'Translate' to the first immediate child for which it can succeed.
  oneT :: MonadCatch m => Translate c m g b -> Translate c m g b
  oneT = unwrapOneT . allR . wrapOneT
  {-# INLINE oneT #-}

  -- | Apply a 'Rewrite' to all immediate children, suceeding if any succeed.
  anyR :: MonadCatch m => Rewrite c m g -> Rewrite c m g
  anyR = unwrapAnyR . allR . wrapAnyR
  {-# INLINE anyR #-}

  -- | Apply a 'Rewrite' to the first immediate child for which it can succeed.
  oneR :: MonadCatch m => Rewrite c m g -> Rewrite c m g
  oneR = unwrapOneR . allR . wrapOneR
  {-# INLINE oneR #-}

  -- | Construct a 'Lens' to the n-th child node.
  childL :: MonadCatch m => Int -> Lens c m g g
  childL = childL_default
  {-# INLINE childL #-}

------------------------------------------------------------------------------------------

-- | Count the number of children of the current node.
numChildrenT :: (Walker c g, MonadCatch m) => Translate c m g Int
numChildrenT = getSum `liftM` allT (return $ Sum 1)
{-# INLINE numChildrenT #-}

-- | Determine if the current node has a child of the specified number.
--   Useful when defining custom versions of 'childL'.
hasChildT :: (Walker c g, MonadCatch m) => Int -> Translate c m g Bool
hasChildT n = do c <- numChildrenT
                 return (n >= 0 && n < c)
{-# INLINE hasChildT #-}

-------------------------------------------------------------------------------

-- | Apply a 'Translate' to a specified child.
childT :: (Walker c g, MonadCatch m) => Int -> Translate c m g b -> Translate c m g b
childT n = focusT (childL n)
{-# INLINE childT #-}

-- | Apply a 'Rewrite' to a specified child.
childR :: (Walker c g, MonadCatch m) => Int -> Rewrite c m g -> Rewrite c m g
childR n = focusR (childL n)
{-# INLINE childR #-}

-------------------------------------------------------------------------------

-- | Fold a tree in a top-down manner, using a single 'Translate' for each node.
foldtdT :: (Walker c g, MonadCatch m, Monoid b) => Translate c m g b -> Translate c m g b
foldtdT t = prefixFailMsg "foldtdT failed: " $
            let go = t `mappend` allT go
             in go
{-# INLINE foldtdT #-}

-- | Fold a tree in a bottom-up manner, using a single 'Translate' for each node.
foldbuT :: (Walker c g, MonadCatch m, Monoid b) => Translate c m g b -> Translate c m g b
foldbuT t = prefixFailMsg "foldbuT failed: " $
            let go = allT go `mappend` t
             in go
{-# INLINE foldbuT #-}

-- | Apply a 'Translate' to the first node for which it can succeed, in a top-down traversal.
onetdT :: (Walker c g, MonadCatch m) => Translate c m g b -> Translate c m g b
onetdT t = setFailMsg "onetdT failed" $
           let go = t <+ oneT go
            in go
{-# INLINE onetdT #-}

-- | Apply a 'Translate' to the first node for which it can succeed, in a bottom-up traversal.
onebuT :: (Walker c g, MonadCatch m) => Translate c m g b -> Translate c m g b
onebuT t = setFailMsg "onebuT failed" $
           let go = oneT go <+ t
            in go
{-# INLINE onebuT #-}

-- | Attempt to apply a 'Translate' in a top-down manner, pruning at successes.
prunetdT :: (Walker c g, MonadCatch m, Monoid b) => Translate c m g b -> Translate c m g b
prunetdT t = setFailMsg "prunetdT failed" $
             let go = t <+ allT go
              in go
{-# INLINE prunetdT #-}

-- | An always successful top-down fold, replacing failures with 'mempty'.
crushtdT :: (Walker c g, MonadCatch m, Monoid b) => Translate c m g b -> Translate c m g b
crushtdT t = foldtdT (mtryM t)
{-# INLINE crushtdT #-}

-- | An always successful bottom-up fold, replacing failures with 'mempty'.
crushbuT :: (Walker c g, MonadCatch m, Monoid b) => Translate c m g b -> Translate c m g b
crushbuT t = foldbuT (mtryM t)
{-# INLINE crushbuT #-}

-- | An always successful traversal that collects the results of all successful applications of a 'Translate' in a list.
collectT :: (Walker c g, MonadCatch m) => Translate c m g b -> Translate c m g [b]
collectT t = crushtdT (t >>^ singleton) >>^ toList
{-# INLINE collectT #-}

-- | Like 'collectT', but does not traverse below successes.
collectPruneT :: (Walker c g, MonadCatch m) => Translate c m g b -> Translate c m g [b]
collectPruneT t = prunetdT (t >>^ singleton) >>^ toList
{-# INLINE collectPruneT #-}

-------------------------------------------------------------------------------

-- | Apply a 'Rewrite' in a top-down manner, succeeding if they all succeed.
alltdR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
alltdR r = prefixFailMsg "alltdR failed: " $
           let go = r >>> allR go
            in go
{-# INLINE alltdR #-}

-- | Apply a 'Rewrite' in a bottom-up manner, succeeding if they all succeed.
allbuR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
allbuR r = prefixFailMsg "allbuR failed: " $
           let go = allR go >>> r
            in go
{-# INLINE allbuR #-}

-- | Apply a 'Rewrite' twice, in a top-down and bottom-up way, using one single tree traversal,
--   succeeding if they all succeed.
allduR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
allduR r = prefixFailMsg "allduR failed: " $
           let go = r >>> allR go >>> r
            in go
{-# INLINE allduR #-}

-- | Apply a 'Rewrite' in a top-down manner, succeeding if any succeed.
anytdR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
anytdR r = setFailMsg "anytdR failed" $
           let go = r >+> anyR go
            in go
{-# INLINE anytdR #-}

-- | Apply a 'Rewrite' in a bottom-up manner, succeeding if any succeed.
anybuR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
anybuR r = setFailMsg "anybuR failed" $
           let go = anyR go >+> r
            in go
{-# INLINE anybuR #-}

-- | Apply a 'Rewrite' twice, in a top-down and bottom-up way, using one single tree traversal,
--   succeeding if any succeed.
anyduR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
anyduR r = setFailMsg "anyduR failed" $
           let go = r >+> anyR go >+> r
            in go
{-# INLINE anyduR #-}

-- | Apply a 'Rewrite' to the first node for which it can succeed, in a top-down traversal.
onetdR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
onetdR r = setFailMsg "onetdR failed" $
           let go = r <+ oneR go
            in go
{-# INLINE onetdR #-}

-- | Apply a 'Rewrite' to the first node for which it can succeed, in a bottom-up traversal.
onebuR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
onebuR r = setFailMsg "onebuR failed" $
           let go = oneR go <+ r
            in go
{-# INLINE onebuR #-}

-- | Attempt to apply a 'Rewrite' in a top-down manner, pruning at successful rewrites.
prunetdR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
prunetdR r = setFailMsg "prunetdR failed" $
             let go = r <+ anyR go
              in go
{-# INLINE prunetdR #-}

-- | A fixed-point traveral, starting with the innermost term.
innermostR :: (Walker c g, MonadCatch m) => Rewrite c m g -> Rewrite c m g
innermostR r = setFailMsg "innermostR failed" $
               let go = anybuR (r >>> tryR go)
                in go
{-# INLINE innermostR #-}

-------------------------------------------------------------------------------

-- | A path from the root.
newtype AbsolutePath = AbsolutePath [Int] deriving Eq

instance Show AbsolutePath where
  show (AbsolutePath p) = show (reverse p)
  {-# INLINE show #-}

-- | The (empty) 'AbsolutePath' to the root.
rootAbsPath :: AbsolutePath
rootAbsPath = AbsolutePath []
{-# INLINE rootAbsPath #-}


-- | Contexts that are instances of 'PathContext' contain the current 'AbsolutePath'.
--   Any user-defined combinators (typically 'allR' and congruence combinators) should update the 'AbsolutePath' using '@@'.
class PathContext c where
  -- | Retrieve the current absolute path.
  absPath :: c -> AbsolutePath

  -- | Extend the current absolute path by one descent.
  (@@) :: c -> Int -> c

-- | The simplest instance of 'PathContext' is 'AbsolutePath' itself.
instance PathContext AbsolutePath where
-- absPath :: AbsolutePath -> AbsolutePath
   absPath = id
   {-# INLINE absPath #-}

-- (@@) :: AbsolutePath -> Int -> AbsolutePath
   (AbsolutePath ns) @@ n = AbsolutePath (n:ns)
   {-# INLINE (@@) #-}

-- | Lifted version of 'absPath'.
absPathT :: (PathContext c, Monad m) => Translate c m a AbsolutePath
absPathT = absPath `liftM` contextT
{-# INLINE absPathT #-}

-------------------------------------------------------------------------------

-- | A path is a route to descend the tree from an arbitrary node.
type Path = [Int]

-- | Retrieve the 'Path' from the root to the current node.
rootPath :: PathContext c => c -> Path
rootPath c = let AbsolutePath p = absPath c
              in reverse p
{-# INLINE rootPath #-}

-- | Lifted version of 'rootPath'.
rootPathT :: (PathContext c, Monad m) => Translate c m a Path
rootPathT = rootPath `liftM` contextT
{-# INLINE rootPathT #-}

--  Provided the first 'AbsolutePath' is a prefix of the second 'AbsolutePath',
--  computes the 'Path' from the end of the first to the end of the second.
rmPathPrefix :: AbsolutePath -> AbsolutePath -> Maybe Path
rmPathPrefix (AbsolutePath p1) (AbsolutePath p2) = do guard (p1 `isSuffixOf` p2)
                                                      return $ drop (length p1) (reverse p2)
{-# INLINE rmPathPrefix #-}

--  Construct a 'Path' from the current node to the end of the given 'AbsolutePath', provided that 'AbsolutePath' passes through the current node.
abs2pathT :: (PathContext c, Monad m) => AbsolutePath -> Translate c m a Path
abs2pathT there = do here <- absPathT
                     maybe (fail "Absolute path does not pass through current node.") return (rmPathPrefix here there)
{-# INLINE abs2pathT #-}

-- | Find the 'Path's to every node that satisfies the predicate.
pathsToT :: (PathContext c, Walker c g, MonadCatch m) => (g -> Bool) -> Translate c m g [Path]
pathsToT q = collectT (acceptR q >>> absPathT) >>= mapM abs2pathT
{-# INLINE pathsToT #-}

-- | Find the 'Path' to the first node that satisfies the predicate (in a pre-order traversal).
onePathToT :: (PathContext c, Walker c g, MonadCatch m) => (g -> Bool) -> Translate c m g Path
onePathToT q = setFailMsg "No matching nodes found." $
               onetdT (acceptR q >>> absPathT) >>= abs2pathT
{-# INLINE onePathToT #-}

-- | Find the 'Path' to the first descendent node that satisfies the predicate (in a pre-order traversal).
oneNonEmptyPathToT :: (PathContext c, Walker c g, MonadCatch m) => (g -> Bool) -> Translate c m g Path
oneNonEmptyPathToT q = setFailMsg "No matching nodes found." $
                       do start <- absPathT
                          onetdT (acceptR q >>> absPathT >>> acceptR (/= start)) >>= abs2pathT
{-# INLINE oneNonEmptyPathToT #-}

-- | Find the 'Path's to every node that satisfies the predicate, ignoring nodes below successes.
prunePathsToT :: (PathContext c, Walker c g, MonadCatch m) => (g -> Bool) -> Translate c m g [Path]
prunePathsToT q = collectPruneT (acceptR q >>> absPathT) >>= mapM abs2pathT
{-# INLINE prunePathsToT #-}

-- local function used by uniquePathToT and uniquePrunePathToT
requireUniquePath :: Monad m => Translate c m [Path] Path
requireUniquePath = contextfreeT $ \ ps -> case ps of
                                             []  -> fail "No matching nodes found."
                                             [p] -> return p
                                             _   -> fail $ "Ambiguous: " ++ show (length ps) ++ " matching nodes found."
{-# INLINE requireUniquePath #-}

-- | Find the 'Path' to the node that satisfies the predicate, failing if that does not uniquely identify a node.
uniquePathToT :: (PathContext c, Walker c g, MonadCatch m) => (g -> Bool) -> Translate c m g Path
uniquePathToT q = pathsToT q >>> requireUniquePath
{-# INLINE uniquePathToT #-}

-- | Build a 'Path' to the node that satisfies the predicate, failing if that does not uniquely identify a node (ignoring nodes below successes).
uniquePrunePathToT :: (PathContext c, Walker c g, MonadCatch m) => (g -> Bool) -> Translate c m g Path
uniquePrunePathToT q = prunePathsToT q >>> requireUniquePath
{-# INLINE uniquePrunePathToT #-}

-------------------------------------------------------------------------------

tryL :: MonadCatch m => Lens c m g g -> Lens c m g g
tryL l = l `catchL` (\ _ -> id)
{-# INLINE tryL #-}

-- | Construct a 'Lens' by following a 'Path'.
pathL :: (Walker c g, MonadCatch m) => Path -> Lens c m g g
pathL = serialise . map childL
{-# INLINE pathL #-}

-- | Construct a 'Lens' that points to the last node at which the 'Path' can be followed.
exhaustPathL :: (Walker c g, MonadCatch m) => Path -> Lens c m g g
exhaustPathL = foldr (\ n l -> tryL (childL n >>> l)) id
{-# INLINE exhaustPathL #-}

-- | Repeat as many iterations of the 'Path' as possible.
repeatPathL :: (Walker c g, MonadCatch m) => Path -> Lens c m g g
repeatPathL p = let go = tryL (pathL p >>> go)
                 in go
{-# INLINE repeatPathL #-}

-- | Build a 'Lens' from the root to a point specified by an 'AbsolutePath'.
rootL :: (Walker c g, MonadCatch m) => AbsolutePath -> Lens c m g g
rootL = pathL . rootPath
{-# INLINE rootL #-}

-------------------------------------------------------------------------------

-- | Apply a 'Rewrite' at a point specified by a 'Path'.
pathR :: (Walker c g, MonadCatch m) => Path -> Rewrite c m g -> Rewrite c m g
pathR = focusR . pathL
{-# INLINE pathR #-}

-- | Apply a 'Translate' at a point specified by a 'Path'.
pathT :: (Walker c g, MonadCatch m) => Path -> Translate c m g b -> Translate c m g b
pathT = focusT . pathL
{-# INLINE pathT #-}

-------------------------------------------------------------------------------

-- | Check if it is possible to construct a 'Lens' along this path from the current node.
testPathT :: (Walker c g, MonadCatch m) => Path -> Translate c m g Bool
testPathT = testLensT . pathL
{-# INLINE testPathT #-}

-------------------------------------------------------------------------------

-- | Apply a 'Rewrite' to the largest node(s) that satisfy the predicate, requiring all to succeed.
allLargestR :: (Walker c g, MonadCatch m) => Translate c m g Bool -> Rewrite c m g -> Rewrite c m g
allLargestR p r = prefixFailMsg "allLargestR failed: " $
                  let go = ifM p r (allR go)
                   in go
{-# INLINE allLargestR #-}

-- | Apply a 'Rewrite' to the largest node(s) that satisfy the predicate, succeeding if any succeed.
anyLargestR :: (Walker c g, MonadCatch m) => Translate c m g Bool -> Rewrite c m g -> Rewrite c m g
anyLargestR p r = setFailMsg "anyLargestR failed" $
                  let go = ifM p r (anyR go)
                   in go
{-# INLINE anyLargestR #-}

-- | Apply a 'Rewrite' to the first node for which it can succeed among the largest node(s) that satisfy the predicate.
oneLargestR :: (Walker c g, MonadCatch m) => Translate c m g Bool -> Rewrite c m g -> Rewrite c m g
oneLargestR p r = setFailMsg "oneLargestR failed" $
                  let go = ifM p r (oneR go)
                   in go
{-# INLINE oneLargestR #-}

-- | Apply a 'Translate' to the largest node(s) that satisfy the predicate, combining the results in a monoid.
allLargestT :: (Walker c g, MonadCatch m, Monoid b) => Translate c m g Bool -> Translate c m g b -> Translate c m g b
allLargestT p t = prefixFailMsg "allLargestT failed: " $
                  let go = ifM p t (allT go)
                   in go
{-# INLINE allLargestT #-}

-- | Apply a 'Translate' to the first node for which it can succeed among the largest node(s) that satisfy the predicate.
oneLargestT :: (Walker c g, MonadCatch m) => Translate c m g Bool -> Translate c m g b -> Translate c m g b
oneLargestT p t = setFailMsg "oneLargestT failed" $
                  let go = ifM p t (oneT go)
                   in go
{-# INLINE oneLargestT #-}

-- | Test if the type of the current node summand matches the type of the argument.
--   Note that the argument /value/ is never inspected, it is merely a proxy for a type argument.
summandIsTypeT :: forall c m a g. (MonadCatch m, Injection a g) => a -> Translate c m g Bool
summandIsTypeT _ = arr (isJust . (project :: (g -> Maybe a)))
{-# INLINE summandIsTypeT #-}

-------------------------------------------------------------------------------

data P a b = P a b

pSnd :: P a b -> b
pSnd (P _ b) = b
{-# INLINE pSnd #-}

checkSuccessPMaybe :: Monad m => String -> m (Maybe a) -> m a
checkSuccessPMaybe msg ma = ma >>= projectWithFailMsgM msg
{-# INLINE checkSuccessPMaybe #-}

-------------------------------------------------------------------------------

-- These are used for defining 'allT' in terms of 'allR'.
-- However, they are unlikely to be of use to the KURE user.

newtype AllT w m a = AllT (m (P a w))

unAllT :: AllT w m a -> m (P a w)
unAllT (AllT mw) = mw
{-# INLINE unAllT #-}

instance (Monoid w, Monad m) => Monad (AllT w m) where
-- return :: a -> AllT w m a
   return a = AllT $ return (P a mempty)
   {-# INLINE return #-}

-- fail :: String -> AllT w m a
   fail = AllT . fail
   {-# INLINE fail #-}

-- (>>=) :: AllT w m a -> (a -> AllT w m d) -> AllT w m d
   ma >>= f = AllT $ do P a w1 <- unAllT ma
                        P d w2 <- unAllT (f a)
                        return (P d (w1 <> w2))
   {-# INLINE (>>=) #-}

instance (Monoid w, MonadCatch m) => MonadCatch (AllT w m) where
-- catchM :: AllT w m a -> (String -> AllT w m a) -> AllT w m a
   catchM (AllT ma) f = AllT $ ma `catchM` (unAllT . f)
   {-# INLINE catchM #-}


-- | Wrap a 'Translate' using the 'AllT' monad transformer.
wrapAllT :: Monad m => Translate c m g b -> Rewrite c (AllT b m) g
wrapAllT t = readerT $ \ a -> resultT (AllT . liftM (P a)) t
{-# INLINE wrapAllT #-}

-- | Unwrap a 'Translate' from the 'AllT' monad transformer.
unwrapAllT :: MonadCatch m => Rewrite c (AllT b m) g -> Translate c m g b
unwrapAllT = prefixFailMsg "allT failed:" . resultT (liftM pSnd . unAllT)
{-# INLINE unwrapAllT #-}

-------------------------------------------------------------------------------

-- We could probably build this on top of OneR or AllT

-- These are used for defining 'oneT' in terms of 'allR'.
-- However, they are unlikely to be of use to the KURE user.

newtype OneT w m a = OneT (Maybe w -> m (P a (Maybe w)))

unOneT :: OneT w m a -> Maybe w -> m (P a (Maybe w))
unOneT (OneT f) = f
{-# INLINE unOneT #-}

instance Monad m => Monad (OneT w m) where
-- return :: a -> OneT w m a
   return a = OneT $ \ mw -> return (P a mw)
   {-# INLINE return #-}

-- fail :: String -> OneT w m a
   fail msg = OneT (\ _ -> fail msg)
   {-# INLINE fail #-}

-- (>>=) :: OneT w m a -> (a -> OneT w m d) -> OneT w m d
   ma >>= f = OneT $ do \ mw1 -> do P a mw2 <- unOneT ma mw1
                                    unOneT (f a) mw2
   {-# INLINE (>>=) #-}

instance MonadCatch m => MonadCatch (OneT w m) where
-- catchM :: OneT w m a -> (String -> OneT w m a) -> OneT w m a
   catchM (OneT g) f = OneT $ \ mw -> g mw `catchM` (($ mw) . unOneT . f)
   {-# INLINE catchM #-}


-- | Wrap a 'Translate' using the 'OneT' monad transformer.
wrapOneT :: MonadCatch m => Translate c m g b -> Rewrite c (OneT b m) g
wrapOneT t = rewrite $ \ c a -> OneT $ \ mw -> case mw of
                                                 Just w  -> return (P a (Just w))
                                                 Nothing -> ((P a . Just) `liftM` apply t c a) <+ return (P a mw)
{-# INLINE wrapOneT #-}

-- | Unwrap a 'Translate' from the 'OneT' monad transformer.
unwrapOneT :: Monad m => Rewrite c (OneT b m) g -> Translate c m g b
unwrapOneT = resultT (checkSuccessPMaybe "oneT failed" . liftM pSnd . ($ Nothing) . unOneT)
{-# INLINE unwrapOneT #-}

-------------------------------------------------------------------------------

data PInt a = PInt {-# UNPACK #-} !Int a

secondPInt :: (a -> b) -> PInt a -> PInt b
secondPInt f = \ (PInt i a) -> PInt i (f a)
{-# INLINE secondPInt #-}

-------------------------------------------------------------------------------

-- This is hideous.
-- Admittedly, part of the problem is using MonadCatch.  If allR just used Monad, this (and other things) would be much simpler.
-- And currently, the only use of MonadCatch is that it allows the error message to be modified.

-- Failure should not occur, so it doesn't really matter where the KureM monad sits in the GetChild stack.
-- I've arbitrarily made it a local failure.

newtype GetChild c g a = GetChild (Int -> PInt (KureM a, Maybe (c,g)))

unGetChild :: GetChild c g a -> Int -> PInt (KureM a, Maybe (c,g))
unGetChild (GetChild f) = f
{-# INLINE unGetChild #-}

instance Monad (GetChild c g) where
-- return :: a -> GetChild c g a
   return a = GetChild $ \ i -> PInt i (return a, Nothing)
   {-# INLINE return #-}

-- fail :: String -> GetChild c g a
   fail msg = GetChild $ \ i -> PInt i (fail msg, Nothing)
   {-# INLINE fail #-}

-- (>>=) :: GetChild c g a -> (a -> GetChild c g b) -> GetChild c g b
   ma >>= f = GetChild $ \ i0 -> let PInt i1 (kma, mcg) = unGetChild ma i0
                                  in runKureM (\ a   -> (secondPInt.second) (mplus mcg) $ unGetChild (f a) i1)
                                              (\ msg -> PInt i1 (fail msg, mcg))
                                              kma
   {-# INLINE (>>=) #-}

instance MonadCatch (GetChild c g) where
-- catchM :: GetChild c g a -> (String -> GetChild c g a) -> GetChild c g a
   ma `catchM` f = GetChild $ \ i0 -> let p@(PInt i1 (kma, mcg)) = unGetChild ma i0
                                       in runKureM (\ _   -> p)
                                                   (\ msg -> (secondPInt.second) (mplus mcg) $ unGetChild (f msg) i1)
                                                   kma
   {-# INLINE catchM #-}


wrapGetChild :: Int -> Rewrite c (GetChild c g) g
wrapGetChild n = rewrite $ \ c a -> GetChild $ \ m -> PInt (m + 1)
                                                           (return a, if n == m then Just (c, a) else Nothing)
{-# INLINE wrapGetChild #-}

unwrapGetChild :: Rewrite c (GetChild c g) g -> Translate c Maybe g (c,g)
unwrapGetChild r = translate $ \ c a -> let PInt _ (_,mcg) = unGetChild (apply r c a) 0
                                         in mcg
{-# INLINE unwrapGetChild #-}

getChild :: Walker c g => Int -> Translate c Maybe g (c, g)
getChild = unwrapGetChild . allR . wrapGetChild
{-# INLINE getChild #-}

-------------------------------------------------------------------------------

newtype SetChild a = SetChild (Int -> PInt (KureM a))

unSetChild :: SetChild a -> Int -> PInt (KureM a)
unSetChild (SetChild f) = f
{-# INLINE unSetChild #-}

instance Monad SetChild where
-- return :: a -> SetChild c g a
   return a = SetChild $ \ i -> PInt i (return a)
   {-# INLINE return #-}

-- fail :: String -> SetChild c g a
   fail msg = SetChild $ \ i -> PInt i (fail msg)
   {-# INLINE fail #-}

-- (>>=) :: SetChild c g a -> (a -> SetChild c g b) -> SetChild c g b
   ma >>= f = SetChild $ \ i0 -> let PInt i1 ka = unSetChild ma i0
                                  in runKureM (\ a   -> unSetChild (f a) i1)
                                              (\ msg -> PInt i1 (fail msg))
                                              ka
   {-# INLINE (>>=) #-}

instance MonadCatch SetChild where
-- catchM :: SetChild c g a -> (String -> SetChild c g a) -> SetChild c g a
   ma `catchM` f = SetChild $ \ i0 -> let PInt i1 ka = unSetChild ma i0
                                       in runKureM (\ _   -> PInt i1 ka)
                                                   (\ msg -> unSetChild (f msg) i1)
                                                   ka
   {-# INLINE catchM #-}


wrapSetChild :: Int -> g -> Rewrite c SetChild g
wrapSetChild n g = contextfreeT $ \ a -> SetChild $ \ m -> PInt (m + 1)
                                                                (return $ if n == m then g else a)
{-# INLINE wrapSetChild #-}

unwrapSetChild :: Monad m => Rewrite c SetChild g -> Rewrite c m g
unwrapSetChild r = rewrite $ \ c a -> let PInt _ ka = unSetChild (apply r c a) 0
                                       in runKureM return fail ka
{-# INLINE unwrapSetChild #-}

setChild :: (Walker c g, Monad m) => Int -> g -> Rewrite c m g
setChild n = unwrapSetChild . allR . wrapSetChild n
{-# INLINE setChild #-}

-------------------------------------------------------------------------------

childL_default :: forall c m g. (Walker c g, MonadCatch m) => Int -> Lens c m g g
childL_default n = lens $ do cg <- getter
                             k  <- setter
                             return (cg, k)
  where
    getter :: Translate c m g (c,g)
    getter = translate $ \ c a -> maybe (fail $ "there is no child number " ++ show n) return (apply (getChild n) c a)
    {-# INLINE getter #-}

    setter :: Translate c m g (g -> m g)
    setter = translate $ \ c a -> return (\ b -> apply (setChild n b) c a)
    {-# INLINE setter #-}

{-# INLINE childL_default #-}

-------------------------------------------------------------------------------