{-# LANGUAGE EmptyDataDecls             #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE QuasiQuotes                #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeFamilies               #-}
module Drifter.PostgreSQL
    ( PGMigration
    , Method(..)
    , DBConnection(..)
    , ChangeHistory(..)
    , runMigrations
    , getChangeHistory
    ) where

-------------------------------------------------------------------------------
import           Control.Applicative
import           Control.Exception
import           Control.Monad
import           Control.Monad.Trans
import           Control.Monad.Trans.Either
import           Data.Time
import           Database.PostgreSQL.Simple
import           Database.PostgreSQL.Simple.FromField
import           Database.PostgreSQL.Simple.FromRow
import           Database.PostgreSQL.Simple.SqlQQ
import           Drifter
-------------------------------------------------------------------------------


data PGMigration


data instance Method PGMigration = MigrationQuery Query
                                 -- ^ Run a query against the database
                                 | MigrationCode (Connection -> IO (Either String ()))
                                 -- ^ Run any arbitrary IO code


data instance DBConnection PGMigration = DBConnection Connection


instance Drifter PGMigration where
  migrateSingle (DBConnection conn) change = do
    runEitherT $ migrateChange conn change


-------------------------------------------------------------------------------
-- Change History Tracking
-------------------------------------------------------------------------------
newtype ChangeId = ChangeId Int deriving (Eq, Ord, Show, FromField)


data ChangeHistory = ChangeHistory {
      histId          :: ChangeId
    , histName        :: ChangeName
    , histDescription :: Maybe Description
    , histTime        :: UTCTime
    } deriving (Show)


instance Eq ChangeHistory where
    a == b = (histName a) == (histName b)


instance Ord ChangeHistory where
    compare a b = compare (histId a) (histId b)


instance FromRow ChangeHistory where
    fromRow = ChangeHistory <$> field
                            <*> (ChangeName <$> field)
                            <*> field
                            <*> field


-------------------------------------------------------------------------------
-- Queries
-------------------------------------------------------------------------------
bootstrapQ :: Query
bootstrapQ = [sql|
CREATE TABLE IF NOT EXISTS schema_migrations (
    id              serial      NOT NULL,
    name            text        NOT NULL,
    description     text,
    time            timestamptz NOT NULL DEFAULT now(),

    PRIMARY KEY (id),
    UNIQUE (name)
);
|]


-------------------------------------------------------------------------------
changeHistoryQ :: Query
changeHistoryQ =
  "SELECT id, name, description, time FROM schema_migrations ORDER BY id;"


-------------------------------------------------------------------------------
insertLogQ :: Query
insertLogQ =
  "INSERT INTO schema_migrations (name, description, time) VALUES (?, ?, ?);"


-------------------------------------------------------------------------------
findNext :: [ChangeHistory] -> [Change PGMigration] -> IO [Change PGMigration]
findNext [] cs = return cs
findNext (h:hs) (c:cs)
  | (histName h) == (changeName c) = do
    putStrLn $ "Skipping: " ++ show (changeName c)
    findNext hs cs
  | otherwise = return (c:cs)
findNext _ _ = do
  putStrLn "Change Set Exhausted"
  return []


-------------------------------------------------------------------------------
migrateChange :: Connection -> Change PGMigration -> EitherT String IO ()
migrateChange c ch@Change{..} = do
  runMethod c changeMethod

  logChange c ch
  lift $ putStrLn $ "Committed: " ++ show changeName


-------------------------------------------------------------------------------
runMethod :: Connection -> Method PGMigration -> EitherT String IO ()
runMethod c (MigrationQuery q) =
  void $ EitherT $ (Right <$> execute_ c q) `catches` errorHandlers
runMethod c (MigrationCode f) =
  EitherT $ f c `catches` errorHandlers


  -------------------------------------------------------------------------------
logChange :: Connection -> Change PGMigration -> EitherT String IO ()
logChange c Change{..} = do
    now <- lift getCurrentTime
    void $ EitherT $ (Right <$> go now) `catches` errorHandlers
  where
    go now = execute c insertLogQ (changeNameText changeName, changeDescription, now)


-------------------------------------------------------------------------------
errorHandlers :: [Handler (Either String b)]
errorHandlers = [ Handler (\(ex::SqlError) -> return $ Left $ show ex)
                , Handler (\(ex::FormatError) -> return $ Left $ show ex)
                , Handler (\(ex::ResultError) -> return $ Left $ show ex)
                , Handler (\(ex::QueryError) -> return $ Left $ show ex)
                ]


-------------------------------------------------------------------------------
-- | Takes the list of all migrations, removes the ones that have
-- already run and runs them. Use this instead of 'migrate'.
runMigrations :: Connection -> [Change PGMigration] -> IO (Either String ())
runMigrations conn changes = do
  void $ execute_ conn bootstrapQ
  hist <- getChangeHistory conn
  remainingChanges <- findNext hist changes
  begin conn
  res <- migrate (DBConnection conn) remainingChanges `onException` rollback conn
  case res of
    Right _ -> commit conn
    Left _ -> rollback conn
  return res


-------------------------------------------------------------------------------
-- | Check the schema_migrations table for all the migrations that
-- have previously run.
getChangeHistory :: Connection -> IO [ChangeHistory]
getChangeHistory conn = query_ conn changeHistoryQ