{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

module Database.Persist.Relational.TH
       where

#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif

import Data.Int
import Data.List
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Text as T
import Database.Persist
import Database.Persist.Quasi
import Database.Persist.Relational.ToPersistEntity
import qualified Database.Persist.Sql as PersistSql
import Database.Record (PersistableWidth (..))
import Database.Record.FromSql
import Database.Record.TH ( deriveNotNullType
#if MIN_VERSION_persistable_record(0, 5, 0)
                          , recordTemplate
#else
                          , recordType
#endif
                          )
import Database.Record.ToSql
import Database.Relational.Query hiding ((!))
import Database.Relational.Query.TH (defineTable, defineScalarDegree)
import Language.Haskell.TH

ftToType :: FieldType -> TypeQ
ftToType (FTTypeCon Nothing t) = conT $ mkName $ T.unpack t
-- This type is generated from the Quasi-Quoter.
-- Adding this special case avoids users needing to import Data.Int
ftToType (FTTypeCon (Just "Data.Int") "Int64") = conT ''Int64
ftToType (FTTypeCon (Just m) t) = conT $ mkName $ T.unpack $ T.concat [m, ".", t]
ftToType (FTApp x y) = ftToType x `appT` ftToType y
ftToType (FTList x) = listT `appT` ftToType x

makeColumns :: EntityDef
            -> [(String, TypeQ)]
makeColumns t =
    mkCol (entityId t) : map mkCol (entityFields t)
  where
    mkCol fd = (toS $ fieldDB fd, mkFieldType fd)
    toS = T.unpack . unDBName

-- | Generate all templates about table from persistent table definition.
defineTableFromPersistentWithConfig
    :: Config                   -- ^ Configration for haskell relational record
    -> String                   -- ^ Database schema name
    -> Name                     -- ^ Name of the persistent record type corresponds to the table
    -> [EntityDef]              -- ^ @EntityDef@ which is generated by persistent
    -> Q [Dec]
defineTableFromPersistentWithConfig config schema persistentRecordName entities =
    case filter ((== nameBase persistentRecordName) . T.unpack . unHaskellName . entityHaskell) entities of
        (t:_) -> do
            let columns = makeColumns t
                tableName = T.unpack . unDBName . entityDB $ t
            tblD <- defineTable
                        config
                        schema
                        tableName
                        columns
                        (map (mkName . T.unpack) . entityDerives $ t)
                        [0]
                        (Just 0)
            entI <- makeToPersistEntityInstance config schema tableName persistentRecordName columns
            return $ tblD ++ entI
        _ -> error $ "makeColumns: Table related to " ++ show persistentRecordName ++ " not found"

-- | Generate all templates about table from persistent table definition using default naming rule.
defineTableFromPersistent
    :: Name                     -- ^ Name of the persistent record type corresponds to the table
    -> [EntityDef]              -- ^ @EntityDef@ which is generated by persistent
    -> Q [Dec]
defineTableFromPersistent =
    defineTableFromPersistentWithConfig
        defaultConfig { schemaNameMode = SchemaNotQualified }
        (error "[bug] Database.Persist.Relational.TH.defineTableFromPersistent: schema name must not be used")

makeToPersistEntityInstance :: Config -> String -> String -> Name -> [(String, TypeQ)] -> Q [Dec]
makeToPersistEntityInstance config schema tableName persistentRecordName columns = do
    (typName, dataConName) <- recType <$> reify persistentRecordName
    deriveToPersistEntityForRecord hrrRecordType (conT typName, conE dataConName) columns
  where
#if MIN_VERSION_template_haskell(2, 11, 0)
    recType (TyConI (DataD _ tName [] _ [RecC dcName _] _)) = (tName, dcName)
#else
    recType (TyConI (DataD _ tName [] [RecC dcName _] _)) = (tName, dcName)
#endif
    recType info = error $ "makeToPersistEntityInstance: unexpected record info " ++ show info

#if MIN_VERSION_persistable_record(0, 5, 0)
    hrrRecordType = fst $ recordTemplate (recordConfig . nameConfig $ config) schema tableName
#else
    hrrRecordType = recordType (recordConfig . nameConfig $ config) schema tableName
#endif

-- | Generate instances for haskell-relational-record.
mkHrrInstances :: [EntityDef] -> Q [Dec]
mkHrrInstances entities =
    concat `fmap` mapM mkHrrInstancesEachEntityDef entities

mkHrrInstancesEachEntityDef :: EntityDef -> Q [Dec]
mkHrrInstancesEachEntityDef = mkPersistablePrimaryKey . entityId

mkPersistablePrimaryKey :: FieldDef -> Q [Dec]
mkPersistablePrimaryKey fd = do
    notNullD <- deriveNotNullType typ
    persistableD <- defineFromToSqlPersistValue typ
    scalarDegD <- defineScalarDegree typ
    showCTermSQLD <- mkShowConstantTermsSQL typ
    toPersistEntityD <- deriveTrivialToPersistEntity typ
    return $ notNullD ++ persistableD ++ scalarDegD ++ showCTermSQLD ++ toPersistEntityD
  where
    typ = mkFieldType fd

mkShowConstantTermsSQL :: TypeQ -> Q [Dec]
mkShowConstantTermsSQL typ =
    [d|instance ShowConstantTermsSQL $typ where
           showConstantTermsSQL' = showConstantTermsSQL' . PersistSql.fromSqlKey|]

mkFieldType :: FieldDef -> TypeQ
mkFieldType fd =
    case nullable . fieldAttrs $ fd of
        Nullable ByMaybeAttr -> conT ''Maybe `appT` typ
        _ -> typ
  where
    typ = ftToType . fieldType $ fd

-- | Generate 'FromSql' 'PersistValue' and 'ToSql' 'PersistValue' instances for 'PersistField' types.
defineFromToSqlPersistValue :: TypeQ -> Q [Dec]
defineFromToSqlPersistValue typ = do
    fromSqlI <-
        [d| instance FromSql PersistValue $typ where
                recordFromSql = valueRecordFromSql unsafePersistValueFromSql |]
    toSqlI <-
        [d| instance ToSql PersistValue $typ where
                recordToSql = valueRecordToSql toPersistValue |]
    return $ fromSqlI ++ toSqlI

deriveTrivialToPersistEntity :: TypeQ -> Q [Dec]
deriveTrivialToPersistEntity typ =
    [d| instance ToPersistEntity $typ $typ where
            recordFromSql' = recordFromSql |]

deriveToPersistEntityForRecord :: TypeQ -> (TypeQ, ExpQ) -> [(String, TypeQ)] -> Q [Dec]
deriveToPersistEntityForRecord hrrTyp (pTyp, pCon) ((_, pKeyTyp):columns) =
    [d| instance ToPersistEntity $hrrTyp (Entity $pTyp) where
            recordFromSql' = Entity <$> (recordFromSql :: RecordFromSql PersistValue $pKeyTyp) <*> $rfsql |]
  where
    fields = map (\(_, typ) -> [| recordFromSql :: RecordFromSql PersistValue $typ |]) columns
    rfsql = foldl' (\s a -> [| $s <*> $a |]) [| pure $pCon |] fields
deriveToPersistEntityForRecord _ _ [] = fail "deriveToPersistEntityForRecord: missing columns"

unsafePersistValueFromSql :: PersistField a => PersistValue -> a
unsafePersistValueFromSql v =
    case fromPersistValue v of
        Left err -> error $ T.unpack err
        Right a -> a

persistValueTypesFromPersistFieldInstances
    :: [String] -- ^ blacklist types
    -> Q (M.Map Name TypeQ)
persistValueTypesFromPersistFieldInstances blacklist = do
    pf <- reify ''PersistField
    pfT <- [t|PersistField|]
    case pf of
       ClassI _ instances -> return . M.fromList $ mapMaybe (go pfT) instances
       unknown -> fail $ "persistValueTypesFromPersistFieldInstances: unknown declaration: " ++ show unknown
  where
#if MIN_VERSION_template_haskell(2, 11, 0)
    go pfT (InstanceD _ [] (AppT insT t@(ConT n)) [])
#else
    go pfT (InstanceD [] (AppT insT t@(ConT n)) [])
#endif
           | insT == pfT
          && nameBase n `notElem` blacklist = Just (n, return t)
    go _ _ = Nothing

persistableWidthTypes :: Q (M.Map Name TypeQ)
persistableWidthTypes =
    reify ''PersistableWidth >>= goI
  where
    unknownDecl decl = fail $ "persistableWidthTypes: Unknown declaration: " ++ show decl
    goI (ClassI _ instances) = return . M.fromList . mapMaybe goD $ instances
    goI unknown = unknownDecl unknown
#if MIN_VERSION_template_haskell(2, 11, 0)
    goD (InstanceD _ _cxt (AppT _insT a@(ConT n)) _defs) = Just (n, return a)
#else
    goD (InstanceD _cxt (AppT _insT a@(ConT n)) _defs) = Just (n, return a)
#endif
    goD _ = Nothing

derivePersistableInstancesFromPersistFieldInstances
    :: [String] -- ^ blacklist types
    -> Q [Dec]
derivePersistableInstancesFromPersistFieldInstances blacklist = do
    types <- persistValueTypesFromPersistFieldInstances blacklist
    pwts <- persistableWidthTypes
    ftsql <- concatMapTypes defineFromToSqlPersistValue types
    toER <- concatMapTypes deriveTrivialToPersistEntity types
    ws <- concatMapTypes deriveNotNullType $ types `M.difference` pwts
    return $ ftsql ++ toER ++ ws
  where
    concatMapTypes :: (Q Type -> Q [Dec]) -> M.Map Name TypeQ -> Q [Dec]
    concatMapTypes f = fmap concat . mapM f . M.elems