{-# LANGUAGE TemplateHaskell #-}

-- | This module provides functions to generate the auxiliary structures for the user data type
module Database.Groundhog.TH
  ( deriveEntity
  , setDbEntityName
  , setConstructor
  , setPhantomName
  , setDbConstrName
  , setConstraints
  , setField
  , setDbFieldName
  , setExprFieldName
  ) where

import Database.Groundhog.Core(PersistEntity(..), PersistField(..), NeverNull, Primitive(toPrim), PersistBackend(..), DbType(DbEntity), Constraint, Constructor(..), namedType, EntityDef(..), ConstructorDef(..), PersistValue(..))
import Language.Haskell.TH
import Language.Haskell.TH.Syntax(StrictType, VarStrictType)
import Control.Monad(liftM, forM, forM_)
import Control.Monad.Trans.State(State, runState, modify)
import Data.Char(toUpper, toLower, isSpace)
import Data.List(nub, (\\))

-- data SomeData a = U1 { foo :: Int} | U2 { bar :: Maybe String, asc :: Int64, add :: a} | U3 deriving (Show, Eq)

data THEntityDef = THEntityDef {
    dataName :: Name -- SomeData
  , dbEntityName :: String  -- SQLSomeData
  , thTypeParams :: [TyVarBndr]
  , thConstructors :: [THConstructorDef]
} deriving Show

data THConstructorDef = THConstructorDef {
    thConstrName    :: Name -- U2
  , thPhantomConstrName :: String -- U2Constructor
  , dbConstrName    :: String -- SQLU2
  , thConstrParams  :: [FieldDef]
  , thConstrConstrs :: [Constraint]
} deriving Show

data FieldDef = FieldDef {
    fieldName :: String -- bar
  , dbFieldName :: String -- SQLbar
  , exprName :: String -- BarField
  , fieldType :: Type
} deriving Show

-- | Set name of the table in the datatype
setDbEntityName :: String -> State THEntityDef ()
setDbEntityName name = modify $ \d -> d {dbEntityName = name}

-- | Modify constructor
setConstructor :: Name -> State THConstructorDef () -> State THEntityDef ()
setConstructor name f = modify $ \d ->
  d {thConstructors = replaceOne thConstrName name f $ thConstructors d}

-- | Set name used to parametrise fields
setPhantomName :: String -> State THConstructorDef ()
setPhantomName name = modify $ \c -> c {thPhantomConstrName = name}

-- | Set name of the constructor specific table
setDbConstrName :: String -> State THConstructorDef ()
setDbConstrName name = modify $ \c -> c {dbConstrName = name}

-- | Set constraints of the constructor. The names should be database names of the fields
setConstraints :: [Constraint] -> State THConstructorDef ()
setConstraints cs = modify $ \c -> c {thConstrConstrs = cs}

-- | Modify field. Field name is a regular field name in record constructor. Otherwise, it is lower-case constructor name with field number.
setField :: String -> State FieldDef () -> State THConstructorDef ()
setField name f = modify $ \c ->
  c {thConstrParams = replaceOne fieldName name f $ thConstrParams c}

-- | Set name of the field column in a database
setDbFieldName :: String -> State FieldDef ()
setDbFieldName name = modify $ \f -> f {dbFieldName = name}

-- | Set name of field constructor used in expressions
setExprFieldName :: String -> State FieldDef ()
setExprFieldName name = modify $ \f -> f {exprName = name}

replaceOne :: (Eq b, Show b) => (a -> b) -> b -> State a () -> [a] -> [a]
replaceOne p a f xs = case length (filter ((==a).p) xs) of
  1 -> map (\x -> if p x == a then runModify f x else x) xs
  0 -> error $ "Element with name " ++ show a ++ " not found"
  _ -> error $ "Found more than one element with name " ++ show a

runModify :: State a () -> a -> a
runModify m a = snd $ runState m a

-- | Creates the auxiliary structures for a user datatype, which are required by Groundhog to manipulate it.
-- 
-- It creates GADT 'Fields' data instance for referring to the fields in
-- expressions and phantom types for data constructors. For record constructors the Field name is the regular field name with 
-- first letter capitalized and postpended \"Field\". If the field is an ordinary constructor, its name is constructor name
-- and postponed field name. The constructor phantom datatypes have the same name as constructors with \"Constructor\" postpended.
--
-- The generation can be adjusted using the optional modifier function. Example:
--
-- > data SomeData a = Normal Int | Record { bar :: Maybe String, asc :: a}
-- > deriveEntity ''SomeData $ Just $ do
-- >   setDbEntityName "SomeTableName"
-- >   setConstructor 'Normal $ do
-- >     setPhantomName "NormalConstructor" -- the same as default
--
-- It will generate these new datatypes and required instances.
--
-- > data NormalConstructor
-- > data RecordConstructor
-- > instance PersistEntity where
-- >   data Fields (SomeData a) where
-- >     Normal0Field :: Fields NormalConstructor Int
-- >     BarField :: Fields RecordConstructor (Maybe String)
-- >     AscField :: Fields RecordConstructor a
-- > ...

deriveEntity :: Name -> Maybe (State THEntityDef ()) -> Q [Dec]
deriveEntity name f = do
  info <- reify name
  let f' = maybe id runModify f
  case info of
    TyConI x -> do
      case x of
        def@(DataD _ _ _ _ _)  -> mkDecs $ either error id $ validate $ f' $ mkTHEntityDef def
        NewtypeD _ _ _ _ _ -> error "Newtypes are not supported"
        _ -> error $ "Unknown declaration type"
    _        -> error "Only datatypes can be processed"

validate :: THEntityDef -> Either String THEntityDef
validate def = do
  let notUniqueBy f xs = let xs' = map f xs in nub $ xs' \\ nub xs'
  let assertUnique f xs what = case notUniqueBy f xs of
        [] -> return ()
        ys -> fail $ "All " ++ what ++ " must be unique: " ++ show ys
  -- we need to validate datatype names because TH just creates unusable fields with spaces
  let isSpaceFree = not . any isSpace
  let assertSpaceFree s what = if isSpaceFree s then return () else fail $ "Spaces in " ++ what ++ " are not allowed: " ++ show s
  let constrs = thConstructors def
  assertUnique thPhantomConstrName constrs "constructor phantom name"
  assertUnique dbConstrName constrs "constructor db name"
  forM_ constrs $ \cdef -> do
    let fields = thConstrParams cdef
    assertSpaceFree (thPhantomConstrName cdef) "constructor phantom name"
    assertUnique exprName fields "expr field name in a constructor"
    assertUnique dbFieldName fields "db field name in a constructor"
    forM_ fields $ \fdef -> assertSpaceFree (exprName fdef) "field expr name"
  return def

mkTHEntityDef :: Dec -> THEntityDef
mkTHEntityDef (DataD _ dname typeVars cons _) =
  THEntityDef dname (nameBase dname) typeVars constrs where
  constrs = map mkConstr cons

  mkConstr (NormalC name params) = mkConstr' name $ zipWith (\p i -> mkField (firstLetter toLower (nameBase name) ++ show i) p) params [0::Int ..]
  mkConstr (RecC name params) = mkConstr' name (map mkVarField params)
  mkConstr (InfixC _ _ _) = error "Types with infix constructors are not supported"
  mkConstr (ForallC _ _ _) = error "Types with existential quantification are not supported"
  mkConstr' name params = THConstructorDef name (nameBase name ++ "Constructor") (nameBase name) params []
  
  mkField :: String -> StrictType -> FieldDef
  mkField name (_, t) = mkField' name t
  mkVarField :: VarStrictType -> FieldDef
  mkVarField (name, _, t) = mkField' (nameBase name) t
  mkField' name t = FieldDef name name (firstLetter toUpper name ++ "Field") t
mkTHEntityDef _ = error "Only datatypes can be processed"

firstLetter :: (Char -> Char) -> String -> String
firstLetter f s = f (head s):tail s

mkDecs :: THEntityDef -> Q [Dec]
mkDecs def = do
--  runIO (print def)
  decs <- fmap concat $ sequence
    [ mkPhantomConstructors def
    , mkPhantomConstructorInstances def
    , mkPersistFieldInstance def
    , mkPersistEntityInstance def
    ]
--  runIO $ putStrLn $ pprint decs
  return decs
-- $(reify ''SomeData >>= stringE.show)

mkPhantomConstructors :: THEntityDef -> Q [Dec]
mkPhantomConstructors def = mapM f $ thConstructors def where
  f c = dataD (cxt []) (mkName $ thPhantomConstrName c) [] [] []
  
mkPhantomConstructorInstances :: THEntityDef -> Q [Dec]
mkPhantomConstructorInstances def = sequence $ zipWith f [0..] $ thConstructors def where
  f :: Int -> THConstructorDef -> Q Dec
  f cNum c = instanceD (cxt []) (appT (conT ''Constructor) (conT $ mkName $ thPhantomConstrName c)) [phantomConstrName', phantomConstrNum'] where
    phantomConstrName' = funD 'phantomConstrName [clause [wildP] (normalB $ stringE $ dbConstrName c) []]
    phantomConstrNum' = funD 'phantomConstrNum [clause [wildP] (normalB $ [|cNum  |]) []]

mkPersistEntityInstance :: THEntityDef -> Q [Dec]
mkPersistEntityInstance def = do
  let entity = foldl AppT (ConT (dataName def)) $ map getType $ thTypeParams def
  
  fields' <- do
    cParam <- newName "c"
    fParam <- newName "f"
    let mkField name field = ForallC [] ([EqualP (VarT cParam) (ConT name), EqualP (VarT fParam) (fieldType field)]) $ NormalC (mkName $ exprName field) []
    let f cdef = map (mkField $ mkName $ thPhantomConstrName cdef) $ thConstrParams cdef
    let constrs = concatMap f $ thConstructors def
    return $ DataInstD [] ''Fields [entity, VarT cParam, VarT fParam] constrs []
    
  entityDef' <- do
    v <- newName "v"
    let mkLambda t = [|undefined :: $(forallT (thTypeParams def) (cxt []) [t| $(return entity) -> $(return t) |]) |]
    let typeParams' = listE $ map (\t -> [| namedType ($(mkLambda $ getType t) $(varE v)) |]) $ thTypeParams def
    let mkField c fNum f = do
        a <- newName "a"
        let fname = dbFieldName f
        let pats = replicate fNum wildP ++ [varP a] ++ replicate (length (thConstrParams c) - fNum - 1) wildP
        let nvar = case hasFreeVars (fieldType f) of
             True ->  appE (lamE [conP (thConstrName c) pats] (varE a)) (varE v)
             False -> [| undefined :: $(return $ fieldType f) |]
        [| (fname, namedType $nvar) |]
         
    let constrs = listE $ zipWith (\cNum c@(THConstructorDef _ _ name params conss) -> [| ConstructorDef cNum name $(listE $ zipWith (mkField c) [0..] params) conss |]) [0..] $ thConstructors def
    let body = normalB [| EntityDef $(stringE $ dbEntityName def) $typeParams' $constrs |]
    funD 'entityDef $ [ clause [varP v] body [] ]

  toPersistValues' <- liftM (FunD 'toPersistValues) $ forM (zip [0..] $ thConstructors def) $ \(cNum, c) -> do
    names <- mapM (const $ newName "f") $ thConstrParams c
    let pat = conP (thConstrName c) (map varP names)
    let body = normalB $ [| sequence $(listE $ map (appE (varE 'toPersistValue)) ([|cNum::Int|]:map varE names) ) |]
    clause [pat] body []
  
--  fromPersistValues' <- funD 'fromPersistValues $ [ clause [wildP] (normalB [| error "fromPersistValues" |])[] ]
  fromPersistValues' <- do
    clauses <- forM (zip [0..] (thConstructors def)) $ \(cNum, c) -> do
      names <- mapM (const $ newName "x") $ thConstrParams c
      names' <- mapM (const $ newName "x'") $ thConstrParams c
      let pat = conP '(:) [conP 'PersistInt64 [litP $ integerL cNum], listP $ map varP names]
      let result = noBindS (appE (varE 'return) ( foldl (\a -> appE a . varE) (conE (thConstrName c)) names'))
      let getField name name' f = bindS (varP name') [| fromPersistValue $(varE name) |]
      let body = normalB $ doE $ (zipWith3 getField names names' (thConstrParams c)) ++ [result]
--      let body = normalB [|undefined|]
      clause [pat] body []
    unexpected <- newName "xs" >>= \xs -> clause [varP xs] (normalB [| fail $ "Invalid values: " ++ show $(varE xs) |]) []
    return $ FunD 'fromPersistValues $ clauses ++ [unexpected]
  
  getConstraints' <- let
    hasConstraints = not . null . thConstrConstrs
    clauses = zipWith mkClause [0::Int ..] (thConstructors def)
    mkClause cNum cdef | not (hasConstraints cdef) = clause [conP (thConstrName cdef) pats] (normalB [| (cNum, []) |]) [] where
      pats = map (const wildP) $ thConstrParams cdef
    mkClause cNum cdef = clause [conP (thConstrName cdef) (map varP names)] body [] where
      getFieldName n = case filter ((==n).dbFieldName) $ thConstrParams cdef of
        [f] -> varE $ mkName $ fieldName f
        []  -> error $ "Database field name " ++ show n ++ " declared in constraint not found"
        _   -> error $ "It can never happen. Found several fields with one database name " ++ show n
      body = normalB $ [| (cNum, $(listE $ map (\(cname, fnames) -> [|(cname, $(listE $ map (\fname -> [| (fname, toPrim $(getFieldName fname)) |] ) fnames )) |] ) $ thConstrConstrs cdef)) |]
      names = map (mkName . fieldName) $ thConstrParams cdef
    in funD 'getConstraints clauses
  
  showField' <- do
    let fields = concatMap thConstrParams $ thConstructors def
    funD 'showField $ map (\f -> clause [conP (mkName $ exprName f) []] (normalB $ stringE $ dbFieldName f)[] ) fields

  eqField' <- let
    fieldNames = thConstructors def >>= thConstrParams >>= return.mkName.exprName
    clauses = map (\n -> clause [conP n [], conP n []] (normalB [| True |]) []) fieldNames
    in funD 'eqField $ if length clauses > 1
     then clauses ++ [clause [wildP, wildP] (normalB [| False |]) []]
     else clauses

  let context = paramsContext def
  let decs = [fields', entityDef', toPersistValues', fromPersistValues', getConstraints', showField', eqField']
  return $ [InstanceD context (AppT (ConT ''PersistEntity) entity) decs]
  
mkPersistFieldInstance :: THEntityDef -> Q [Dec]
mkPersistFieldInstance def = do
  let types = map getType $ thTypeParams def
  let entity = foldl AppT (ConT (dataName def)) types
  
  persistName' <- do
    v <- newName "v"
    let mkLambda t = [|undefined :: $(forallT (thTypeParams def) (cxt []) [t| $(return entity) -> $(return t) |]) |]
    
    let paramNames = foldr1 (\p xs -> [| $p ++ "$" ++ $xs |] ) $ map (\t -> [| persistName ($(mkLambda t) $(varE v)) |]) types
    let namesList = case null types of
         True  -> [| $(stringE $ dbEntityName def) |]
         False -> [| $(stringE $ dbEntityName def) ++ "$" ++ $(paramNames) |]
    let body = normalB $ namesList
    funD 'persistName $ [ clause [varP v] body [] ]
  
  toPersistValue' <- do
    let body = normalB [| liftM (either toPrim toPrim) . insertBy |]
    funD 'toPersistValue $ [ clause [] body [] ]
  fromPersistValue' <- do
    x <- newName "x"
    let body = normalB [| fromPersistValue $(varE x) >>= get >>= maybe (fail $ "No data with id " ++ show $(varE x)) return |]
    funD 'fromPersistValue $ [ clause [varP x] body [] ]
  dbType' <- funD 'dbType $ [ clause [] (normalB [| DbEntity . entityDef |]) [] ]

  let context = paramsContext def
  let decs = [persistName', toPersistValue', fromPersistValue', dbType']
  return $ [InstanceD context (AppT (ConT ''PersistField) entity) decs]

paramsContext :: THEntityDef -> Cxt
paramsContext def = classPred ''PersistField params ++ classPred ''NeverNull maybys where
  classPred clazz = map (\t -> ClassP clazz [t])
  -- every type must be an instance of PersistField
  params = map getType $ thTypeParams def
  -- all datatype fields also must be instances of PersistField
  -- if Maybe is applied to a type param, the param must be also an instance of NeverNull
  -- so that (Maybe param) is an instance of PersistField
  maybys = nub $ thConstructors def >>= thConstrParams >>= insideMaybe . fieldType

getType :: TyVarBndr -> Type
getType (PlainTV name) = VarT name
getType (KindedTV name _) = VarT name

foldType :: (Type -> a) -> (a -> a -> a) -> Type -> a
foldType f (<>) = go where
  go (ForallT _ _ _) = error "forall'ed fields are not allowed"
  go z@(AppT a b)    = f z <> go a <> go b
  go z@(SigT t _)    = f z <> go t
  go z               = f z

hasFreeVars :: Type -> Bool
hasFreeVars = foldType f (||) where
  f (VarT _) = True
  f _ = False

insideMaybe :: Type -> [Type]
insideMaybe = foldType f (++) where
  f (AppT (ConT c) t@(VarT _)) | c == ''Maybe = [t]
  f _ = []