{-# LANGUAGE EmptyDataDecls      #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Multi.Term
-- Copyright   :  (c) 2011 Patrick Bahr
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module defines the central notion of mutual recursive (or, higher-order)
-- /terms/ and its generalisation to (higher-order) contexts. All definitions
-- are generalised versions of those in "Data.Comp.Term".
--
--------------------------------------------------------------------------------

module Data.Comp.Multi.Term
    (Cxt (..),
     Hole,
     NoHole,
     Context,
     Term,
     Const,
     constTerm,
     unTerm,
     toCxt,
     simpCxt
     ) where

import Data.Comp.Multi.HFoldable
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.HTraversable

import Data.Kind

import Control.Monad

import Unsafe.Coerce

type Const (f :: (Type -> Type) -> Type -> Type) = f (K ())

-- | This function converts a constant to a term. This assumes that
-- the argument is indeed a constant, i.e. does not have a value for
-- the argument type of the functor f.

constTerm :: (HFunctor f) => Const f :-> Term f
constTerm :: forall (f :: (* -> *) -> * -> *). HFunctor f => Const f :-> Term f
constTerm = forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (forall a b. a -> b -> a
const forall a. HasCallStack => a
undefined)

-- | This data type represents contexts over a signature. Contexts are
-- terms containing zero or more holes. The first type parameter is
-- supposed to be one of the phantom types 'Hole' and 'NoHole'. The
-- second parameter is the signature of the context. The third
-- parameter is the type family of the holes. The last parameter is
-- the index/label.

data Cxt h f a i where
    Term ::  f (Cxt h f a) i -> Cxt h f a i
    Hole :: a i -> Cxt Hole f a i

-- | Phantom type that signals that a 'Cxt' might contain holes.
data Hole
-- | Phantom type that signals that a 'Cxt' does not contain holes.
data NoHole

-- | A context might contain holes.
type Context = Cxt Hole

-- | A (higher-order) term is a context with no holes.
type Term f = Cxt NoHole f (K ())

-- | This function unravels the given term at the topmost layer.
unTerm :: Term f t -> f (Term f) t
unTerm :: forall (f :: (* -> *) -> * -> *) t. Term f t -> f (Term f) t
unTerm (Term f (Cxt NoHole f (K ())) t
t) = f (Cxt NoHole f (K ())) t
t

instance (HFunctor f) => HFunctor (Cxt h f) where
    hfmap :: forall (f :: * -> *) (g :: * -> *).
(f :-> g) -> Cxt h f f :-> Cxt h f g
hfmap f :-> g
f (Hole f i
x) = forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole (f :-> g
f f i
x)
    hfmap f :-> g
f (Term f (Cxt h f f) i
t) = forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term (forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap f :-> g
f) f (Cxt h f f) i
t)

instance (HFoldable f) => HFoldable (Cxt h f) where
    hfoldr :: forall (a :: * -> *) b. (a :=> (b -> b)) -> b -> Cxt h f a :=> b
hfoldr = forall (a :: * -> *) b. (a :=> (b -> b)) -> b -> Cxt h f a :=> b
hfoldr' where
        hfoldr'  :: forall a b. (a :=> (b -> b)) -> b -> Cxt h f a :=> b
        hfoldr' :: forall (a :: * -> *) b. (a :=> (b -> b)) -> b -> Cxt h f a :=> b
hfoldr' a :=> (b -> b)
op b
c Cxt h f a i
a = Cxt h f a :=> (b -> b)
run Cxt h f a i
a b
c where
              run :: (Cxt h f) a :=> (b ->  b)
              run :: Cxt h f a :=> (b -> b)
run (Hole a i
a) b
e = a i
a a :=> (b -> b)
`op` b
e
              run (Term f (Cxt h f a) i
t) b
e = forall (h :: (* -> *) -> * -> *) (a :: * -> *) b.
HFoldable h =>
(a :=> (b -> b)) -> b -> h a :=> b
hfoldr Cxt h f a :=> (b -> b)
run b
e f (Cxt h f a) i
t

    hfoldl :: forall b (a :: * -> *). (b -> a :=> b) -> b -> Cxt h f a :=> b
hfoldl = forall (a :: * -> *) b. (b -> a :=> b) -> b -> Cxt h f a :=> b
hfoldl' where
        hfoldl' :: forall a b. (b -> a :=> b) -> b -> Cxt h f a :=> b
        hfoldl' :: forall (a :: * -> *) b. (b -> a :=> b) -> b -> Cxt h f a :=> b
hfoldl' b -> a :=> b
op = b -> Cxt h f a :=> b
run where
              run :: b -> (Cxt h f) a :=> b
              run :: b -> Cxt h f a :=> b
