{-# LANGUAGE GADTs, RankNTypes, TypeOperators, ViewPatterns, TypeFamilies, FlexibleInstances, FlexibleContexts, TemplateHaskell #-}

{-|

  This module defines the NuElim typeclass, which allows eliminations
  of (multi-) bindings. The high-level idea is that, when a fresh name
  is created with 'nu', the fresh name can also be substituted for the
  bound name in a binding. See the documentation for 'nuWithElim1' and
  'nuWithElimMulti' for details.
-}

module Data.Binding.Hobbits.NuElim (
  nuWithElimMulti, nuWithElim, nuWithElim1,
  NuElim(..), mkNuElimData, NuElimList(..),
  NuElimProof(),
) where

--import Data.Data
import Data.Binding.Hobbits.Internal
--import Data.Binding.Hobbits.Context
import Data.Binding.Hobbits.Mb
import Data.Type.List
import Language.Haskell.TH hiding (Name)
import qualified Language.Haskell.TH as TH
--import Unsafe.Coerce (unsafeCoerce)
import Control.Monad.State
import Control.Monad.Identity

-- "proofs" that type a is an allowed type for nu-elimination; i.e.,
-- that the multi-binding of an Mb ctx a can be eliminated by nuWithElim
data NuElimProof a where
    NuElimName :: NuElimProof (Name a)
    NuElimMb :: NuElimProof a -> NuElimProof (Mb ctx a)
    NuElimFun :: NuElimProof a -> NuElimProof b -> NuElimProof (a -> b)
    NuElimData :: NuElimDataProof a -> NuElimProof a

data NuElimDataProof a =
    NuElimDataProof (forall ctx. MapC Name ctx -> MapC Name ctx -> a -> a)


-- map the names in one MapC to the corresponding ones in another
mapNamesPf :: NuElimProof a -> MapC Name ctx -> MapC Name ctx -> a -> a
mapNamesPf NuElimName Nil Nil n = n
mapNamesPf NuElimName (names :> m) (names' :> m') n =
    case cmpName m n of
      Just Refl -> m'
      Nothing -> mapNamesPf NuElimName names names' n
mapNamesPf NuElimName _ _ _ = error "Should not be possible! (in mapNamesPf)"
mapNamesPf (NuElimMb nuElim) names1 names2 (MkMb ns b) =
    undefined
{- FIXME: something like this...
    let names2' = fresh_names ns in
    MkMb names2' $ mapNamesPf (names1 |++> ns) (names2 |++> names2') b
-}
mapNamesPf (NuElimFun nuElimIn nuElimOut) names names' f =
    (mapNamesPf nuElimOut names names') . f . (mapNamesPf nuElimIn names' names)
mapNamesPf (NuElimData (NuElimDataProof mapFun)) names names' x =
    mapFun names names' x


-- just like mapNamesPf, except use the NuElim class
mapNames :: NuElim a => MapC Name ctx -> MapC Name ctx -> a -> a
mapNames = mapNamesPf nuElimProof


{-|
  Instances of the @NuElim a@ class allow the type @a@ to be used with
  'nuWithElimMulti' and 'nuWithElim1'. The structure of this class is
  mostly hidden from the user; see 'mkNuElimData' to see how to create
  instances of the @NuElim@ class.
-}
class NuElim a where
    nuElimProof :: NuElimProof a

instance NuElim (Name a) where
    nuElimProof = NuElimName

instance (NuElim a, NuElim b) => NuElim (a -> b) where
    nuElimProof = NuElimFun nuElimProof nuElimProof

{-
instance NuElim a => NuElim (Mb Nil a) where
    nuElimProof = NuElimMbNil nuElimProof

instance NuElim (Mb ctx a) => NuElim (Mb (ctx :> c) a) where
    nuElimProof = NuElimMbCons nuElimProof
-}

instance NuElim a => NuElim (Mb ctx a) where
    nuElimProof = NuElimMb nuElimProof

instance NuElim Int where
    nuElimProof = NuElimData (NuElimDataProof $ (\c1 c2 -> id))

instance NuElim Char where
    nuElimProof = NuElimData (NuElimDataProof $ (\c1 c2 -> id))

