{-# LANGUAGE FlexibleInstances, TemplateHaskell, CPP, PatternGuards #-}
module Happstack.State.ComponentTH
    ( mkMethods
    ) where

import Data.Char
import Language.Haskell.TH

import Happstack.State.Types
import Happstack.Data.Serialize
import Happstack.State.ComponentSystem

import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import Data.Maybe

import Data.List

import Data.Generics.Basics

nubCxt :: [TypeQ] -> Q [Type]
nubCxt tsQ
    = do ts <- cxt tsQ
         return $ nub ts

{-
  Error cases:
    not all state keys are tyvars.
    method not using the keys in either the args or the result.
    Checked: component not being a data declaration with a single constructor.
    Checked: method using a tyvar that isn't a key.
    Checked: method is not a function.
-}

{- | This function will derive, for the provided type name and the provided
   list of methods, an instance of Methods for the given type and necessary
   instance of UpdateEvent and QueryEvent to make the provided list of functions
   into Methods.
   This TH splice needs to be executed for any MACID state type in order to 
   have the necessary code for event handling generated.
   Ex:  @$(mkMethods ''Foo ['fooUpdate,'fooQuery'])@
-}
mkMethods :: Name -> [Name] -> Q [Dec]
mkMethods componentName componentMethods
    = do keys <- liftM (requireSimpleCon componentName) $ reify componentName
         methodInfos <- getMethodInfos keys componentMethods
         ds1 <- genEventInstances methodInfos
         let handlers = genComponentHandlers methodInfos
             stType = mkType componentName keys
             context' = mkKeyConstraints keys ++
                        concatMap (mkMethodConstraints keys) methodInfos
         ds2 <- instanceD (nubCxt context') (appT (conT ''Methods) stType) [ funD 'methods [clause [wildP] (normalB handlers) []] ]
         ds4 <- genMethodStructs [''Typeable] methodInfos
         ds5 <- genSerializeInstances methodInfos
         return (ds1 ++ [ds2] ++ ds4 ++  ds5 )

mkKeyConstraints :: [Name] -> [TypeQ]
mkKeyConstraints keys
    = [ appT (conT ''Typeable) (varT key) | key <- keys ] ++
      [ appT (conT ''Serialize) (varT key) | key <- keys ]

mkMethodConstraints :: [Name] -> MethodInfo -> [TypeQ]
mkMethodConstraints keys method
    = map return (substMethodContext method keys)

substMethodContext :: MethodInfo -> [Name] -> [Type]
substMethodContext method keys
    = let relation = zip (methodKeys method) keys
          worker (VarT old) | Just new <- lookup old relation
                   = VarT new
          worker (AppT l r) = AppT (worker l) (worker r)
          worker (ForallT c names t) = ForallT c names (worker t)
          worker t = t
      in map worker (methodContext method)

mkType :: Name -> [Name] -> TypeQ
mkType name args = foldl appT (conT name) (map varT args)

genSerializeInstances :: [MethodInfo] -> Q [Dec]
genSerializeInstances meths
    = liftM concat $ forM meths $ \method ->
      let constraints = nubCxt $ mkKeyConstraints (methodKeys method) ++ map return (methodContext method)
          upperMethod = upperName (methodName method)
          encode = do args <- replicateM (length (methodArgs method)) (newName "arg")
                      lamE [conP upperMethod $ map varP args ] $
                           doE $ [ noBindS $ appE (varE 'safePut) (varE arg) | arg <- args] ++
                                 [ noBindS [| return () |] ]
          decode = do args <- replicateM (length (methodArgs method)) (newName "arg")
                      doE $ [ bindS (varP arg) (varE 'safeGet) | arg <- args] ++
                            [ noBindS $ appE (varE 'return) $ foldl appE (conE upperMethod) $ map varE args ]
      in do s <- instanceD constraints
                   (appT (conT ''Serialize) (mkType (upperName (methodName method)) (methodKeys method)))
                   [funD 'putCopy [clause [] (normalB [| contain . $(encode) |]) []]
                   ,funD 'getCopy [clause [] (normalB [| contain $(decode) |]) []]]
            v <- instanceD constraints (appT (conT ''Version) (mkType (upperName (methodName method)) (methodKeys method))) []
            return [s,v]


{-
  [ Update $ \(SetComponent c) -> setComponent c
  , Query $ \GetComponent -> getComponent ]
-}
genComponentHandlers :: [MethodInfo] -> ExpQ
genComponentHandlers meths
    = do let localHandlers = flip map meths $ \method ->
                        let upName = upperName (methodName method)
                        in do args <- replicateM (length (methodArgs method)) (newName "arg")
                              appE (conE (methodEv method)) $
                                lamE [conP upName (map varP args)] $ foldl appE (varE (methodName method)) $ map varE args
             handlers = listE localHandlers
         handlers


genEventInstances :: [MethodInfo] -> Q [Dec]
genEventInstances methodsInfo
    = mapM genEventInstance methodsInfo

genEventInstance :: MethodInfo -> Q Dec
genEventInstance method
    = do let keys = methodKeys method
             eventType = foldl appT (conT (upperName (methodName method))) (map varT keys)
         instanceD (nubCxt $ [appT (conT ''Serialize) eventType
                             ,appT (conT ''Serialize) (return (methodResult method))]
                          ++ mkKeyConstraints keys
                          ++ mkMethodConstraints keys method
                   )
                   (appT (appT (conT (methodClass method)) eventType) (return (methodResult method)))
                   []


genMethodStructs :: [Name] -> [MethodInfo] -> Q [Dec]
genMethodStructs derv meths
    = liftM concat (mapM (genMethodStruct derv) meths)

-- FIXME: allow class constraints on keys.
genMethodStruct :: [Name] -> MethodInfo -> Q [Dec]
genMethodStruct derv method
    = do let c = NormalC (upperName (methodName method)) (zip (repeat NotStrict ) (methodArgs method))
         return [ DataD [] (upperName (methodName method)) (methodKeys method) [c] (derv) ]

upperName :: Name -> Name
upperName = mkName . upperFirst . nameBase

upperFirst :: String -> String
upperFirst (x:xs) = toUpper x : xs
upperFirst "" = error "ComponentTH.UpperFirst []"

data MethodInfo = Method { methodName   :: Name
                         , methodKeys   :: [Name]
                         , methodContext:: [Type]
                         , methodArgs   :: [Type]
                         , methodClass  :: Name
                         , methodEv     :: Name
                         , methodResult :: Type
                         }

-- get and validate method information.
getMethodInfos :: [Name] -> [Name] -> Q [MethodInfo]
getMethodInfos sessionKeys names
    = do ms <- mapM getMethodInfo names
         mapM worker ms
    where worker m | length (methodKeys m) /= length sessionKeys
                       = error $ "Inconsistent keys: " ++ pprint (methodName m) ++ ": " ++ show (sessionKeys, methodKeys m)
                   | otherwise = case compare (sort (methodTyVars m)) (sort (methodKeys m)) of
                                   EQ -> return m
                                   GT -> error $ "Method too general: " ++ pprint (methodName m)
                                   LT -> error $ "Method not general enough: " ++ pprint (methodName m)
          getArgKeys (AppT t1 t2) = getArgKeys t1 ++ getArgKeys t2
          getArgKeys (VarT key) = [key]
          getArgKeys _ = []
          methodTyVars m = nub $ concatMap getArgKeys (methodResult m:methodArgs m)

getMethodInfo :: Name -> Q MethodInfo
getMethodInfo method
    = do methodInfo <- reify method
         case methodInfo of
           VarI _name funcType _decl _fixity -> return (getTypes funcType){methodName = method}
           _ -> error $ "Method is not a function: " ++ nameBase method ++ " is a " ++ showInfo methodInfo

showInfo :: Info -> String
showInfo (ClassI _) = "class"
showInfo (TyConI _) = "type constructor"
showInfo (PrimTyConI _ _ _) = "primitive type constructor"
showInfo (DataConI _ _ _ _) = "data constructor"
showInfo (VarI _ _ _ _) = "variable"
showInfo (TyVarI _ _) = "type variable"
showInfo x = pprint x


-- Cases:
--  forall m. MonadState state m => X -> m Y
--  forall key. key -> Update ()
--  forall key m. MonadState state m => key -> m ()
--  X -> Ev (ReaderT state STM) Y
--  X -> Ev (StateT state STM) Y
getTypes :: Type -> MethodInfo
getTypes (ForallT _ contxt t) = getTypes' contxt t
getTypes t = getTypes' [] t

-- FIXME: only allow type variables used by the component.
getTypes' :: Cxt -> Type -> MethodInfo
getTypes' contxt t
    = case runWriter (worker t) of
        ((keys,className, typeName, res), args) -> Method { methodName = error "Method name not set", methodKeys = keys
                                                          , methodContext = filter (isRelevant keys) contxt
                                                          , methodArgs = args [], methodClass = className
                                                          , methodEv = typeName, methodResult = res}
    where -- recursive case: A -> B
          worker (AppT (AppT ArrowT t1) t2)
              = do tell (t1:)
                   worker t2
          -- end case: Update state res  ||  Query state res
          worker (AppT (AppT (ConT c) state) r)
              | c == ''Update = return (getStateKeys state,''UpdateEvent, 'Update, r)
              | c == ''Query  = return (getStateKeys state,''QueryEvent, 'Query, r)
          -- end case: Ev (ReaderT state STM) res  ||  Ev (StateT state STM) res
          worker (AppT (AppT (ConT _ev) (AppT (AppT (ConT m) state) (ConT _stm))) r)
              | m == ''StateT  = return (getStateKeys state,''UpdateEvent, 'Update, r)
              | m == ''ReaderT = return (getStateKeys state,''QueryEvent, 'Query, r)
          -- end case: m res    (check if m is an instance of MonadState)
          worker (AppT name r)
              | Just state <- isMonadState contxt name  = return (getStateKeys state, ''UpdateEvent, 'Update, r)
              | Just state <- isMonadReader contxt name = return (getStateKeys state, ''QueryEvent, 'Query, r)
          -- error case
          worker c = error ("Unexpected method type: " ++ pprint c)

getStateKeys :: Type -> [Name]
getStateKeys (AppT r r') = getStateKeys r ++ getStateKeys r'
getStateKeys (VarT key) = [key]
getStateKeys (ConT _st) = []
getStateKeys v = error $ "Bad state type: " ++ pprint v ++ " (expected a constant, an application or a type variable)"

isMonadState :: [Type] -> Type -> Maybe Type
isMonadState contxt name = listToMaybe [ state | AppT (AppT (ConT m) state) mName <- contxt, 
                                         mName == name, m == ''MonadState ]

isMonadReader :: [Type] -> Type -> Maybe Type
isMonadReader contxt name = listToMaybe [ state | AppT (AppT (ConT m) state) mName <- contxt, 
                                          mName == name, m == ''MonadReader ]

isRelevant :: [Name] -> Type -> Bool
isRelevant keys t = isAcceptableContext t && any (`elem` keys) (getStateKeys t)

isAcceptableContext :: Type -> Bool
isAcceptableContext (AppT r r') = isAcceptableContext r && isAcceptableContext r'
isAcceptableContext (ConT con) = con `notElem` [''MonadState, ''MonadReader]
isAcceptableContext _ = True

requireSimpleCon :: Name -> Info -> [Name]
requireSimpleCon _ (TyConI (DataD _ _ names _ _derv)) = names
requireSimpleCon _ (TyConI (NewtypeD _ _ names _ _derv)) = names
requireSimpleCon _ (TyConI (TySynD _ names _)) = names
requireSimpleCon name _ = error $ "Cannot create component from '"++pprint name++"'. Expected a data structure."