{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE GADTs           #-}
{-# LANGUAGE KindSignatures  #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MonomorphismRestriction #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS -fwarn-missing-signatures #-}



{-|
Example usage:

@
import Generics.MultiRec
import Generics.MultiRec.TH.Alt
import Data.Tree

data TheFam :: (* -> *) where
              Tree_Int   :: TheFam   (Tree Int)
              Forest_Int :: TheFam (Forest Int)
                           
$('deriveEverything'
  ('DerivOptions'
   [ ( [t| Tree   Int |],   \"Tree_Int\" )
   , ( [t| Forest Int |], \"Forest_Int\" )
   ]
   \"TheFam\"
   (\\t c -> \"CONSTRUCTOR_\" ++ t ++ \"_\" ++ c)
   \"ThePF\"
   True
  )
 )

type instance 'PF' TheFam = ThePF
@
-}

module Generics.MultiRec.TH.Alt
  ( 
    DerivOptions(..),
    deriveEverything,
  ) where

import Generics.MultiRec.TH.Alt.DerivOptions(DerivOptions(..))
import THUtils(AppliedTyCon, (@@), (@@@), toAppliedTyCon,
               fromAppliedTyCon, atc2constructors, pprintUnqual, sMatch, sClause,
               cleanConstructorName)
import BalancedFold(balancedFold, ascendFromLeaf)
import MonadRQ(RQ, message, messageReport, liftq, foreachType,
               foreachTypeNumbered, runRQ)
import Generics.MultiRec.Base((:>:)(..), C(..), El(..), (:=:)(..),
                              I0(I0), (:*:)(..), (:+:)(..), EqS(..), Fam(..), I(I), K(K), U(..))
import Generics.MultiRec.Constructor(Associativity(..), Fixity(..),
                                     Constructor(..))
import Control.Monad.Reader(Monad(return, fail, (>>)), Functor(..),
                            (=<<), mapM, sequence, liftM, zipWithM, asks)
import Language.Haskell.TH.Syntax(Lift(..))
import Language.Haskell.TH(newName, mkName, wildP, clause, conE,
                           appE, normalB, funD, dataD, instanceD, cxt, conT, appT,
                           Exp(VarE, SigE, LamE, ConE, CaseE, AppE), Match, Clause, Q,
                           Pat(WildP, VarP, ConP), TypeQ, Type(ConT),
                           Dec(TySynD, InstanceD, FunD), Name, 
                           Con(RecC, NormalC, InfixC),
                           FixityDirection(..),
                           Info(DataConI), nameBase, reify, stringE)
import Data.Map(lookup, elems)
import Control.Applicative((<$>))

import qualified Data.Map as Map
import qualified Language.Haskell.TH as TH


bALANCED_MODE :: Bool
bALANCED_MODE = False

 
 

deriveEverything :: DerivOptions [(TypeQ, String)] -> Q [Dec]
deriveEverything opts = do
  -- let x | mkSanityChecks opts = makeSanityChecks
  --       | otherwise = return []
                      
  runRQ (concat <$> sequence [deriveConstructors, deriveFamily]) opts
  

-- | Given a list of datatype names, derive datatypes and 
-- instances of class 'Constructor'.

deriveConstructors :: RQ [Dec]
deriveConstructors =
  concat <$> foreachType constrInstance

-- | Given the name of the index GADT, the names of the
-- types in the family, and the name (as string) for the
-- pattern functor to derive, generate the 'Ix' and 'PF'
-- instances. /IMPORTANT/: It is assumed that the constructors
-- of the GADT have the same names as the datatypes in the
-- family.

deriveFamily :: RQ [Dec]
deriveFamily =
  do
    pf  <- derivePF
    el  <- deriveEl
    fam <- deriveFam
    eq  <- deriveEqS
    return $ pf ++ el ++ fam ++ eq

-- | Derive only the 'PF' instance. Not needed if 'deriveFamily'
-- is used.

derivePF :: RQ [Dec]
derivePF =
    do
      branches <- foreachType pfType
      pfn <- asks patternFunctorName
      let 
          pf = [TySynD (mkName pfn) [] (sumT branches)]
                

      famName <- asks indexGadtName
                
      -- message $
      --       (   "*** The pattern functor is:\n"
      --        ++ pprint (cutNames pf)
      --        ++ "\n\n\n"
      --       )

      messageReport (
                "Reminder: Don't forget to add this line manually:\n"
             ++ "    type instance PF "++famName++" = "++pfn
            )
      
      return pf

             
sumT :: [Type] -> Type
sumT | bALANCED_MODE = balancedSumT
     | otherwise = rightSumT
                   
rightSumT :: [Type] -> Type
rightSumT = foldr1 plusT
balancedSumT :: [Type] -> Type
balancedSumT = balancedFold plusT
      
plusT :: Type -> Type -> Type
plusT a b = ConT ''(:+:) @@ a @@ b
            
prodT :: [Type] -> Type
prodT = foldr1 timesT
            
timesT :: Type -> Type -> Type
timesT a b = ConT ''(:*:) @@ a @@ b
             
-- | Derive only the 'El' instances. Not needed if 'deriveFamily'
-- is used.

deriveEl :: RQ [Dec]
deriveEl = foreachType elInstance
           
indexGadtType :: RQ Type
indexGadtType = ConT . mkName <$> asks indexGadtName

-- | Dervie only the 'Fam' instance. Not needed if 'deriveFamily'
-- is used.

deriveFam :: RQ [Dec]
deriveFam =
  do
    fcs <- liftM concat $ foreachTypeNumbered mkFrom
    tcs <- foreachTypeNumbered mkTo
    s <- indexGadtType
    return [
      InstanceD [] (ConT ''Fam @@ s)
                    [FunD 'from fcs, FunD 'to tcs]
     ]

-- | Derive only the 'EqS' instance. Not needed if 'deriveFamily'
-- is used.

deriveEqS :: RQ [Dec]
deriveEqS = do
  s <- indexGadtType
  ns <- elems <$> asks familyTypes
  
  return [
    InstanceD [] (ConT ''EqS @@ s)
      [FunD 'eqS (trues ns ++ falses ns)]
   ]
  where
    trueClause n = sClause [ConP (mkName n) [], ConP (mkName n) []] 
                   ((ConE 'Just `AppE` ConE 'Refl)) 
    falseClause  = sClause [WildP,  WildP]        
                   ((ConE 'Nothing)) 
    trues ns     = fmap trueClause ns
    falses ns    = if length (trues ns) == 1 then [] else [falseClause]

constrInstance :: (AppliedTyCon,String) -> RQ [Dec]
constrInstance (atc,s) =
  do
    cs <- liftq (atc2constructors atc)
    -- runIO (print i)
         
    ds <- mapM (mkData s) cs
    is <- mapM (mkInstance s) cs
    return $ ds ++ is

stripRecordNames :: Con -> Con
stripRecordNames (RecC n f) =
  NormalC n (fmap (\(_, s, t) -> (s, t)) f)
stripRecordNames c = c

-- TODO: Handle colons in the constructor name
mkData :: String -> Con -> RQ Dec
mkData s (NormalC n _) = do
  modifier <- asks constructorNameModifier
  liftq $ dataD (cxt []) (mkName . modifier s . cleanConstructorName . nameBase $ n) [] [] [] 
mkData s r@(RecC _ _) =
  mkData s (stripRecordNames r)
mkData s (InfixC t1 n t2) =
  mkData s (NormalC n [t1,t2])

instance Lift Fixity where
  lift Prefix      = conE 'Prefix
  lift (Infix a n) = conE 'Infix `appE` [| a |] `appE` [| n |]

instance Lift Associativity where
  lift LeftAssociative  = conE 'LeftAssociative
  lift RightAssociative = conE 'RightAssociative
  lift NotAssociative   = conE 'NotAssociative

mkInstance :: String -> Con -> RQ Dec
mkInstance s (NormalC n _) =
    do
      modifier <- asks constructorNameModifier
      let n' = modifier s . cleanConstructorName . nameBase $ n

      liftq $
       instanceD (cxt []) 
        (appT (conT ''Constructor) (conT . mkName $ n'))
        [funD 'conName [clause [wildP] (normalB (stringE (nameBase n))) []]]
        
mkInstance s r@(RecC _ _) =
  mkInstance s (stripRecordNames r)
mkInstance s (InfixC t1 n t2) =
    do
      modifier <- asks constructorNameModifier
      let n' = modifier s . cleanConstructorName . nameBase $ n

      i <- liftq (reify n)
      let fi = case i of
                 DataConI _ _ _ f -> convertFixity f
                 _ -> Prefix
      liftq $ 
       instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName n'))
        [funD 'conName   [clause [wildP] (normalB (stringE (nameBase n))) []],
         funD 'conFixity [clause [wildP] (normalB [| fi |]) []]]
  where
    convertFixity (TH.Fixity n d) = Infix (convertDirection d) n
    convertDirection InfixL = LeftAssociative
    convertDirection InfixR = RightAssociative
    convertDirection InfixN = NotAssociative

pfType :: (AppliedTyCon,String) -> RQ Type
pfType (atc,s) =
    do
      -- runIO $ putStrLn $ "processing " ++ show n
      cs <- liftq (atc2constructors atc)
           
      guardEmptyData cs atc

      b <- sumT <$> mapM (pfCon s) cs
              
      return $
       ConT ''(:>:) @@ b @@ fromAppliedTyCon atc
            

            
pfCon :: String -> Con -> RQ Type
pfCon s (NormalC n fs) = do
  modifier <- asks constructorNameModifier
  let n' = mkName . modifier s . cleanConstructorName . nameBase $ n
           
  fieldResults <- mapM (pfField . snd) fs
           
  let rest = 
          case fs of
               [] -> ConT ''U
               _ -> prodT fieldResults
  
  return $
    ConT ''C @@ ConT n' @@ rest
         

pfCon s r@(RecC _ _) =
  pfCon s (stripRecordNames r)
pfCon s (InfixC t1 n t2) =
    pfCon s (NormalC n [t1,t2])

pfField :: Type -> RQ Type
pfField t = ifInFamily t (ConT ''I @@ t) (ConT ''K @@ t)

lookupFam :: Type -> RQ (Maybe String)
lookupFam t = 
    do
          
      ts <- asks familyTypes
      t' <- liftq $ toAppliedTyCon t
      let res = case t' of
                  Right t'' -> Map.lookup t'' ts
                  Left _ -> Nothing
                           
      -- message ("familyTypes = "++show ts)
      -- message ("lookupFam "++show t'++" = "++show res)
      return res 
            
ifInFamily :: Type -> a -> a -> RQ a
ifInFamily n x y = ifInFamily' n (return x) (return y)

ifInFamily' :: Type -> RQ a -> RQ a -> RQ a
ifInFamily' t x y = maybe y (const x) 
                          =<< lookupFam t
                

elInstance :: (AppliedTyCon,String) -> RQ Dec
elInstance x@(atc,_) = do
  s <- indexGadtType
  prf <- mkProof x
  return $ InstanceD [] (ConT ''El @@ s @@ fromAppliedTyCon atc) [prf]

mkFrom :: Int -> Int -> (AppliedTyCon,String) -> RQ [Clause]
mkFrom m i (atc,s) =
    do
      -- ns <- fmap mkName . elems <$> asks familyTypes
      -- runIO $ putStrLn $ "processing " ++ show n
      cs <- liftq (atc2constructors atc)
           
      let 
          wrapE = (\e -> lrE m i (ConE 'Tag @@@ e))
                    
          dn = mkName s -- (nameBase n)
               
      zipWithM (fromCon wrapE dn (length cs)) [0..] cs
              

mkTo :: Int -> Int -> (AppliedTyCon,String) -> RQ Clause
mkTo m i (atc,s) =
    do
      -- ns <- fmap mkName . elems <$> asks familyTypes
      -- runIO $ putStrLn $ "processing " ++ show n
      cs <- liftq (atc2constructors atc)
      pfname <- mkName <$> asks patternFunctorName
             
      let 
          -- typeOfLamE = ArrowT @@
          --              (ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc) @@
          --              (fromAppliedTyCon atc)
             
      matchesOfCons <-
          zipWithM (toCon (length cs) atc) [0..] cs

                   
      xvar <- liftq (newName "x")
      convar <- liftq (newName "con")
               

               
      typeOfConvar <- 
          do
            t0 <- pfType (atc,s) 
            return (t0 @@ ConT ''I0 @@ fromAppliedTyCon atc)
               
                
      let
          typeOfXvar = ConT pfname @@ ConT ''I0 @@ fromAppliedTyCon atc
                      
                       
          body = LamE [VarP xvar]
                  (CaseE (VarE xvar `SigE` typeOfXvar)
                   [sMatch (lrP m i (VarP convar))
                    (CaseE (VarE convar `SigE` typeOfConvar)
                     matchesOfCons)
                   ]
                  )
                
      return (sClause 
              [ConP (mkName s) []]
              body
             )
      

                          

mkProof :: (AppliedTyCon,String) -> RQ Dec
mkProof (_,s) = return $
  FunD 'proof [sClause [] (ConE (mkName s)) ]

fromCon :: (Exp -> Exp) -> Name -> Int -> Int -> Con -> RQ Clause
fromCon wrap n m i (NormalC cn []) = return $
    -- Nullary constructor case
    sClause
      [ConP n [], ConP cn []]
      (wrap . lrE m i 
                   $ ConE 'C @@@ ConE 'U)
                                                            
fromCon wrap n m i (NormalC cn fs) =
    do
      rhs <- zipWithM fromField [0..] (snd <$> fs)
      
      return $
        sClause
          [ ConP n [], 
            ConP cn (fmap (VarP . field) [0..length fs - 1])
          ]
          (wrap . lrE m i 
                       $ ConE 'C @@@ foldr1 prod rhs) 
  where
    prod x y = ConE '(:*:) @@@ x @@@ y
               
fromCon wrap n m i r@(RecC _ _) =
  fromCon wrap n m i (stripRecordNames r)
          
fromCon wrap n m i (InfixC t1 cn t2) =
  fromCon wrap n m i (NormalC cn [t1,t2])

toCon ::      
                                       Int -- ^ Number of constructors
                                     -> AppliedTyCon
                                     -> Int -- ^ Index of this constructor
                                     -> Con
                                     -> RQ Match
toCon m atc i (NormalC cn []) = return $
    -- Nullary constructor case
    sMatch
      (ConP 'Tag [lrP m i $ ConP 'C [ConP 'U []]])
    
      (
               ConE cn
               -- SigE (ConE cn) (fromAppliedTyCon atc) 
      ) 
      
                          
toCon m atc i (NormalC cn fs) =
    -- runIO (putStrLn ("constructor " ++ show ix)) >>
    do
      lhs <- zipWithM toField [0..] (fmap snd fs)
            
      return $ 
        sMatch
         (ConP 'Tag [lrP m i $ ConP 'C [foldr1 prod lhs]])
      
         (
                 -- SigE (
                       foldl AppE (ConE cn) 
                       (fmap (VarE . field) [0..length fs - 1]) 
                      -- )
                      -- (fromAppliedTyCon atc)
         )
  where
    prod x y = ConP '(:*:) [x,y]
               
toCon m atc i r@(RecC _ _) =
  toCon m atc i (stripRecordNames r) 
        
toCon m atc i (InfixC t1 cn t2) =
  toCon m atc i (NormalC cn [t1,t2]) 

fromField :: Int -> Type -> RQ Exp
fromField nr t = 
    ifInFamily' t 
    (return (ConE 'I @@@ (ConE 'I0 @@@ VarE (field nr))))
    
    (message ("* Info: Type not in family: " ++ pprintUnqual t) >>
     -- helper t >>
     return (ConE 'K @@@ VarE (field nr)))
         
toField :: Int -> Type -> RQ Pat
toField nr t =
    ifInFamily t
    (ConP 'I [ConP 'I0 [VarP (field nr)]])
    (ConP 'K [VarP (field nr)])

field :: Int -> Name
field n = mkName $ "f" ++ show n

lrP :: Int -> Int -> ( Pat ->  Pat)
lrP m i p | bALANCED_MODE = ascendFromLeaf 
                            (ConP 'L . (:[] {- robot monkey -}))      
                            (ConP 'R . (:[]))      
                            p
                            m
                            i

lrP 1 0 p = p
lrP m 0 p = ConP 'L [p]
lrP m i p = ConP 'R [lrP (m-1) (i-1) p]

lrE :: Int -> Int -> ( Exp ->  Exp)
lrE m i e | bALANCED_MODE = ascendFromLeaf
                            (ConE 'L @@@)
                            (ConE 'R @@@)
                            e
                            m
                            i


lrE 1 0 e = e
lrE m 0 e = ConE 'L @@@ e
lrE m i e = ConE 'R @@@ lrE (m-1) (i-1) e

            
guardEmptyData :: [Con] -> AppliedTyCon -> RQ ()
guardEmptyData [] atc = fail ("Empty types not supported yet ("++
                              show (fromAppliedTyCon atc))
guardEmptyData _ atc = return ()


-- helper t = do
--   Right (AppliedTyCon n args) <- liftq (toAppliedTyCon t)
                        
--   let prefix = "Prf_"
                        
--   str <- if n == ''[]
--    then do
--      Right (AppliedTyCon n1 _) <- liftq (toAppliedTyCon (head args))
--      return ("T("++prefix++"List"++nameBase n1
--              ++",["++pprintUnqual (head args)++"])")
--    else
--       return ("T("++prefix++nameBase n
--               ++","++pprintUnqual t++")")

--   liftq . runIO $ appendFile "dump.dump" (str++"\n")

noSigE :: Exp -> Type -> Exp
x `noSigE` y = x



-- makeSanityChecks :: RQ [Dec]
-- makeSanityChecks = concat <$> foreachType makeSanityCheck

-- makeSanityCheck :: (AppliedTyCon,String) -> RQ [Dec]
-- makeSanityCheck (atc,s) = do
--   famname <- mkName <$> asks indexGadtName
            
--   let
--       chkName = mkName ("sanityCheck"++s)
            
--   return [
--            SigD chkName (ConT famname @@ fromAppliedTyCon atc)
--          , ValD (VarP chkName)
--                 (NormalB (ConE (mkName s)))
--                 []
--          ]