run b
e (Hole a i
a) = b
e b -> a :=> b
`op` a i
a
              run b
e (Term f (Cxt h f a) i
t) = forall (h :: (* -> *) -> * -> *) b (a :: * -> *).
HFoldable h =>
(b -> a :=> b) -> b -> h a :=> b
hfoldl b -> Cxt h f a :=> b
run b
e f (Cxt h f a) i
t

    hfold :: forall m. Monoid m => Cxt h f (K m) :=> m
hfold (Hole (K m
a)) = m
a
    hfold (Term f (Cxt h f (K m)) i
t) = forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap forall (h :: (* -> *) -> * -> *) m.
(HFoldable h, Monoid m) =>
h (K m) :=> m
hfold f (Cxt h f (K m)) i
t

    hfoldMap :: forall m (a :: * -> *). Monoid m => (a :=> m) -> Cxt h f a :=> m
hfoldMap = forall m (a :: * -> *). Monoid m => (a :=> m) -> Cxt h f a :=> m
hfoldMap' where
        hfoldMap' :: forall m a. Monoid m => (a :=> m) -> Cxt h f a :=> m
        hfoldMap' :: forall m (a :: * -> *). Monoid m => (a :=> m) -> Cxt h f a :=> m
hfoldMap' a :=> m
f = Cxt h f a :=> m
run where
              run :: Cxt h f a :=> m
              run :: Cxt h f a :=> m
run (Hole a i
a) = a :=> m
f a i
a
              run (Term f (Cxt h f a) i
t) = forall (h :: (* -> *) -> * -> *) m (a :: * -> *).
(HFoldable h, Monoid m) =>
(a :=> m) -> h a :=> m
hfoldMap Cxt h f a :=> m
run f (Cxt h f a) i
t

instance (HTraversable f) => HTraversable (Cxt h f) where
   hmapM :: forall (m :: * -> *) (a :: * -> *) (b :: * -> *).
Monad m =>
NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
hmapM = forall (m :: * -> *) (a :: * -> *) (b :: * -> *).
Monad m =>
NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
hmapM' where
       hmapM' :: forall m a b. (Monad m) => NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
       hmapM' :: forall (m :: * -> *) (a :: * -> *) (b :: * -> *).
Monad m =>
NatM m a b -> NatM m (Cxt h f a) (Cxt h f b)
hmapM' NatM m a b
f = NatM m (Cxt h f a) (Cxt h f b)
run where
             run :: NatM m (Cxt h f a) (Cxt h f b)
             run :: NatM m (Cxt h f a) (Cxt h f b)
run (Hole a i
x) = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole forall a b. (a -> b) -> a -> b
$ NatM m a b
f a i
x
             run (Term f (Cxt h f a) i
t) = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Monad m) =>
NatM m a b -> NatM m (t a) (t b)
hmapM NatM m (Cxt h f a) (Cxt h f b)
run f (Cxt h f a) i
t
   htraverse :: forall (f :: * -> *) (a :: * -> *) (b :: * -> *).
Applicative f =>
NatM f a b -> NatM f (Cxt h f a) (Cxt h f b)
htraverse NatM f a b
f (Hole a i
x) = forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NatM f a b
f a i
x
   htraverse NatM f a b
f (Term f (Cxt h f a) i
t) = forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: (* -> *) -> * -> *) (f :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Applicative f) =>
NatM f a b -> NatM f (t a) (t b)
htraverse (forall (t :: (* -> *) -> * -> *) (f :: * -> *) (a :: * -> *)
       (b :: * -> *).
(HTraversable t, Applicative f) =>
NatM f a b -> NatM f (t a) (t b)
htraverse NatM f a b
f) f (Cxt h f a) i
t

simpCxt :: (HFunctor f) => f a i -> Context f a i
simpCxt :: forall (f :: (* -> *) -> * -> *) (a :: * -> *) i.
HFunctor f =>
f a i -> Context f a i
simpCxt = forall (f :: (* -> *) -> * -> *) h (a :: * -> *) i.
f (Cxt h f a) i -> Cxt h f a i
Term forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap forall (a :: * -> *) i (f :: (* -> *) -> * -> *).
a i -> Cxt Hole f a i
Hole

{-| Cast a term over a signature to a context over the same signature. -}
toCxt :: (HFunctor f) => Term f :-> Context f a
{-# INLINE toCxt #-}
toCxt :: forall (f :: (* -> *) -> * -> *) (a :: * -> *).
HFunctor f =>
Term f :-> Context f a
toCxt = forall a b. a -> b
unsafeCoerce
-- equivalentto @Term . (hfmap toCxt) . unTerm@