{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS -XEmptyDataDecls #-}

module Data.AspectAG.Derive (deriveAG, attLabel, attLabels, typeList) where

import Language.Haskell.TH

import Data.Set (Set)
import Data.List (isPrefixOf)
import qualified Data.Set as S

import Data.AspectAG

data UserType  = UserD Name [Name] [Con]
type TypeDecls = (Set Name, [Dec])


declareLabel :: Name -> Name -> TypeQ -> Q [Dec]
declareLabel ndata nlabel t = do 
            dtl <- dataD (cxt []) ndata [] [] []
            lbl <- declareFnLabel nlabel t
            return $ dtl:lbl

declareFnLabel ::  Name -> TypeQ -> Q [Dec]
declareFnLabel nlabel t = do 
            sgn <- sigD nlabel (appT (conT $ mkName "Proxy") t)  
            let pxy = normalB [| proxy |]
            lbl <- funD nlabel [clause [] pxy []]
            return [sgn,lbl]

attLabel ::  String -> Q [Dec]
attLabel att = declareLabel attn (mkName att) (conT $ attn) 
  where
      attn  = mkName $ "Att_" ++ att
      

attLabels ::  [String] -> Q [Dec]
attLabels = liftM concat . mapM attLabel


chLabels ::  [Name] -> [Type] -> Q [Dec]
chLabels ns ts = (liftM concat) $ zipWithM label ns ts
  where
      label n t = declareLabel (chTName n) (chName n) (tyLabel (chTName n) t) 
      tyLabel n t = appT (appT (conT $ mkName "(,)") (conT n)) (return t) 

chName,chTName,ntName,prdName,prdTName ::  Name -> Name
chName   cn = mkName $ "ch_" ++ nameBase cn 
chTName  cn = mkName $ "Ch_" ++ nameBase cn 
ntName   cn = mkName $ "nt_" ++ nameBase cn 
prdName  cn = mkName $ "p_"  ++ nameBase cn 
prdTName cn = mkName $ "P_"  ++ nameBase cn 


deriveAG :: Name -> Q [Dec]
deriveAG n = do
              (_,decl) <- derive n (S.empty,[]) --eval)
              return decl




semName ::  Name -> Name
semName t = mkName ("sem_"++(nameBase t))

derive :: Name -> TypeDecls -> Q TypeDecls 
derive n (stn,decl) = 
    do
       info <- reify n 
       if (S.member n stn || not (isNT info))  
          then return (stn,decl)
          else let stn' = S.insert n stn
               in  do
                      (UserD _ _ lc) <- getUserType info
                      ((s,d),fc)   <- foldM deriveCons ((stn',decl),[]) lc
                      let semDecl = FunD (semName n) fc
                      nt <- declareFnLabel (ntName n) (conT $ n)
                      return (s,semDecl:(nt++d))

deriveCons :: (TypeDecls,[Clause]) -> Con -> Q (TypeDecls,[Clause])
deriveCons ((stn,decl),fc) c =                     
    do
      let (cht,chn,cn) = getCtx c
      (stn',decl') <- foldM (\td t -> deriveList (typeNames t) td) (stn,decl) cht
      conargs <- newNames cht
      body <- [| knit ($(aspV) # $(att cn)) $(childs cht chn conargs) |] 
      let semF = Clause (pat cn conargs) (NormalB body) []
      lp <- declareLabel (prdTName cn) (prdName cn) (conT $ prdTName cn)
      lc <- chLabels chn cht
      return ((stn',lp++lc++decl'),semF:fc)
  where
     newNames []     = return []
     newNames (_:as) = do
                        na  <- newName "x"
                        nas <- newNames as
                        return (na:nas)
     pat cn args = [aspP, ConP cn (map VarP args)] 
     aspP = VarP $ mkName "asp"
     aspV = varE $ mkName "asp"
     att cn  = varE $ prdName cn
     ch  cn  = varE $ chName cn
     childs []     _      _      = [| emptyRecord |]
     childs (t:ts) (n:ns) (p:ps) = case (typeName1 t) of
                                     Just tn -> [| $(ch n) .=. $(chFun tn p)  .*. $(childs ts ns ps) |]
                                     Nothing -> [| $(childs ts ns ps) |]
     childs _      _      _      = error "Impossible case!!"
     deriveList ns td = foldM (flip $ derive) td ns
     chFun tn n =  
              do
                  i <- reify (tn)
                  if isNT i
                   then [| $(varE (semName tn)) $(aspV) $(varE n) |]
                   else [| ( \(Record HNil) -> $(varE n) ) |]
           




getUserType :: Info -> Q UserType
getUserType info = do
    case info of
        TyConI d -> case d of
            (DataD     _ uname args cs  _)  -> return $ UserD uname args cs 
            (NewtypeD  _ uname args c   _)  -> return $ UserD uname args [c]
            _                               -> scopeError
        _ -> scopeError
    where scopeError = error $ "Can only be used on algebraic datatypes"

getCtx :: Con -> ([Type],[Name], Name) 
getCtx (RecC           name args) = (map thd args, map fst' args, name)
getCtx (NormalC name []) = ([],[],name)
getCtx (NormalC name _)  = error $ "Constructor " ++ (show name) ++ " is not a record."
getCtx (InfixC _ name _) = error $ "Constructor " ++ (show name) ++ " is not a record."	
getCtx _ = error $ "Cannot derive a 'forall' constructor."	

thd :: (a, b, c) -> c
thd (_, _, c) = c

fst' :: (a, b, c) -> a
fst' (a, _, _) = a



isNT ::  Info -> Bool
isNT (PrimTyConI _ _ _)          =  False
isNT (TyConI (DataD _ n _ _ _))  =  not $ isPrefixOf "GHC" (show n) -- GHC types can't be non-terminals
isNT (TyConI (TySynD _ _ _))     =  False   -- type synonyms to escape
isNT  _                          =  True


typeNames :: Type -> [ Name ]
typeNames t = case t of
    VarT _                   -> [  ]
    ConT conname             -> [ conname ]
    AppT t1 t2               -> typeNames t1 ++ typeNames t2
    _                        -> error $ "Not valid type " ++ (show t)


typeName1 :: Type -> Maybe Name 
typeName1 t = case t of
    VarT _                   -> Nothing
    ConT conname             -> Just conname
    AppT t1 _                -> typeName1 t1
    _                        -> error $ "Not valid type " ++ (show t)



typeList :: String -> String -> Q [Dec]
typeList nData nArg =  
   do
     let ddef = DataD [] dataName [] [ RecC consName [(hdName, NotStrict, ConT argName), (tlName ,NotStrict, ConT dataName) ]
                                     , NormalC nilName []] []
     return [ ddef ]
 where
  dataName = mkName nData
  argName  = mkName nArg
  consName = mkName $ "Cons" ++ nData
  nilName  = mkName $ "Nil"  ++ nData
  hdName   = mkName $ "hd"   ++ nData
  tlName   = mkName $ "tl"   ++ nData