{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS -XEmptyDataDecls #-}

module Data.AspectAG.Derive (deriveAG, attLabel, attLabels) where

import Language.Haskell.TH

import Data.Set (Set)
import Data.List (isPrefixOf)
import qualified Data.Set as S

--import Data.HList

import Data.AspectAG

data UserType  = UserD Name [Name] [Con]
type TypeDecls = (Set Name, [Dec])

declareLabel :: Name -> Name -> TypeQ -> Q [Dec]
declareLabel ndata nlabel t = do 
            dtl <- dataD (cxt []) ndata [] [] []
            lbl <- declareFnLabel nlabel t
            return $ dtl:lbl

declareFnLabel ::  Name -> TypeQ -> Q [Dec]
declareFnLabel nlabel t = do 
            sgn <- sigD nlabel (appT (conT $ mkName "Proxy") t)  
            let pxy = normalB [| proxy |]
            lbl <- funD nlabel [clause [] pxy []]
            return [sgn,lbl]

attLabel ::  String -> Q [Dec]
attLabel att = declareLabel attn (mkName att) (conT $ attn) 
      attn  = mkName $ "Att_" ++ att

attLabels ::  [String] -> Q [Dec]
attLabels = liftM concat . mapM attLabel

chLabels ::  [Name] -> [Type] -> Q [Dec]
chLabels ns ts = (liftM concat) $ zipWithM label ns ts
      label n t = declareLabel (chTName n) (chName n) (tyLabel (chTName n) t) 
      tyLabel n t = appT (appT (conT $ mkName "(,)") (conT n)) (return t) 

chName,chTName,ntName,prdName,prdTName ::  Name -> Name
chName   cn = mkName $ "ch_" ++ nameBase cn 
chTName  cn = mkName $ "Ch_" ++ nameBase cn 
ntName   cn = mkName $ "nt_" ++ nameBase cn 
prdName  cn = mkName $ "p_"  ++ nameBase cn 
prdTName cn = mkName $ "P_"  ++ nameBase cn 

deriveAG :: Name -> Q [Dec]
deriveAG n = do
              (_,decl) <- derive n (S.empty,[]) --eval)
              return decl

semName ::  Name -> Name
semName t = mkName ("sem_"++(nameBase t))

derive :: Name -> TypeDecls -> Q TypeDecls 
derive n (stn,decl) = 
       info <- reify n 
       if (S.member n stn || primitive info)  
          then return (stn,decl)
          else let stn' = S.insert n stn
               in  do
                      (UserD _ _ lc) <- getUserType info
                      ((s,d),fc)   <- foldM deriveCons ((stn',decl),[]) lc
                      let semDecl = FunD (semName n) fc
                      nt <- declareFnLabel (ntName n) (conT $ n)
                      return (s,semDecl:(nt++d))

deriveCons :: (TypeDecls,[Clause]) -> Con -> Q (TypeDecls,[Clause])
deriveCons ((stn,decl),fc) c =                     
      let (cht,chn,cn) = getCtx c
      (stn',decl') <- foldM (\td t -> derive (typeName t) td) (stn,decl) cht
      conargs <- newNames cht
      body <- [| knit ($(aspV) # $(att cn)) $(childs cht chn conargs) |] 
      let semF = Clause (pat cn conargs) (NormalB body) []
      lp <- declareLabel (prdTName cn) (prdName cn) (conT $ prdTName cn)
      lc <- chLabels chn cht
      return ((stn',lp++lc++decl'),semF:fc)
     newNames []     = return []
     newNames (_:as) = do
                        na  <- newName "x"
                        nas <- newNames as
                        return (na:nas)
     pat cn args = [aspP, ConP cn (map VarP args)] 
     aspP = VarP $ mkName "asp"
     aspV = varE $ mkName "asp"
     att cn  = varE $ prdName cn
     ch  cn  = varE $ chName cn
     childs []     _      _      = [| emptyRecord |]
     childs (t:ts) (n:ns) (p:ps) = [| $(ch n) .=. $(chFun (typeName t) p)  .*. $(childs ts ns ps) |]
     childs _      _      _      = error "Impossible case!!"
     chFun tn n =  
                  i <- reify (tn)
                  if primitive i
                   then [| ( \(Record HNil) -> $(varE n) ) |]
                   else [| $(varE (semName tn)) $(aspV) $(varE n) |]

getUserType :: Info -> Q UserType
getUserType info = do
    case info of
        TyConI d -> case d of
            (DataD     _ uname args cs  _)  -> return $ UserD uname args cs 
            (NewtypeD  _ uname args c   _)  -> return $ UserD uname args [c]
            _                               -> scopeError
        _ -> scopeError
    where scopeError = error $ "Can only be used on algebraic datatypes"

getCtx :: Con -> ([Type],[Name], Name) 
getCtx (RecC           name args) = (map thd args, map fst' args, name)
getCtx (NormalC name _) = error $ "Constructor " ++ (show name) ++ " is not a record."
getCtx (InfixC _ name _) = error $ "Constructor " ++ (show name) ++ " is not a record."	
getCtx _ = error $ "Cannot derive a 'forall' constructor."	

thd :: (a, b, c) -> c
thd (_, _, c) = c

fst' :: (a, b, c) -> a
fst' (a, _, _) = a

primitive ::  Info -> Bool
primitive (PrimTyConI _ _ _)          =  True
primitive (TyConI (DataD _ n _ _ _))  =  isPrefixOf "GHC" (show n) 
primitive (TyConI (TySynD _ _ _))     =  True   -- type synonyms to escape
primitive  _                          =  False

typeName :: Type -> Name
typeName t = case t of
    VarT varname             -> varname
    ConT conname             -> conname
    _                        -> error $ "Not valid type " ++ (show t)