instance NuElim a => NuElim [a] where
    nuElimProof = NuElimData (NuElimDataProof $ (\c1 c2 -> map (mapNames c1 c2)))

instance (NuElim a, NuElim b) => NuElim (a,b) where
    nuElimProof = NuElimData (NuElimDataProof $ (\c1 c2 (a,b) -> (mapNames c1 c2 a, mapNames c1 c2 b)))

instance (NuElim a, NuElim b) => NuElim (Either a b) where
    nuElimProof = NuElimData
                  (NuElimDataProof
                   $ (\c1 c2 x -> case x of
                                    Left l -> Left (mapNames c1 c2 l)
                                    Right r -> Right (mapNames c1 c2 r)))


{-
type family NuElimListProof args
type instance NuElimListProof Nil = ()
type instance NuElimListProof (args :> arg) = (NuElimListProof args, NuElimProof arg)

-- the NuElimList class, for saying that NuElim holds for a context of types
class NuElimList args where
    nuElimListProof :: NuElimListProof args

instance NuElimList Nil where
    nuElimListProof = ()

instance (NuElimList args, NuElim a) => NuElimList (args :> a) where
    nuElimListProof = (nuElimListProof, nuElimProof)
-}

-- the NuElimList class, for saying that NuElim holds for a context of types
class NuElimList args where
    nuElimListProof :: MapC NuElimProof args

instance NuElimList Nil where
    nuElimListProof = Nil

instance (NuElimList args, NuElim a) => NuElimList (args :> a) where
    nuElimListProof = nuElimListProof :> nuElimProof


-- filter out names from a MapC Name
{-
data FilterRes where
    FilterRes :: (MapC Name ctx, MapC Name ctx) -> FilterRes

filterName :: Name a -> MapC Name ctx -> MapC Name ctx -> FilterRes
filterName n Nil Nil = Nil
filterName 

filterNames :: [Int] -> MapC Name ctx -> MapC Name ctx -> FilterRes
filterNames Nil names1 names2 = FilterRes names1 names2
Nil Nil = FilterRes Nil
filterNames (names1 :> n1)
-}

-- FIXME: add more examples!!
{-|
  The expression @nuWithElimMulti args f@ takes a sequence @args@ of
  zero or more multi-bindings, each of type @Mb ctx ai@ for the same
  type context @ctx@ of bound names, and a function @f@ and does the
  following:

  * Creates a multi-binding that binds names @n1,...,nn@, one name for
    each type in @ctx@;

  * Substitutes the names @n1,...,nn@ for the names bound by each
    argument in the @args@ sequence, yielding the bodies of the @args@
    (using the new name @n@); and then

  * Passes the sequence @n1,...,nn@ along with the result of
    substituting into @args@ to the function @f@, which then returns
    the value for the newly created binding.

  Note that the types in @args@ must each have a @NuElim@ instance;
  this is represented with the @NuElimList@ type class.

  Here are some examples:

> commuteFun :: (NuElim a, NuElim b) => Mb ctx (a -> b) -> Mb ctx a -> Mb ctx b
> commuteFun f a =
>     nuWithElimMulti ('mbToProxy' f) ('Nil' :> f :> a)
>                     (\_ ('Nil' :> 'Identity' f' :> 'Identity' a') -> f' a')
-}
nuWithElimMulti :: NuElimList args => MapC f ctx -> MapC (Mb ctx) args ->
                   (MapC Name ctx -> MapC Identity args -> b) -> Mb ctx b
nuWithElimMulti nameProxies args f =
    nuMulti nameProxies $ \names -> f names (mapArgs nuElimListProof args names)
    where
      mapArgs :: MapC NuElimProof args -> MapC (Mb ctx) args ->
                 MapC Name ctx -> MapC Identity args
      mapArgs Nil Nil _ = Nil
      mapArgs (proofs :> proof) (args :> arg) names =
          mapArgs proofs args names :>
                      (Identity $ mapNamesSubst proof names arg)
      mapArgs _ _ _ = error "Should not be possible! (in mapArgs)"

      mapNamesSubst :: NuElimProof arg -> MapC Name ctx -> Mb ctx arg -> arg
      mapNamesSubst proof names (MkMb namesOld arg) =
          mapNamesPf proof (unsafeNamesFromInts namesOld) names arg


