{-#LANGUAGE NoImplicitPrelude #-}
{-#LANGUAGE OverloadedStrings #-}
{-#LANGUAGE TypeFamilies #-}
{-#LANGUAGE MultiParamTypeClasses #-}
{-#LANGUAGE FlexibleInstances #-}
{-#LANGUAGE FlexibleContexts #-}
{-#LANGUAGE LambdaCase #-}
{-#LANGUAGE DeriveGeneric #-}

module Web.Sprinkles.Databases
where

import Web.Sprinkles.Prelude
import Data.Aeson as JSON
import Data.Aeson.TH as JSON
import qualified Database.HDBC as HDBC
import Database.HDBC.PostgreSQL (connectPostgreSQL)
import Database.HDBC.Sqlite3 (connectSqlite3)
import Database.HDBC.MySQL (connectMySQL, MySQLConnectInfo (..))
import qualified Data.Serialize as Cereal
import Data.Serialize (Serialize)
import Data.Expandable

data DSN = DSN { dsnDriver :: SqlDriver, dsnDetails :: Text }
    deriving (Show, Generic)

data ResultSetMode
    = ResultsMerge
    | ResultsNth Int
    | ResultsLast
    deriving (Show, Generic)

instance FromJSON ResultSetMode where
    parseJSON = \case
        String "merge" -> return ResultsMerge
        String "first" -> return $ ResultsNth 0
        String "last" -> return ResultsLast
        String x -> fail $ "Invalid result set mode " ++ show x
        Number i -> if i >= 0
                        then return (ResultsNth . round $ i)
                        else fail $ "Invalid result set index " ++ show i
        x -> fail $ "Expected integer or string for result set mode, got " ++ show x

instance Serialize ResultSetMode where

instance ExpandableM Text DSN where
    expandM f (DSN driver details) = DSN driver <$> f details

data SqlDriver = SqliteDriver
               | PostgreSQLDriver
               | MySQLDriver
               deriving (Show, Generic)

instance Serialize SqlDriver where
    put SqliteDriver = Cereal.put ("sqlt" :: String)
    put PostgreSQLDriver = Cereal.put ("pg" :: String)
    put MySQLDriver = Cereal.put ("my" :: String)
    get = do
        str <- Cereal.get
        case str :: String of
            "sqlt" -> return SqliteDriver
            "pg" -> return PostgreSQLDriver
            "my" -> return MySQLDriver
            x -> fail $ "Invalid database driver: " <> show x

sqlDriverID :: SqlDriver -> Text
sqlDriverID SqliteDriver = "sqlite"
sqlDriverID PostgreSQLDriver = "postgres"
sqlDriverID MySQLDriver = "mysql"

sqlDriverFromID :: Text -> Maybe SqlDriver
sqlDriverFromID "sqlite" = Just SqliteDriver
sqlDriverFromID "sqlite3" = Just SqliteDriver
sqlDriverFromID "pg" = Just PostgreSQLDriver
sqlDriverFromID "pgsql" = Just PostgreSQLDriver
sqlDriverFromID "postgres" = Just PostgreSQLDriver
sqlDriverFromID "postgresql" = Just PostgreSQLDriver
sqlDriverFromID "my" = Just MySQLDriver
sqlDriverFromID "mysql" = Just MySQLDriver
sqlDriverFromID _ = Nothing

instance ToJSON SqlDriver where
    toJSON = toJSON . sqlDriverID

instance FromJSON SqlDriver where
    parseJSON x =
        parseJSON x >>=
            maybe (fail "Invalid SQL Driver") return . sqlDriverFromID

instance Serialize DSN where
    put (DSN driver details) = do
        Cereal.put driver
        Cereal.put ((unpack :: Text -> String) details)
    get = DSN <$> Cereal.get <*> ((pack :: String -> Text) <$> Cereal.get)

dsnToText :: DSN -> Text
dsnToText (DSN driver details) = sqlDriverID driver <> ":" <> details

instance FromJSON DSN where
    parseJSON (Object obj) = do
        driver <- obj .: "driver"
        details <- obj .: "dsn"
        return $ DSN driver details
    parseJSON _ = fail "Invalid DSN"

withConnection :: DSN -> (HDBC.ConnWrapper -> IO a) -> IO a
withConnection dsn inner = do
    conn <- connect dsn
    inner conn

connect :: DSN -> IO HDBC.ConnWrapper
connect (DSN driver details) =
    case driver of
        SqliteDriver -> HDBC.ConnWrapper <$> connectSqlite3 (unpack details)
        PostgreSQLDriver -> HDBC.ConnWrapper <$> connectPostgreSQL (unpack details)
        MySQLDriver -> do
            info <- parseMysqlConnectInfo details
            HDBC.ConnWrapper <$> connectMySQL info

parseMysqlConnectInfo :: Monad m => Text -> m MySQLConnectInfo
parseMysqlConnectInfo details = do
    MySQLConnectInfo
        <$> getStrDef "localhost" "host"
        <*> getStrDef "" "user"
        <*> getStrDef "" "password"
        <*> getStrDef "" "database"
        <*> getIntDef 3306 "port"
        <*> getStrDef "" "socket"
        <*> getStr "group"
    where
        getStr :: Monad m => Text -> m (Maybe String)
        getStr key =
            maybe
                (return Nothing)
                (fmap Just . parseStr)
                (lookup key dict)

        getStrDef :: Monad m => String -> Text -> m String
        getStrDef d = fmap (fromMaybe d) . getStr

        getInt :: Monad m => Text -> m (Maybe Int)
        getInt key =
            maybe
                (return Nothing)
                (fmap Just . parseInt)
                (lookup key dict)

        getIntDef :: Monad m => Int -> Text -> m Int
        getIntDef d = fmap (fromMaybe d) . getInt

        dict :: HashMap Text Text
        dict = mapFromList . catMaybes . map parsePair . splitElem ';' $ details

        parsePair :: Text -> Maybe (Text, Text)
        parsePair src =
            case break (== '=') src of
                (key, "") -> Nothing
                (key, value) -> Just (toLower key, drop 1 value)

        parseStr :: Monad m => Text -> m String
        parseStr = return . unpack

        parseInt :: Monad m => Text -> m Int
        parseInt = maybe (fail "Invalid number") return . readMay