module Data.Comp.Automata
    ( module Data.Comp.Automata,
      module Data.Comp.Automata.Product
    ) where
import Data.Comp.Zippable
import Data.Comp.Automata.Product
import Data.Comp.Term
import Data.Comp.Algebra
import Data.Comp.Show ()
import Data.Map (Map)
import qualified Data.Map as Map
infix 1 |->
infixr 0 &
(&) :: Ord k => Map k v -> Map k v -> Map k v
(&) = Map.union
(|->) :: k -> a -> Map k a
(|->) = Map.singleton
o :: Map k a
o = Map.empty
below :: (?below :: a -> q, p :< q) => a -> p
below = pr . ?below
above :: (?above :: q, p :< q) => p
above = pr ?above
explicit :: q -> (a -> q) -> ((?above :: q, ?below :: a -> q) => b) -> b
explicit ab be x = x where ?above = ab; ?below = be
type QHom f q g = forall a . (?below :: a -> q, ?above :: q) => f a -> Context g a
type UpTrans f q g = forall a. f (q,a) -> (q, Context g a)
upAlg :: (Functor g)  => UpTrans f q g -> Alg f (q, Term g)
upAlg trans = fmap appCxt . trans 
runUpTrans :: (Functor f, Functor g) => UpTrans f q g -> Term f -> (q, Term g)
runUpTrans = cata . upAlg
runUpTrans' :: (Functor f, Functor g) => UpTrans f q g -> Context f (q,a) -> (q, Context g a)
runUpTrans' trans = run where
    run (Hole (q,a)) = (q, Hole a)
    run (Term t) = fmap appCxt $ trans $ fmap run t
compUpTrans :: (Functor f, Functor g, Functor h)
               => UpTrans g p h -> UpTrans f q g -> UpTrans f (q,p) h
compUpTrans t2 t1 x = ((q1,q2), c2) where
    (q1, c1) = t1 $ fmap (\((q1,q2),a) -> (q1,(q2,a))) x
    (q2, c2) = runUpTrans' t2 c1
type UpState f q = Alg f q
tagUpState :: (Functor f) => (q -> p) -> (p -> q) -> UpState f q -> UpState f p
tagUpState i o s = i . s . fmap o
runUpState :: (Functor f) => UpState f q -> Term f -> q
runUpState = cata
prodUpState :: Functor f => UpState f p -> UpState f q -> UpState f (p,q)
prodUpState sp sq t = (p,q) where
    p = sp $ fmap fst t
    q = sq $ fmap snd t
upTrans :: (Functor f, Functor g) => UpState f q -> QHom f q g -> UpTrans f q g
upTrans st f t = (q, c)
    where q = st $ fmap fst t
          c = fmap snd $ explicit q fst f t
runUpHom :: (Functor f, Functor g) => UpState f q -> QHom f q g -> Term f -> (q,Term g)
runUpHom alg h = runUpTrans (upTrans alg h)
type DUpState f p q = forall a . (?below :: a -> p, ?above :: p, q :< p) => f a -> q
dUpState :: Functor f => UpState f q -> DUpState f p q
dUpState f = f . fmap below
upState :: DUpState f q q -> UpState f q
upState f s = res where res = explicit res id f s
runDUpState :: Functor f => DUpState f q q -> Term f -> q
runDUpState = runUpState . upState
prodDUpState :: (p :< c, q :< c)
             => DUpState f c p -> DUpState f c q -> DUpState f c (p,q)
prodDUpState sp sq t = (sp t, sq t)
(<*>) :: (p :< c, q :< c)
             => DUpState f c p -> DUpState f c q -> DUpState f c (p,q)
(<*>) = prodDUpState
type DownTrans f q g = forall a. (q, f a) -> Context g (q,a)
runDownTrans :: (Functor f, Functor g) => DownTrans f q g -> q -> Cxt h f a -> Cxt h g a
runDownTrans tr q t = run (q,t) where
    run (q,Term t) = appCxt $ fmap run $  tr (q, t)
    run (_,Hole a)      = Hole a
runDownTrans' :: (Functor f, Functor g) => DownTrans f q g -> q -> Cxt h f a -> Cxt h g (q,a)
runDownTrans' tr q t = run (q,t) where
    run (q,Term t) = appCxt $ fmap run $  tr (q, t)
    run (q,Hole a)      = Hole (q,a)
compDownTrans :: (Functor f, Functor g, Functor h)
              => DownTrans g p h -> DownTrans f q g -> DownTrans f (q,p) h
compDownTrans t2 t1 ((q,p), t) = fmap (\(p, (q, a)) -> ((q,p),a)) $ runDownTrans' t2 p (t1 (q, t))
type DownState f q = forall a. Ord a => (q, f a) -> Map a q
tagDownState :: (q -> p) -> (p -> q) -> DownState f q -> DownState f p
tagDownState i o t (q,s) = fmap i $ t (o q,s)
prodDownState :: DownState f p -> DownState f q -> DownState f (p,q)
prodDownState sp sq ((p,q),t) = prodMap p q (sp (p, t)) (sq (q, t))
data ProdState p q = LState p
                   | RState q
                   | BState p q
prodMap :: (Ord i) => p -> q -> Map i p -> Map i q -> Map i (p,q)
prodMap p q mp mq = Map.map final $ Map.unionWith combine ps qs
    where ps = Map.map LState mp
          qs = Map.map RState mq
          combine (LState p) (RState q) = BState p q
          combine (RState q) (LState p) = BState p q
          combine _ _                   = error "unexpected merging"
          final (LState p) = (p, q)
          final (RState q) = (p, q)
          final (BState p q) = (p,q)
appMap :: Zippable f => (forall i . Ord i => f i -> Map i q)
                       -> q -> f b -> f (q,b)
appMap qmap q s = fmap qfun s'
    where s' = number s
          qfun k@(Numbered (_,a)) = (Map.findWithDefault q k (qmap s') ,a)
downTrans :: Zippable f => DownState f q -> QHom f q g -> DownTrans f q g
downTrans st f (q, s) = explicit q fst f (appMap (curry st q) q s)
runDownHom :: (Zippable f, Functor g)
            => DownState f q -> QHom f q g -> q -> Term f -> Term g
runDownHom st h = runDownTrans (downTrans st h)
type DDownState f p q = forall i . (Ord i, ?below :: i -> p, ?above :: p, q :< p)
                                => f i -> Map i q
dDownState :: DownState f q -> DDownState f p q
dDownState f t = f (above,t)
downState :: DDownState f q q -> DownState f q
downState f (q,s) = res
    where res = explicit q bel f s
          bel k = Map.findWithDefault q k res
prodDDownState :: (p :< c, q :< c)
               => DDownState f c p -> DDownState f c q -> DDownState f c (p,q)
prodDDownState sp sq t = prodMap above above (sp t) (sq t)
(>*<) :: (p :< c, q :< c, Functor f)
         => DDownState f c p -> DDownState f c q -> DDownState f c (p,q)
(>*<) = prodDDownState
runDState :: Zippable f => DUpState f (u,d) u -> DDownState f (u,d) d -> d -> Term f -> u
runDState up down d (Term t) = u where
        t' = fmap bel $ number t
        bel (Numbered (i,s)) = 
            let d' = Map.findWithDefault d (Numbered (i,undefined)) m
            in Numbered (i, (runDState up down d' s, d'))
        m = explicit (u,d) unNumbered down t'
        u = explicit (u,d) unNumbered up t'