{-|
  Similar to 'nuWithElimMulti' but binds only one name.
-}
nuWithElim :: NuElimList args => MapC (Mb (Nil :> a)) args ->
              (Name a -> MapC Identity args -> b) -> Binding a b
nuWithElim args f =
    nuWithElimMulti (Nil :> Proxy) args (\(Nil :> n) -> f n)


{-|
  Similar to 'nuWithElimMulti' but takes only one argument that binds
  a single name.
-}
nuWithElim1 :: NuElim arg => Binding a arg ->
               (Name a -> arg -> b) -> Binding a b
nuWithElim1 (MkMb [nOld] arg) f =
    nu $ \n -> f n (mapNames (Nil :> MkName nOld) (Nil :> n) arg)
nuWithElim1 (MkMb _ arg) f =
    error "Internal error in nuWithElim1"


-- now we define some TH to create NuElimDataProofs

natsFrom i = i : natsFrom (i+1)

fst3 :: (a,b,c) -> a
fst3 (x,_,_) = x

snd3 :: (a,b,c) -> b
snd3 (_,y,_) = y

thd3 :: (a,b,c) -> c
thd3 (_,_,z) = z


type Names = (TH.Name, TH.Name, TH.Name, TH.Name)

mapNamesType a = [t| forall ctx. MapC Name ctx -> MapC Name ctx -> $a -> $a |]

