{-# LANGUAGE
  TemplateHaskell
  ,MultiParamTypeClasses
  ,FlexibleInstances
  ,FlexibleContexts
  ,RankNTypes
  ,UndecidableInstances
  ,DeriveDataTypeable #-}

{-# OPTIONS_GHC -funbox-strict-fields -fno-warn-incomplete-patterns #-}

module Data.AdaptiveTuple.TH (
 makeDatas
 ,makeData
 ,makeReify
 ,deriveInstances
 ,deriveFunctor
 ,deriveApplicative
 ,deriveAdaptive
 )
where

import Data.AdaptiveTuple.AdaptiveTuple
import Language.Haskell.TH
import qualified Data.TypeLevel.Num as T
import Data.Data
import Control.Monad
import Control.Applicative
import Control.Arrow

checkStrict :: Strict -> Bool
checkStrict IsStrict = True
checkStrict _        = False

-- template type for Q Decls
data T1 s a = T1 a

-- |Generate a reification function
makeReify :: Integer -> Q [Dec]
makeReify maxn = do
  let fname = mkName $ "reifyTuple" ++ show maxn
  let basedecl = [d| rf :: forall el r.Int -> [el] -> (forall c s. (AdaptiveTuple c s, T.Nat s) => c s el -> r) -> r; rf 0 xs f = f (toATuple xs :: T1 s el) |]
  d <- basedecl
  let [(SigD _fname typ), _] = d
  let mkclause n = do
        ([xsPat,fPat],[xsExp,fExp]) <- genPE 2
        let tupN = mkName $ "ATuple" ++ show n
        let dN   = mkName $ "D" ++ show n
        let atName = [| undefined :: $(appT (appT (conT tupN) (conT dN)) (conT ''() )) |]
        clause [litP (IntegerL n),xsPat,fPat]
         (normalB [| $fExp (helper $atName $xsExp)|]) []
  let defclause = do
        ([nPat,xsPat,fPat],[nExp,xsExp,fExp]) <- genPE 3
        clause [nPat,xsPat,fPat]
         (normalB [| T.reifyIntegral $nExp (\n' -> $fExp (makeListTuple n' $xsExp))|]) []
  cls <- mapM mkclause [0..maxn]
  defcl <- defclause
  return [SigD fname typ, FunD fname (cls++[defcl])]

helper :: (AdaptiveTuple c s, T.Nat s) => c s x -> [el] -> c s el 
helper _ = toATuple

-- |Generate "ATupleN" ... "ATupleM"
makeDatas :: Strict -> Int -> Int -> Q [Dec]
makeDatas strict m n = liftM concat $ mapM (makeData strict) [m..n]

-- |Given a value n >1, create data value "ATupleN"
makeData :: Strict -> Int -> Q [Dec]
makeData strict n = do
  let dN = mkName $ "ATuple" ++ show n
  let d' = if checkStrict strict
              then [d| data MX s a = MX {-# UNPACK #-} !a deriving (Show, Eq, Typeable, Data)|]
              else [d| data MX s a = MX a deriving (Show, Eq, Typeable, Data)|]
  d <- d'
  let [DataD [] _mx tvars [NormalC _mx' [cfield]] ders] = d
  return [DataD [] dN tvars [NormalC dN (replicate n cfield)] ders]

-- |Generate Functor, Applicative, and AdaptiveTuple instances for type (t s)
deriveInstances :: Name -> Name -> Q [Dec]
deriveInstances t s = do
  fs <- deriveFunctor t s
  apps <- deriveApplicative t s
  adpts <- deriveAdaptive t s
  return $ fs ++ apps ++ adpts

-- |derive Functor instance for type (t s)
deriveFunctor :: Name -> Name -> Q [Dec]
deriveFunctor t s = do
  TyConI (DataD _ _ _ constructors _) <- reify t
  tT <- conT t                                      -- tuple constructor
  sT <- conT s                                      --type-level size number
  d <- [d| instance Functor (T1 s) where fmap _ x = x|]
  let fmapClause (NormalC name fields) = do
        (fP:pats, fE:vars) <- genPE (1+length fields)
        clause (fP:[conP name pats])
           (normalB (appsE (conE name : map (appE fE) vars))) []
  let [InstanceD [] (AppT fmapt _) [FunD fmapf _clause]] = d
  funs <- funD fmapf (map fmapClause constructors)
  return [InstanceD [] (AppT fmapt (AppT tT sT)) [funs]]

-- |Generate Applicative instance for type (t s)
deriveApplicative :: Name -> Name -> Q [Dec]
deriveApplicative t s = do
  TyConI (DataD _ _ _ constructors _) <- reify t
  tT <- conT t
  sT <- conT s
  d <- [d| instance Functor (T1 s) => Applicative (T1 s) where pure a = T1 a; (T1 a) <*> (T1 b) = T1 (a b)|]
  let pureClause (NormalC name fields) = do
        (aP, [aE]) <- genPE 1
        clause aP (normalB (appsE (conE name:replicate (length fields) aE))) []
  let appClause (NormalC name fields) = do
        (aPats, aVars) <- genPE (length fields)
        (bPats, bVars) <- genPE (length fields)
        let pats = [conP name aPats, conP name bPats]
        clause pats (normalB (appsE (conE name:zipWith appE aVars bVars))) []
  let [InstanceD _ (AppT appt _) [FunD puref _, FunD appf _]] = d
  purefuncs <- funD puref (map pureClause constructors)
  appfuncs <- funD appf (map appClause constructors)
  return [InstanceD [] (AppT appt (AppT tT sT)) [purefuncs, appfuncs]]

-- |Generate AdaptiveTuple instance for type (t s)
deriveAdaptive :: Name -> Name -> Q [Dec]
deriveAdaptive t s = do
  TyConI (DataD _ _ _ constructors _) <- reify t
  tT <- conT t
  sT <- conT s
  d <- [d| instance (T.Nat s, Applicative (T1 s)) => AdaptiveTuple T1 s where getIndex _ _ = undefined; setIndex _ _ c = c; mapIndex _ _ c = c; toATuple _ = undefined; fromATuple _ = []; sequenceAT _ = undefined|]
  let makeClauseOut n pf bf = return $ map ((\(x,y) -> clause x y []) .
                                (pf &&& bf)) [0..n]
  let getClauses (NormalC name fields) = do
        (aP, aV) <- genPE (length fields)
        ([eP],[_eV]) <- genPE 1
        let getPats n = [conP name aP, litP (integerL (fromIntegral n))]
        let getBody = normalB . (aV !!)
        let errC = clause [wildP, eP] (normalB [| oObExcp "getIndex" |]) []
        c1 <- makeClauseOut (length fields - 1) getPats getBody
        return (c1 ++ [errC])
  let setClauses (NormalC name fields) = do
        ([elP,eP], [elV,_eV]) <- genPE 2
        (aP, aV) <- genPE (length fields)
        let getPats n = [litP (integerL (fromIntegral n)), elP, conP name aP]
        let getBody n = normalB $ appsE (conE name:replaceAt aV n elV)
        let errC = clause [eP, wildP, wildP] (normalB [| oObExcp "setIndex" |]) []
        c1 <- makeClauseOut (length fields - 1) getPats getBody
        return (c1 ++ [errC])
  let mapClauses (NormalC name fields) = do
        ([fP,eP], [fV,_eV]) <- genPE 2
        (aP, aV) <- genPE (length fields)
        let getPats n = [fP, litP (integerL (fromIntegral n)), conP name aP]
        let getBody n = normalB $ appsE
                          (conE name:replaceAt aV n (appE fV (aV !! n)))
        let errC = clause [wildP, eP, wildP] (normalB [| oObExcp "mapIndex" |]) []
        c1 <- makeClauseOut (length fields - 1) getPats getBody
        return (c1 ++ [errC])
  let toClauses (NormalC name fields) = do
        (aP, aV) <- genPE (length fields)
        let pats = foldr (flip infixP '(:)) wildP aP
        let c1 = clause [pats] (normalB $ appsE (conE name:aV)) []
        let c2 = clause [wildP] (normalB [| insExcp |]) []
        return [c1,c2]
  let fromClause (NormalC name fields) = do
        (aP, aV) <- genPE (length fields)
        clause [conP name aP] (normalB $ listE aV) []
  let seqClause (NormalC name fields) = do
        (aP, aV) <- genPE (length fields)
        let step vals acc = case vals of
              (mx:xs) -> [| $mx >>= \x -> $(step xs (appE acc [|x|])) |]
              _      -> [|return $acc|]
        clause [conP name aP] (normalB (step aV (conE name))) []
  let [InstanceD _ (AppT (AppT adtT _) _) [FunD getF _, FunD setF _, FunD mapF _, FunD toATF _, FunD fromATF _, FunD seqATF _]] = d
  let newty = AppT (AppT adtT tT) sT
  getters <- mapM getClauses constructors >>= (funD getF . concat)
  setters <- mapM setClauses constructors >>= (funD setF . concat)
  maps    <- mapM mapClauses constructors >>= (funD mapF . concat)
  tos     <- mapM toClauses  constructors >>= (funD toATF . concat)
  froms   <- funD fromATF (map fromClause constructors)
  seqAT   <- funD seqATF  (map seqClause constructors)
  return [InstanceD [] newty [getters,setters,maps,tos,froms,seqAT]]

-- |Create a list of n Names, with the associated Pat's and Exp's
genPE :: Int -> Q ([PatQ], [ExpQ])
genPE n = do
  ids <- replicateM n (newName "x")
  return (map varP ids, map varE ids)

replaceAt :: [a] -> Int -> a -> [a]
replaceAt xs n el = let (f,l) = splitAt n xs in f ++ (el:tail l)