-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Retrie.SYB
  ( everywhereMWithContextBut
  , GenericCU
  , GenericMC
  , Strategy
  , topDown
  , bottomUp
  , everythingMWithContextBut
  , GenericMCQ
  , module Data.Generics
  ) where

import Control.Monad
import Data.Generics hiding (Fixity(..))

-- | Monadic rewrite with context
type GenericMC m c = forall a. Data a => c -> a -> m a

-- | Context update:
-- Given current context, child number, and parent, create new context
type GenericCU m c = forall a. Data a => c -> Int -> a -> m c

-- | Monadic traversal with pruning and context propagation.
everywhereMWithContextBut
  :: forall m c. Monad m
  => Strategy m    -- ^ Traversal order (see 'topDown' and 'bottomUp')
  -> GenericQ Bool -- ^ Short-circuiting stop condition
  -> GenericCU m c -- ^ Context update function
  -> GenericMC m c -- ^ Context-aware rewrite
  -> GenericMC m c
everywhereMWithContextBut :: forall (m :: * -> *) c.
Monad m =>
Strategy m
-> GenericQ Bool -> GenericCU m c -> GenericMC m c -> GenericMC m c
everywhereMWithContextBut Strategy m
strategy GenericQ Bool
stop GenericCU m c
upd GenericMC m c
f = GenericMC m c
go
  where
    go :: GenericMC m c
    go :: GenericMC m c
go c
ctxt a
x
      | GenericQ Bool
stop a
x    = forall (m :: * -> *) a. Monad m => a -> m a
return a
x
      | Bool
otherwise = Strategy m
strategy (GenericMC m c
f c
ctxt) (GenericMC m c
h c
ctxt) a
x

    h :: c -> a -> m a
h c
ctxt a
parent = forall (m :: * -> *) a.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m d) -> m a
gforMIndexed a
parent forall a b. (a -> b) -> a -> b
$ \Int
i d
child -> do
      c
ctxt' <- GenericCU m c
upd c
ctxt Int
i a
parent
      GenericMC m c
go c
ctxt' d
child

type GenericMCQ m c r = forall a. Data a => c -> a -> m r

-- | Monadic query with pruning and context propagation.
everythingMWithContextBut
  :: forall m c r. (Monad m, Monoid r)
  => GenericQ Bool -- ^ Short-circuiting stop condition
  -> GenericCU m c -- ^ Context update function
  -> GenericMCQ m c r -- ^ Context-aware query
  -> GenericMCQ m c r
everythingMWithContextBut :: forall (m :: * -> *) c r.
(Monad m, Monoid r) =>
GenericQ Bool
-> GenericCU m c -> GenericMCQ m c r -> GenericMCQ m c r
everythingMWithContextBut GenericQ Bool
stop GenericCU m c
upd GenericMCQ m c r
q = GenericMCQ m c r
go
  where
    go :: GenericMCQ m c r
    go :: GenericMCQ m c r
go c
ctxt a
x
      | GenericQ Bool
stop a
x = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
      | Bool
otherwise = do
        r
r <- GenericMCQ m c r
q c
ctxt a
x
        [r]
rs <- forall (m :: * -> *) a r.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m r) -> m [r]
gforQIndexed a
x forall a b. (a -> b) -> a -> b
$ \Int
i d
child -> do
          c
ctxt' <- GenericCU m c
upd c
ctxt Int
i a
x
          GenericMCQ m c r
go c
ctxt' d
child
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat (r
rforall a. a -> [a] -> [a]
:[r]
rs)

-- | Traversal strategy.
-- Given a rewrite on the node and a rewrite on the node's children, define
-- a composite rewrite.
type Strategy m = forall a. Monad m => (a -> m a) -> (a -> m a) -> a -> m a

-- | Perform a top-down traversal.
topDown :: Strategy m
topDown :: forall (m :: * -> *). Strategy m
topDown a -> m a
p a -> m a
cs = a -> m a
p forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> m a
cs

-- | Perform a bottom-up traversal.
bottomUp :: Strategy m
bottomUp :: forall (m :: * -> *). Strategy m
bottomUp a -> m a
p a -> m a
cs = a -> m a
cs forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> m a
p

-- | 'gmapM' with arguments flipped and providing zero-based index of child
-- to mapped function.
gforMIndexed
  :: (Monad m, Data a) => a -> (forall d. Data d => Int -> d -> m d) -> m a
gforMIndexed :: forall (m :: * -> *) a.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m d) -> m a
gforMIndexed a
x forall d. Data d => Int -> d -> m d
f = forall a b. (a, b) -> b
snd (forall d (m :: * -> *) a.
(Data d, Monad m) =>
(forall e. Data e => a -> e -> (a, m e)) -> a -> d -> (a, m d)
gmapAccumM (forall a b. (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex forall d. Data d => Int -> d -> m d
f) (-Int
1) a
x)
-- -1 is constructor, 0 is first child

accumIndex :: (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex :: forall a b. (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex Int -> a -> b
f Int
i a
y = let !i' :: Int
i' = Int
iforall a. Num a => a -> a -> a
+Int
1 in (Int
i', Int -> a -> b
f Int
i' a
y)

gforQIndexed
  :: (Monad m, Data a) => a -> (forall d. Data d => Int -> d -> m r) -> m [r]
gforQIndexed :: forall (m :: * -> *) a r.
(Monad m, Data a) =>
a -> (forall d. Data d => Int -> d -> m r) -> m [r]
gforQIndexed a
x forall d. Data d => Int -> d -> m r
f = forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall d a q.
Data d =>
(forall e. Data e => a -> e -> (a, q)) -> a -> d -> (a, [q])
gmapAccumQ (forall a b. (Int -> a -> b) -> Int -> a -> (Int, b)
accumIndex forall d. Data d => Int -> d -> m r
f) (-Int
1) a
x