{-|
  Template Haskell function for creating NuElim instances for (G)ADTs.
  Typical usage is to include the following line in the source file for
  (G)ADT @T@ (here assumed to have two type arguments):

> $(mkNuElimData [t| forall a b . T a b |])

  The 'mkNuElimData' call here will create an instance declaration for
  @'NuElim' (T a b)@. It is also possible to include a context in the
  forall type; for example, if we define the 'ID' data type as follows:

> data ID a = ID a

  then we can create a 'NuElim' instance for it like this:

> $( mkNuElimData [t| NuElim a => ID a |])

  Note that, when a context is included, the Haskell parser will add
  the @forall a@ for you.
-}
mkNuElimData :: Q Type -> Q [Dec]
mkNuElimData tQ =
    do t <- tQ
       (cxt, cType, tName, constrs, tyvars) <- getNuElimInfoTop t
       fName <- newName "f"
       x1Name <- newName "x1"
       x2Name <- newName "x2"
       clauses <- mapM (getClause (tName, fName, x1Name, x2Name)) constrs
       mapNamesT <- mapNamesType (return cType)
       return [InstanceD
               cxt (AppT (ConT ''NuElim) cType)
               [ValD (VarP 'nuElimProof)
                (NormalB
                 $ AppE (ConE 'NuElimData)
                   $ AppE (ConE 'NuElimDataProof)
                         $ LetE [SigD fName
                                 $ ForallT (map PlainTV tyvars) cxt mapNamesT,
                                 FunD fName clauses]
                               (VarE fName)) []]]

       {-
       return (LetE
               [SigD fName
                     (ForallT tyvars reqCxt
                     $ foldl AppT ArrowT
                           [foldl AppT (ConT conName)
                                      (map tyVarToType tyvars)]),
                FunD fname clauses]
               (VarE fname))
        -}
    where
      -- extract the name from a TyVarBndr
      tyBndrToName (PlainTV n) = n
      tyBndrToName (KindedTV n _) = n

      -- fail for getNuElimInfo
      getNuElimInfoFail t =
          fail ("mkNuElimData: " ++ show t
                ++ " is not a type constructor for a (G)ADT applied to zero or more distinct type variables")

      -- get info for conName (top-level call)
      getNuElimInfoTop t = getNuElimInfo [] [] t t

      -- get info for conName
      getNuElimInfo ctx tyvars topT (ConT tName) =
          do info <- reify tName
             case info of
               TyConI (DataD _ _ tyvarsReq constrs _) ->
                   let tyvarsRet = if tyvars == [] && ctx == []
                                   then map tyBndrToName tyvarsReq
                                   else tyvars in
                   return (ctx,
                           foldl AppT (ConT tName)
                                     (map VarT tyvars),
                           tName, constrs, tyvarsRet)
               _ -> getNuElimInfoFail topT

      getNuElimInfo ctx tyvars topT (AppT f (VarT argName)) =
          if elem argName tyvars then
              getNuElimInfoFail topT
          else
              getNuElimInfo ctx (argName:tyvars) topT f

      getNuElimInfo ctx tyvars topT (ForallT _ ctx' t) =
          getNuElimInfo (ctx ++ ctx') tyvars topT t

      getNuElimInfo ctx tyvars topT t = getNuElimInfoFail topT

      -- get the name from a data type
      getTCtor t = getTCtorHelper t t []
      getTCtorHelper (ConT tName) topT tyvars = Just (topT, tName, tyvars)
      getTCtorHelper (AppT t1 (VarT var)) topT tyvars =
          getTCtorHelper t1 topT (tyvars ++ [var])
      getTCtorHelper (SigT t1 _) topT tyvars = getTCtorHelper t1 topT tyvars
      getTCtorHelper _ _ _ = Nothing

      -- get a list of Clauses, one for each constructor in constrs
      getClauses :: Names -> [Con] -> Q [Clause]
      getClauses names constrs = mapM (getClause names) constrs

      getClause :: Names -> Con -> Q Clause
      getClause names (NormalC cName cTypes) =
          getClauseHelper names (map snd cTypes)
                          (natsFrom 0)
                          (\l -> ConP cName (map (VarP . fst3) l))
                          (\l -> foldl AppE (ConE cName) (map fst3 l))

      getClause names (RecC cName cVarTypes) =
          getClauseHelper names (map thd3 cVarTypes)
                         (map fst3 cVarTypes)
                         (\l -> RecP cName
                                (map (\(var,_,field) -> (field, VarP var)) l))
                         (\l -> RecConE cName
                                (map (\(exp,_,field) -> (field, exp)) l))

      getClause names (InfixC cType1 cName cType2) =
          undefined -- FIXME

      getClause names (ForallC _ _ con) =  getClause names con

      getClauseHelper :: Names -> [Type] -> [a] ->
                         ([(TH.Name,Type,a)] -> Pat) ->
                         ([(Exp,Type,a)] -> Exp) ->
                         Q Clause
      getClauseHelper names@(tName, fName, x1Name, x2Name) cTypes cData pFun eFun =
          do varNames <- mapM (newName . ("x" ++) . show . fst)
                         $ zip (natsFrom 0) cTypes
             let varsTypesData = zip3 varNames cTypes cData
             let expsTypesData = map (mkExpTypeData names) varsTypesData
             return $ Clause [(VarP x1Name), (VarP x2Name), (pFun varsTypesData)]
                        (NormalB $ eFun expsTypesData) []

      mkExpTypeData :: Names -> (TH.Name,Type,a) -> (Exp,Type,a)
      mkExpTypeData (tName, fName, x1Name, x2Name)
                    (varName, getTCtor -> Just (t, tName', _), cData)
          | tName == tName' =
              -- the type of the arg is the same as the (G)ADT we are
              -- recursing over; apply the recursive function
              (foldl AppE (VarE fName)
                         [(VarE x1Name), (VarE x2Name), (VarE varName)],
               t, cData)
      mkExpTypeData (tName, fName, x1Name, x2Name) (varName, t, cData) =
          -- the type of the arg is not the same as the (G)ADT; call mapNames
          (foldl AppE (VarE 'mapNames)
                     [(VarE x1Name), (VarE x2Name), (VarE varName)],
           t, cData)

-- FIXME: old stuff below

type CxtStateQ a = StateT Cxt Q a

-- create a NuElimDataProof for a (G)ADT
mkNuElimDataProofOld :: Q TH.Name -> Q Exp
mkNuElimDataProofOld conNameQ =
    do conName <- conNameQ
       (cxt, name, tyvars, constrs) <- getNuElimInfo conName
       (clauses, reqCxt) <- runStateT (getClauses cxt name tyvars constrs) []
       fname <- newName "f"
       return (LetE
               [SigD fname
                     (ForallT tyvars reqCxt
                     $ foldl AppT ArrowT
                           [foldl AppT (ConT conName)
                                      (map tyVarToType tyvars)]),
                FunD fname clauses]
               (VarE fname))
    where
      -- convert a TyVar to a Name
      tyVarToType (PlainTV n) = VarT n
      tyVarToType (KindedTV n _) = VarT n

      -- get info for conName
      getNuElimInfo conName =
          reify conName >>= \info ->
              case info of
                TyConI (DataD cxt name tyvars constrs _) ->
                    return (cxt, name, tyvars, constrs)
                _ -> fail ("mkNuElimDataProof: " ++ show conName
                           ++ " is not a (G)ADT")
      {-
      -- report failure
      getNuElimInfoFail t =
          fail ("mkNuElimDataProof: " ++ show t
                ++ " is not a fully applied (G)ADT")

      getNuElimInfo (ConT conName) topT =
          reify conName >>= \info ->
              case info of
                TyConI (DataD cxt name tyvars constrs _) ->
                    return (cxt, name, tyvars, constrs)
                _ -> getNuElimInfoFail topT
      getNuElimInfo (AppT t _) topT = getNuElimInfo t topT
      getNuElimInfo (SigT t _) topT = getNuElimInfo t topT
      getNuElimInfo _ topT = getNuElimInfoFail topT
       -}

      -- get a list of Clauses, one for each constructor in constrs
      getClauses :: Cxt -> TH.Name -> [TyVarBndr] -> [Con] -> CxtStateQ [Clause]
      getClauses cxt name tyvars constrs =
          mapM (getClause cxt name tyvars []) constrs

      getClause :: Cxt -> TH.Name -> [TyVarBndr] -> [TyVarBndr] -> Con ->
                   CxtStateQ Clause
      getClause cxt name tyvars locTyvars (NormalC cName cTypes) =
          getClauseHelper cxt name tyvars locTyvars (map snd cTypes)
                          (natsFrom 0)
                          (\l -> ConP cName (map (VarP . fst3) l))
                          (\l -> foldl AppE (ConE cName) (map (VarE . fst3) l))

      getClause cxt name tyvars locTyvars (RecC cName cVarTypes) =
          getClauseHelper cxt name tyvars locTyvars (map thd3 cVarTypes)
                         (map fst3 cVarTypes)
                         (\l -> RecP cName
                                (map (\(var,_,field) -> (field, VarP var)) l))
                         (\l -> RecConE cName
                                (map (\(var,_,field) -> (field, VarE var)) l))

      getClause cxt name tyvars locTyvars (InfixC cType1 cName cType2) =
          undefined -- FIXME

      getClause cxt name tyvars locTyvars (ForallC tyvars2 cxt2 con) =
          getClause (cxt ++ cxt2) name tyvars (locTyvars ++ tyvars2) con

      getClauseHelper :: Cxt -> TH.Name -> [TyVarBndr] -> [TyVarBndr] ->
                         [Type] -> [a] ->
                         ([(TH.Name,Type,a)] -> Pat) ->
                         ([(TH.Name,Type,a)] -> Exp) ->
                         CxtStateQ Clause
      getClauseHelper cxt name tyvars locTyvars cTypes cData pFun eFun =
          do varNames <- mapM (lift . newName . ("x" ++) . show . fst)
                         $ zip (natsFrom 0) cTypes
             () <- ensureCxt cxt locTyvars cTypes
             let varsTypesData = zip3 varNames cTypes cData
             return $ Clause [(pFun varsTypesData)]
                        (NormalB $ eFun varsTypesData) []

      -- ensure that NuElim a holds for each type a in cTypes
      ensureCxt :: Cxt -> [TyVarBndr] -> [Type] -> CxtStateQ ()
      ensureCxt cxt locTyvars cTypes =
          foldM (const (ensureCxt1 cxt locTyvars)) () cTypes

      -- FIXME: it is not possible (or, at least, not easy) to determine
      -- if NuElim a is implied from a current Cxt... so we just add
      -- everything we need to the returned Cxt, except for 
      ensureCxt1 :: Cxt -> [TyVarBndr] -> Type -> CxtStateQ ()
      ensureCxt1 cxt locTyvars t = undefined
      {-
      ensureCxt1 cxt locTyvars t = do
        curCxt = get
        let fullCxt = cxt ++ curCxt
        isOk <- isNuElim fullCxt 

      isNuElim 
       -}