{-# LANGUAGE OverloadedStrings #-}

{- |
Functions to help with building database migrations.

Most users will want to create a database migration using @defaultMain@ as
follows,

>
> import Database.PostgreSQL.Migrations
>
> main = defaultMain up down
>
> up = migrate $ do
>       create_table "posts"
>         [ column "title" "VARCHAR(255) NOT NULL"
>         , column "author_id" "integer references authors(id)"]
> 
> down = migrate $ drop_table "posts"
>
-}
module Database.PostgreSQL.Migrations (
    -- * Utilities
    defaultMain
  , connectEnv
  , runSqlFile
    -- * DSL
  , Migration, migrate
  , column
    -- ** Adding
  , create_table
  , add_column
  , create_index
  , create_unique_index
    -- ** Removing
  , drop_table
  , drop_column
  , drop_index
    -- ** Modifying
  , rename_column
  , change_column
    -- ** Statements
  , create_table_stmt, add_column_stmt, create_index_stmt
  , drop_table_stmt, drop_column_stmt, drop_index_stmt
  , rename_column_stmt, change_column_stmt
  ) where

import Control.Monad
import Control.Monad.Reader
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import Data.Int
import Data.Maybe
import Database.PostgreSQL.Simple hiding (connect)
import Database.PostgreSQL.Simple.Internal (exec)
import Database.PostgreSQL.Simple.Types
import System.Environment
import System.Exit

import Database.PostgreSQL.Escape

-- | Creates a PostgreSQL 'Connection' using the /DATABASE_URL/ environment
-- variable, if it exists. If it does, it should match the format:
--
-- @
--   postgresql:\/\/[[USERNAME\@PASSWORD]HOSTNAME[:PORT]]/[DBNAME]
-- @
--
-- If it is not present, the environment variables /PG_DBNAME/ /PG_HOST/ etc,
-- are used.
connectEnv :: IO Connection
connectEnv = do
  psqlStr <- getEnvironment >>=
             return . (fromMaybe "") . (lookup "DATABASE_URL")
  connectPostgreSQL $ S8.pack psqlStr

--
-- Migration Monad
--

type Migration = ReaderT Connection IO

migrate :: Migration a -> Connection -> IO ()
migrate = (void .) . runReaderT

executeQuery_ :: Query -> Migration Int64
executeQuery_ q = ask >>= \conn -> liftIO $ execute_ conn q

-- | Runs the SQL file at the given path, relative to the current working
-- directory.
runSqlFile :: FilePath -> Migration ()
runSqlFile sqlFile = void $ do
    conn <- ask
    liftIO $ do
      rawSql <- S.readFile sqlFile
      exec conn rawSql

-- | Returns a column defition by quoting the given name
column :: S8.ByteString -- ^ name
       -> S8.ByteString -- ^ type, definition, constraints
       -> S8.ByteString
column name def = S8.concat [quoteIdent name, " ", def]

-- | Creates a table. See 'column' for constructing the column list.
create_table :: S8.ByteString
             -- ^ Table name
             -> [S8.ByteString]
             -- ^ Column definitions
             -> Migration Int64
create_table = (executeQuery_ .) . create_table_stmt

-- | Returns a 'Query' that creates a table, for example:
--
-- @
--   create_table \"posts\"
--     [ column \"title\" \"VARCHAR(255) NOT NULL\"
--     , column \"body\"  \"text\"]
-- @
create_table_stmt :: S8.ByteString
                  -- ^ Table name
                  -> [S8.ByteString]
                  -- ^ Column definitions
                  -> Query
create_table_stmt tableName colDefs = Query $ S8.concat $
  [ "create table "
  , quoteIdent tableName
  , " ("] ++ (S8.intercalate ", " colDefs):([");"])

-- | Drops a table
drop_table :: S8.ByteString -> Migration Int64
drop_table = executeQuery_ . drop_table_stmt

-- | Returns a 'Query' that drops a table
drop_table_stmt :: S8.ByteString -> Query
drop_table_stmt tableName = Query $ S8.concat
  [ "drop table ", quoteIdent tableName, ";"]

-- | Adds a column to the given table. For example,
--
-- @
--   add_column \"posts\" \"title\" \"VARCHAR(255)\"
-- @
--
-- adds a varchar column called \"title\" to the table \"posts\".
--
add_column :: S8.ByteString
           -- ^ Table name
           -> S8.ByteString
           -- ^ Column name
           -> S8.ByteString
           -- ^ Column definition
           -> Migration Int64
add_column  = ((executeQuery_ .) .) . add_column_stmt

-- | Returns a 'Query' that adds a column to the given table. For example,
--
-- @
--   add_column \"posts\" \"title\" \"VARCHAR(255)\"
-- @
--
-- Returns the query
--
-- @
--   ALTER TABLE \"posts\" add \"title\" VARCHAR(255);
-- @
add_column_stmt :: S8.ByteString
                -- ^ Table name
                -> S8.ByteString
                -- ^ Column name
                -> S8.ByteString
                -- ^ Column definition
                -> Query
add_column_stmt tableName colName colDef = Query $ S8.concat
  [ "alter table ", quoteIdent tableName, " add ", column colName colDef, ";"]

-- | Drops a column from the given table. For example,
--
-- @
--   drop_column \"posts\" \"title\"
-- @
--
-- drops the column \"title\" from the \"posts\" table.
drop_column :: S8.ByteString
            -- ^ Table name
            -> S8.ByteString
            -- ^ Column name
            -> Migration Int64
drop_column = (executeQuery_ .) . drop_column_stmt

-- | Returns a 'Query' that drops a column from the given table. For example,
--
-- @
--   drop_column \"posts\" \"title\"
-- @
--
-- Returns the query
--
-- @
--   ALTER TABLE \"posts\" add \"title\";
-- @
drop_column_stmt :: S8.ByteString
                 -- ^ Table name
                 -> S8.ByteString
                 -- ^ Column name
                 -> Query
drop_column_stmt tableName colName = Query $ S8.concat
  ["alter table ", quoteIdent tableName, " drop ", quoteIdent colName, ";"]

-- | Renames a column in the given table. For example,
--
-- @
--   rename_column \"posts\" \"title\" \"name\"
-- @
--
-- renames the column \"title\" in the \"posts\" table to \"name\".
rename_column :: S8.ByteString
              -- ^ Table name
              -> S8.ByteString
              -- ^ Old column name
              -> S8.ByteString
              -- ^ New column name
              -> Migration Int64
rename_column = ((executeQuery_ .) .) . rename_column_stmt

-- | Returns a 'Query' that renames a column in the given table. For example,
--
-- @
--   rename_column \"posts\" \"title\" \"name\"
-- @
--
-- Returns the query
--
-- @
--   ALTER TABLE \"posts\" RENAME \"title\" TO \"name\";
-- @
rename_column_stmt :: S8.ByteString
                   -- ^ Table name
                   -> S8.ByteString
                   -- ^ Old column name
                   -> S8.ByteString
                   -- ^ New column name
                   -> Query
rename_column_stmt tableName colName colNameNew = Query $ S8.concat
  [ "alter table ", quoteIdent tableName, " rename "
  , quoteIdent colName, " to ", quoteIdent colNameNew, ";"]

-- | Alters a column in the given table. For example,
--
-- @
--   change_column \"posts\" \"title\" \"DROP DEFAULT\"
-- @
--
-- drops the default constraint for the \"title\" column in the \"posts\"
-- table.
change_column :: S8.ByteString
              -- ^ Table name
              -> S8.ByteString
              -- ^ Column name
              -> S8.ByteString
              -- ^ Action
              -> Migration Int64
change_column = ((executeQuery_ .) .) . change_column_stmt

-- | Returns a 'Query' that alters a column in the given table. For example,
--
-- @
--   change_column \"posts\" \"title\" \"DROP DEFAULT\"
-- @
--
-- Returns the query
--
-- @
--   ALTER TABLE \"posts\" ALTER \"title\" DROP DEFAULT;
-- @
change_column_stmt :: S8.ByteString
                   -- ^ Table name
                   -> S8.ByteString
                   -- ^ Column name
                   -> S8.ByteString
                   -- ^ Action
                   -> Query
change_column_stmt tableName colName action = Query $ S8.concat
  [ "alter table ", quoteIdent tableName, " alter "
  , quoteIdent colName, " ", action, ";"]

data CmdArgs = CmdArgs { cmd :: String
                       , cmdVersion :: String
                       , cmdCommit :: Bool }

-- | Creates an index for efficient lookup.
create_index :: S8.ByteString
             -- ^ Index name
             -> S8.ByteString
             -- ^ Table name
             -> [S8.ByteString]
             -- ^ Column names
             -> Migration Int64
create_index = ((executeQuery_ .) .) . (create_index_stmt False)

-- | Creates a unique index for efficient lookup.
create_unique_index :: S8.ByteString
                    -- ^ Index name
                    -> S8.ByteString
                    -- ^ Table name
                    -> [S8.ByteString]
                    -- ^ Column names
                    -> Migration Int64
create_unique_index = ((executeQuery_ .) .) . (create_index_stmt True)

-- | Returns a 'Query' that creates an index for the given columns on the given
-- table. For example,
--
-- @
--   create_index_stmt \"post_owner_index\" \"posts\" \"owner\"
-- @
--
-- Returns the query
--
-- @
--   CREATE INDEX \"post_owner_index\" ON \"posts\" (\"owner\")
-- @
create_index_stmt :: Bool 
                  -- ^ Unique index?
                  -> S8.ByteString
                  -- ^ Index name
                  -> S8.ByteString
                  -- ^ Table name
                  -> [S8.ByteString]
                  -- ^ Column names
                  -> Query
create_index_stmt unq indexName tableName colNames = Query $ S8.concat
  [ "create", unique, " index ", quoteIdent indexName, " on "
  , quoteIdent tableName, " (", cols, ")", ";" ]
  where cols = S8.intercalate ", " $ map quoteIdent colNames
        unique = if unq then " unique" else ""

-- | Drops an index.
drop_index :: S8.ByteString
           -- ^ Index name
           -> Migration Int64
drop_index = executeQuery_ . drop_index_stmt

-- | Returns a 'Query' that drops an index.
--
-- @
--   drop_index_stmt \"post_owner_index\"
-- @
--
-- Returns the query
--
-- @
--   DROP INDEX \"post_owner_index\"
-- @
drop_index_stmt :: S8.ByteString
                -- ^ Index name
                -> Query
drop_index_stmt indexName = Query $ S8.concat
  [ "drop index ", quoteIdent indexName, ";" ]

parseCmdArgs :: [String] -> Maybe CmdArgs
parseCmdArgs args = do
  mycmd <- listToMaybe args
  let args0 = tail args
  myversion <- listToMaybe args0
  return $ go (CmdArgs mycmd myversion False) $ tail args0
  where go res [] = res
        go res (arg:as) =
          let newRes = case arg of
                        "--with-db-commit" -> res { cmdCommit = True }
                        _ -> res
          in go newRes as

defaultMain :: (Connection -> IO ()) -- ^ Migration function
            -> (Connection -> IO ()) -- ^ Rollback function
            -> IO ()
defaultMain up down = do
  (Just cmdArgs) <- getArgs >>= return . parseCmdArgs
  case cmd cmdArgs of
    "up" -> do
      conn <- connectEnv
      res <- query_ conn
          "select version from schema_migrations order by version desc limit 1"
      let currentVersion = case res of
                      [] -> ""
                      (Only v):_ -> v
      let version = cmdVersion cmdArgs
      if currentVersion < version then do
          begin conn
          up conn
          void $ execute conn "insert into schema_migrations values(?)"
                              (Only version)
          if cmdCommit cmdArgs then
            commit conn
            else rollback conn
        else exitWith $ ExitFailure 1
    "down" -> do
      conn <- connectEnv
      res <- query_ conn
          "select version from schema_migrations order by version desc limit 1"
      let currentVersion = case res of
                      [] -> ""
                      (Only v):_ -> v
      let version = cmdVersion cmdArgs
      if currentVersion == version then do
          begin conn
          down conn
          void $ execute conn "delete from schema_migrations where version = ?"
              (Only version)
          if cmdCommit cmdArgs then
            commit conn
            else rollback conn
        else
          exitWith $ ExitFailure 1
    _ -> exitWith $ ExitFailure 1