{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE GADTs           #-}
{-# LANGUAGE KindSignatures  #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MonomorphismRestriction #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS -fwarn-missing-signatures #-}

Example usage:

import Generics.MultiRec
import Generics.MultiRec.TH.Alt
import Data.Tree

data TheFam :: (* -> *) where
              Tree_Int   :: TheFam   (Tree Int)
              Forest_Int :: TheFam (Forest Int)
   [ ( [t| Tree   Int |],   \"Tree_Int\" )
   , ( [t| Forest Int |], \"Forest_Int\" )
   (\\t c -> \"CONSTRUCTOR_\" ++ t ++ \"_\" ++ c)

type instance 'PF' TheFam = ThePF

module Generics.MultiRec.TH.Alt
  ) where

import Generics.MultiRec.TH.Alt.DerivOptions(DerivOptions(..))
import THUtils(AppliedTyCon, (@@), (@@@), toAppliedTyCon,
               fromAppliedTyCon, atc2constructors, pprintUnqual, sMatch, sClause,
import BalancedFold(balancedFold, ascendFromLeaf)
import MonadRQ(RQ, message, messageReport, liftq, foreachType,
               foreachTypeNumbered, runRQ)
import Generics.MultiRec.Base((:>:)(..), C(..), El(..), (:=:)(..),
                              I0(I0), (:*:)(..), (:+:)(..), EqS(..), Fam(..), I(I), K(K), U(..))
import Generics.MultiRec.Constructor(Associativity(..), Fixity(..),
import Control.Monad.Reader(Monad(return, fail, (>>)), Functor(..),
                            (=<<), mapM, sequence, liftM, zipWithM, asks)
import Language.Haskell.TH.Syntax(Lift(..))
import Language.Haskell.TH(newName, mkName, wildP, clause, conE,
                           appE, normalB, funD, dataD, instanceD, cxt, conT, appT,
                           Exp(VarE, SigE, LamE, ConE, CaseE, AppE), Match, Clause, Q,
                           Pat(WildP, VarP, ConP), TypeQ, Type(ConT),
                           Dec(TySynD, InstanceD, FunD), Name, 
                           Con(RecC, NormalC, InfixC),
                           Info(DataConI), nameBase, reify, stringE)
import Data.Map(lookup, elems)
import Control.Applicative((<$>))

import qualified Data.Map as Map
import qualified Language.Haskell.TH as TH



deriveEverything :: DerivOptions [(TypeQ, String)] -> Q [Dec]
deriveEverything opts = do
  -- let x | mkSanityChecks opts = makeSanityChecks
  --       | otherwise = return []
  runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts

-- | Given a list of datatype names, derive datatypes and 
-- instances of class 'Constructor'.

deriveConstructors :: RQ [Dec]
deriveConstructors =
  concat <$> foreachType constrInstance

-- | Given the name of the index GADT, the names of the
-- types in the family, and the name (as string) for the
-- pattern functor to derive, generate the 'Ix' and 'PF'
-- instances. /IMPORTANT/: It is assumed that the constructors
-- of the GADT have the same names as the datatypes in the
-- family.

deriveFamily :: RQ [Dec]
deriveFamily =
    pf  <- derivePF
    el  <- deriveEl
    fam <- deriveFam
    eq  <- deriveEqS
    return $ pf ++ el ++ fam ++ eq

-- | Derive only the 'PF' instance. Not needed if 'deriveFamily'
-- is used.

derivePF :: RQ [Dec]
derivePF =
      branches <- foreachType pfType
      pfn <- asks patternFunctorName
          pf = [TySynD (mkName pfn) [] (sumT branches)]

      famName <- asks indexGadtName
      -- message $
      --       (   "*** The pattern functor is:\n"
      --        ++ pprint (cutNames pf)
      --        ++ "\n\n\n"
      --       )

      messageReport (
                "Reminder: Don't forget to add this line manually:\n"
             ++ "    type instance PF "++famName++" = "++pfn
      return pf

sumT :: [Type] -> Type
sumT | bALANCED_MODE = balancedSumT
     | otherwise = rightSumT
rightSumT :: [Type] -> Type
rightSumT = foldr1 plusT
balancedSumT :: [Type] -> Type
balancedSumT = balancedFold plusT
plusT :: Type -> Type -> Type
plusT a b = ConT ''(:+:) @@ a @@ b
prodT :: [Type] -> Type
prodT = foldr1 timesT
timesT :: Type -> Type -> Type
timesT a b = ConT ''(:*:) @@ a @@ b
-- | Derive only the 'El' instances. Not needed if 'deriveFamily'
-- is used.

deriveEl :: RQ [Dec]
deriveEl = foreachType elInstance
indexGadtType :: RQ Type
indexGadtType = ConT . mkName <$> asks indexGadtName

-- | Dervie only the 'Fam' instance. Not needed if 'deriveFamily'
-- is used.

deriveFam :: RQ [Dec]
deriveFam =
    fcs <- liftM concat $ foreachTypeNumbered mkFrom
    tcs <- foreachTypeNumbered mkTo
    s <- indexGadtType
    return [
      InstanceD [] (ConT ''Fam @@ s)
                    [FunD 'from fcs, FunD 'to tcs]

-- | Derive only the 'EqS' instance. Not needed if 'deriveFamily'
-- is used.

deriveEqS :: RQ [Dec]
deriveEqS = do
  s <- indexGadtType
  ns <- elems <$> asks familyTypes
  return [
    InstanceD [] (ConT ''EqS @@ s)
      [FunD 'eqS (trues ns ++ falses ns)]
    trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []] 
                   ((ConE 'Just `AppE` ConE 'Refl)) 
    falseClause  = sClause [WildP,  WildP]        
                   ((ConE 'Nothing)) 
    trues ns     = fmap trueClause ns
    falses ns    = if length (trues ns) == 1 then [] else [falseClause]

constrInstance :: (AppliedTyCon,String) -> RQ [Dec]
constrInstance (atc,s) =
    cs <- liftq (atc2constructors atc)
    -- runIO (print i)
    ds <- mapM (mkData s) cs
    is <- mapM (mkInstance s) cs
    return $ ds ++ is

stripRecordNames :: Con -> Con
stripRecordNames (RecC n f) =
  NormalC n (fmap (\(_, s, t) -> (s, t)) f)
stripRecordNames c = c

-- TODO: Handle colons in the constructor name
mkData :: String -> Con -> RQ Dec
mkData s (NormalC n _) = do
  modifier <- asks constructorNameModifier
  liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] [] 
mkData s r@(RecC _ _) =
  mkData s (stripRecordNames r)
mkData s (InfixC t1 n t2) =
  mkData s (NormalC n [t1,t2])

instance Lift Fixity where
  lift Prefix      = conE 'Prefix
  lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |]

instance Lift Associativity where
  lift LeftAssociative  = conE 'LeftAssociative
  lift RightAssociative = conE 'RightAssociative
  lift NotAssociative   = conE 'NotAssociative

mkInstance :: String -> Con -> RQ Dec
mkInstance s (NormalC n _) =
      modifier <- asks constructorNameModifier
      let n' = modifier s . cleanConstructorName . nameBase $ n

      liftq $
       instanceD (cxt []) 
        (appT (conT ''Constructor) (conT . mkName $ n'))
        [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]]
mkInstance s r@(RecC _ _) =
  mkInstance s (stripRecordNames r)
mkInstance s (InfixC t1 n t2) =
      modifier <- asks constructorNameModifier
      let n' = modifier s . cleanConstructorName . nameBase $ n

      i <- liftq (reify n)
      let fi = case i of
                 DataConI _ _ _ f -> convertFixity f
                 _ -> Prefix
      liftq $ 
       instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n'))
        [funD 'conName   [clause [wildP] (normalB (stringE (nameBase n))) []],
         funD 'conFixity [clause [wildP] (normalB [| fi |]) []]]
    convertFixity (TH.Fixity n d) = Infix (convertDirection d) n
    convertDirection InfixL = LeftAssociative
    convertDirection InfixR = RightAssociative
    convertDirection InfixN = NotAssociative

pfType :: (AppliedTyCon,String) -> RQ Type
pfType (atc,s) =
      -- runIO $ putStrLn $ "processing " ++ show n
      cs <- liftq (atc2constructors atc)
      guardEmptyData cs atc

      b <- sumT <$> mapM (pfCon s) cs
      return $
       ConT ''(:>:) @@ b @@ fromAppliedTyCon atc

pfCon :: String -> Con -> RQ Type
pfCon s (NormalC n fs) = do
  modifier <- asks constructorNameModifier
  let n' = mkName . modifier s . cleanConstructorName . nameBase $ n
  fieldResults <- mapM (pfField . snd) fs
  let rest = 
          case fs of
               [] -> ConT ''U
               _ -> prodT fieldResults
  return $
    ConT ''C @@ ConT n' @@ rest

pfCon s r@(RecC _ _) =
  pfCon s (stripRecordNames r)
pfCon s (InfixC t1 n t2) =
    pfCon s (NormalC n [t1,t2])

pfField :: Type -> RQ Type
pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t)

lookupFam :: Type -> RQ (Maybe String)
lookupFam t = 
      ts <- asks familyTypes
      t' <- liftq $ toAppliedTyCon t
      let res = case t' of
                  Right t'' -> Map.lookup t'' ts
                  Left _ -> Nothing
      -- message ("familyTypes = "++show ts)
      -- message ("lookupFam "++show t'++" = "++show res)
      return res 
ifInFamily :: Type -> a -> a -> RQ a
ifInFamily n x y = ifInFamily' n (return x) (return y)

ifInFamily' :: Type -> RQ a -> RQ a -> RQ a
ifInFamily' t x y = maybe y (const x) 
                          =<< lookupFam t

elInstance :: (AppliedTyCon,String) -> RQ Dec
elInstance x@(atc,_) = do
  s <- indexGadtType
  prf <- mkProof x
  return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf]

mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause]
mkFrom m i (atc,s) =
      -- ns <- fmap mkName . elems <$> asks familyTypes
      -- runIO $ putStrLn $ "processing " ++ show n
      cs <- liftq (atc2constructors atc)
          wrapE = (\e -> lrE m i (ConE 'Tag @@@ e))
          dn = mkName s -- (nameBase n)
      zipWithM (fromCon wrapE dn (length cs)) [0..] cs

mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause
mkTo m i (atc,s) =
      -- ns <- fmap mkName . elems <$> asks familyTypes
      -- runIO $ putStrLn $ "processing " ++ show n
      cs <- liftq (atc2constructors atc)
      pfname <- mkName <$> asks patternFunctorName
          -- typeOfLamE = ArrowT @@
          --              (ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc) @@
          --              (fromAppliedTyCon atc)
      matchesOfCons <-
          zipWithM (toCon (length cs) atc) [0..] cs

      xvar <- liftq (newName "x")
      convar <- liftq (newName "con")

      typeOfConvar <- 
            t0 <- pfType (atc,s) 
            return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc)
          typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc
          body = LamE [VarP xvar]
                  (CaseE (VarE xvar `SigE` typeOfXvar)
                   [sMatch (lrP m i (VarP convar))
                    (CaseE (VarE convar `SigE` typeOfConvar)
      return (sClause 
              [ConP (mkName s) []]


mkProof :: (AppliedTyCon,String) -> RQ Dec
mkProof (_,s) = return $
  FunD 'proof [sClause [] (ConE (mkName s)) ]

fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause
fromCon wrap n m i (NormalC cn []) = return $
    -- Nullary constructor case
      [ConP n [], ConP cn []]
      (wrap . lrE m i 
                   $ ConE 'C @@@ ConE 'U)
fromCon wrap n m i (NormalC cn fs) =
      rhs <- zipWithM fromField [0..] (snd <$> fs)
      return $
          [ ConP n [], 
            ConP cn (fmap (VarP . field) [0..length fs - 1])
          (wrap . lrE m i 
                       $ ConE 'C @@@ foldr1 prod rhs) 
    prod x y = ConE '(:*:) @@@ x @@@ y
fromCon wrap n m i r@(RecC _ _) =
  fromCon wrap n m i (stripRecordNames r)
fromCon wrap n m i (InfixC t1 cn t2) =
  fromCon wrap n m i (NormalC cn [t1,t2])

toCon ::      
                                       Int -- ^ Number of constructors
                                     -> AppliedTyCon
                                     -> Int -- ^ Index of this constructor
                                     -> Con
                                     -> RQ Match
toCon m atc i (NormalC cn []) = return $
    -- Nullary constructor case
      (ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]])
               ConE cn
               -- SigE (ConE cn) (fromAppliedTyCon atc) 
toCon m atc i (NormalC cn fs) =
    -- runIO (putStrLn ("constructor " ++ show ix)) >>
      lhs <- zipWithM toField [0..] (fmap snd fs)
      return $ 
         (ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]])
                 -- SigE (
                       foldl AppE (ConE cn) 
                       (fmap (VarE . field) [0..length fs - 1]) 
                      -- )
                      -- (fromAppliedTyCon atc)
    prod x y = ConP '(:*:) [x,y]
toCon m atc i r@(RecC _ _) =
  toCon m atc i (stripRecordNames r) 
toCon m atc i (InfixC t1 cn t2) =
  toCon m atc i (NormalC cn [t1,t2]) 

fromField :: Int -> Type -> RQ Exp
fromField nr t = 
    ifInFamily' t 
    (return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr))))
    (message ("* Info: Type not in family: " ++ pprintUnqual t) >>
     -- helper t >>
     return (ConE 'K @@@ VarE (field nr)))
toField :: Int -> Type -> RQ Pat
toField nr t =
    ifInFamily t
    (ConP 'I [ConP 'I0 [VarP (field nr)]])
    (ConP 'K [VarP (field nr)])

field :: Int -> Name
field n = mkName $ "f" ++ show n

lrP :: Int -> Int -> ( Pat ->  Pat)
lrP m i p | bALANCED_MODE = ascendFromLeaf 
                            (ConP 'L . (:[] {- robot monkey -}))      
                            (ConP 'R . (:[]))      

lrP 1 0 p = p
lrP m 0 p = ConP 'L [p]
lrP m i p = ConP 'R [lrP (m-1) (i-1) p]

lrE :: Int -> Int -> ( Exp ->  Exp)
lrE m i e | bALANCED_MODE = ascendFromLeaf
                            (ConE 'L @@@)
                            (ConE 'R @@@)

lrE 1 0 e = e
lrE m 0 e = ConE 'L @@@ e
lrE m i e = ConE 'R @@@ lrE (m-1) (i-1) e

guardEmptyData :: [Con] -> AppliedTyCon -> RQ ()
guardEmptyData [] atc = fail ("Empty types not supported yet ("++
                              show (fromAppliedTyCon atc))
guardEmptyData _ atc = return ()

-- helper t = do
--   Right (AppliedTyCon n args) <- liftq (toAppliedTyCon t)
--   let prefix = "Prf_"
--   str <- if n == ''[]
--    then do
--      Right (AppliedTyCon n1 _) <- liftq (toAppliedTyCon (head args))
--      return ("T("++prefix++"List"++nameBase n1
--              ++",["++pprintUnqual (head args)++"])")
--    else
--       return ("T("++prefix++nameBase n
--               ++","++pprintUnqual t++")")

--   liftq . runIO $ appendFile "dump.dump" (str++"\n")

noSigE :: Exp -> Type -> Exp
x `noSigE` y = x

-- makeSanityChecks :: RQ [Dec]
-- makeSanityChecks = concat <$> foreachType makeSanityCheck

-- makeSanityCheck :: (AppliedTyCon,String) -> RQ [Dec]
-- makeSanityCheck (atc,s) = do
--   famname <- mkName <$> asks indexGadtName
--   let
--       chkName = mkName ("sanityCheck"++s)
--   return [
--            SigD chkName (ConT famname @@ fromAppliedTyCon atc)
--          , ValD (VarP chkName)
--                 (NormalB (ConE (mkName s)))
--                 []
--          ]