{-# LANGUAGE EmptyDataDecls             #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE QuasiQuotes                #-}
{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeFamilies               #-}
module Drifter.PostgreSQL
    ( PGMigration
    , Method(..)
    , DBConnection(..)
    , ChangeHistory(..)
    , runMigrations
    , getChangeHistory
    , getChangeNameHistory
    ) where

import           Control.Applicative                  as A
import           Control.Exception
import           Control.Monad
import           Control.Monad.Trans
import           Control.Monad.Trans.Except
import           Data.Set                             (Set)
import qualified Data.Set                             as Set
import           Data.Time
import           Database.PostgreSQL.Simple
import           Database.PostgreSQL.Simple.FromField
import           Database.PostgreSQL.Simple.FromRow
import           Database.PostgreSQL.Simple.SqlQQ
import           Drifter

data PGMigration

data instance Method PGMigration = MigrationQuery Query
                                 -- ^ Run a query against the database
                                 | MigrationCode (Connection -> IO (Either String ()))
                                 -- ^ Run any arbitrary IO code

data instance DBConnection PGMigration = DBConnection PGMigrationConnection

data PGMigrationConnection = PGMigrationConnection (Set ChangeName) Connection

instance Drifter PGMigration where
  migrateSingle (DBConnection migrationConn) change = do
    runExceptT $ migrateChange migrationConn change

-- Change History Tracking
newtype ChangeId = ChangeId Int deriving (Eq, Ord, Show, FromField)

data ChangeHistory = ChangeHistory {
      histId          :: ChangeId
    , histName        :: ChangeName
    , histDescription :: Maybe Description
    , histTime        :: UTCTime
    } deriving (Show)

instance Eq ChangeHistory where
    a == b = (histName a) == (histName b)

instance Ord ChangeHistory where
    compare a b = compare (histId a) (histId b)

instance FromRow ChangeHistory where
    fromRow = ChangeHistory <$> field
                            <*> (ChangeName <$> field)
                            <*> field
                            <*> field

-- Queries
bootstrapQ :: Query
bootstrapQ = [sql|
CREATE TABLE IF NOT EXISTS schema_migrations (
    id              serial      NOT NULL,
    name            text        NOT NULL,
    description     text,
    time            timestamptz NOT NULL DEFAULT now(),

    PRIMARY KEY (id),
    UNIQUE (name)

changeHistoryQ :: Query
changeHistoryQ =
  "SELECT id, name, description, time FROM schema_migrations ORDER BY id;"

changeNameHistoryQ :: Query
changeNameHistoryQ =
  "SELECT name FROM schema_migrations ORDER BY id;"

insertLogQ :: Query
insertLogQ =
  "INSERT INTO schema_migrations (name, description, time) VALUES (?, ?, ?);"

migrateChange :: PGMigrationConnection -> Change PGMigration -> ExceptT String IO ()
migrateChange (PGMigrationConnection hist c) ch@Change{..} = do
  if Set.member changeName hist
    then lift $ putStrLn $ "Skipping: " ++ show (changeNameText changeName)
    else do
      runMethod c changeMethod
      logChange c ch
      lift $ putStrLn $ "Committed: " ++ show changeName

runMethod :: Connection -> Method PGMigration -> ExceptT String IO ()
runMethod c (MigrationQuery q) =
  void $ ExceptT $ (Right <$> execute_ c q) `catches` errorHandlers
runMethod c (MigrationCode f) =
  ExceptT $ f c `catches` errorHandlers

logChange :: Connection -> Change PGMigration -> ExceptT String IO ()
logChange c Change{..} = do
    now <- lift getCurrentTime
    void $ ExceptT $ (Right <$> go now) `catches` errorHandlers
    go now = execute c insertLogQ (changeNameText changeName, changeDescription, now)

errorHandlers :: [Handler (Either String b)]
errorHandlers = [ Handler (\(ex::SqlError) -> return $ Left $ show ex)
                , Handler (\(ex::FormatError) -> return $ Left $ show ex)
                , Handler (\(ex::ResultError) -> return $ Left $ show ex)
                , Handler (\(ex::QueryError) -> return $ Left $ show ex)

-- | Takes a connection and builds the state to thread throughout the migration.
-- This includes bootstrapping the migration tables and collecting all the
-- migrations that have already been committed.
makePGMigrationConnection :: Connection -> IO PGMigrationConnection
makePGMigrationConnection conn = do
  void $ execute_ conn bootstrapQ
  hist <- getChangeNameHistory conn
  return $ PGMigrationConnection (Set.fromList hist) conn

-- | Takes the list of all migrations, removes the ones that have
-- already run and runs them. Use this instead of 'migrate'.
runMigrations :: Connection -> [Change PGMigration] -> IO (Either String ())
runMigrations conn changes = do
  begin conn
  migrationConn <- makePGMigrationConnection conn
  res <- migrate (DBConnection migrationConn) changes `onException` rollback conn
  case res of
    Right _ -> commit conn
    Left  _ -> rollback conn
  return res

-- | Get all changes from schema_migrations table for all the migrations that
-- have previously run.
getChangeHistory :: Connection -> IO [ChangeHistory]
getChangeHistory conn = query_ conn changeHistoryQ

-- | Get just the names of all changes from schema_migrations for migrations
-- that have previously run.
getChangeNameHistory :: Connection -> IO [ChangeName]
getChangeNameHistory conn = fmap (\(Only nm) -> ChangeName nm)
                        A.<$> query_ conn changeNameHistoryQ