{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

module Text.GrammarCombinators.Base.Domain (
    DomainMap(supIx, subIx),
    SubVal(MkSubVal, unSubVal),
    IxMapId, IxMapBase, IxMapSeq, ApplyIxMap, 
    MemoFam(toMemo, fromMemo), Memo, memoFamily,
    EqFam(overrideIdx, eqIdx), overrideIdxK, 
    memoFamilyK, toMemoK, fromMemoK,
    LiftFam(liftIdxE, liftIdxP)
  ) where

import Generics.MultiRec.Base
import Generics.MultiRec.HFunctor

import Language.Haskell.TH.Syntax (Exp, Pat)

data IxMapId
data IxMapBase (m :: * -> *)
data IxMapSeq (l1 :: *) (l2 :: * -> *)
type family ApplyIxMap (m :: *) ix
type instance ApplyIxMap (IxMapBase m) ix = m ix
type instance ApplyIxMap (IxMapSeq l1 l2) ix = ApplyIxMap l1 (l2 ix)
type instance ApplyIxMap IxMapId ix = ix

class MemoFam (phi :: * -> *) where
  data Memo phi :: (* -> *) -> *
  fromMemo :: Memo phi v -> (forall ix. phi ix -> v ix)
  toMemo :: (forall ix. phi ix -> v ix) -> Memo phi v

memoFamily :: (MemoFam phi) =>
              (forall ix. phi ix -> v ix) -> (forall ix. phi ix -> v ix)  
memoFamily f = fromMemo (toMemo f)

memoFamilyK :: (MemoFam phi) =>
               (forall ix. phi ix -> v) -> (forall ix. phi ix -> v)
memoFamilyK f = fromMemoK (toMemoK f) 
toMemoK :: (MemoFam phi) =>
           (forall ix. phi ix -> v) -> Memo phi (K0 v)
toMemoK f = toMemo (K0 . f)
fromMemoK :: (MemoFam phi) =>
             Memo phi (K0 v) -> phi ix -> v
fromMemoK m = unK0 . fromMemo m

-- | A domain 'phi' that is an instance of the 'FoldFam' type class supports 
-- folding over all non-terminals in the domain using the 'foldFam' function.
class FoldFam phi where
  -- | Fold a given function over all non-terminals in the domain 'phi'.
  foldFam :: (forall ix. phi ix -> b -> b) -> b -> b

-- | A domain 'phi' that is an instance of the 'ShowFam' type class supports 
-- conversion of non-terminal proof terms to Strings using the 'showIdx' function.
class ShowFam phi where
  -- | Convert a given non-terminal proof term to a String representation.
  showIdx :: forall ix. phi ix -> String

-- | A domain 'phi' that is an instance of the 'EqFam' type class supports 
-- overriding a function over the full domain at a single non-terminal using 
-- the |overrideIdx| function.
class EqFam phi where
  -- | Test equality of two given non-terminal proof terms.
  eqIdx :: forall ix1 ix2. phi ix1 -> phi ix2 -> Bool
  eqIdx idx1 = overrideIdxK (const False) idx1 True 
  -- | Override a function over the full domain at a single non-terminal.
  overrideIdx :: (forall ix'. phi ix' -> r ix') -> phi oix ->
                 r oix -> phi ix -> r ix

-- | Similar to the 'overrideIdx' function, but limited to functions whose result type is 
-- the same for all non-terminals.
overrideIdxK :: (EqFam phi) => (forall ix'. phi ix' -> v) -> phi oix -> v -> phi ix -> v
overrideIdxK f idx v = unK0 . overrideIdx (K0 . f) idx (K0 v)

-- | A decent Domain 'phi' should instantiate the 'FoldFam', 'ShowFam', 'EqFam' and 'MemoFam'. Avoid
-- using this type class in constraints, use more specific type classes whenever possible.
-- Note: instances for this type class are not automatically derived, and you have to manually instantiate 
-- it with an empty implementation block.
class (FoldFam phi,
       ShowFam phi,
       EqFam phi,
       MemoFam phi) => Domain phi

class DomainMap phi phi' supIxT where
  supIx :: phi' ix -> phi (supIxT ix)
  subIx :: phi (supIxT ix) -> phi' ix
class (DomainMap phi phi' supIxT) =>
      DomainEmbedding phi phi' supIxT where
  supPF :: (HFunctor phi (PF phi)) =>
           phi' ix -> phi (supIxT ix) ->
           PF phi' (SubVal supIxT r) ix -> PF phi r (supIxT ix)

-- | A generic wrapper type that restricts a semantic value family over a bigger domain
-- to a smaller domain.
data SubVal (supIxT :: * -> *) v ix = MkSubVal {
       unSubVal :: v (supIxT ix)
       } deriving (Show)

class LiftFam phi where
  liftIdxE :: phi ix -> Exp
  liftIdxP :: phi ix -> Pat