{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}

module Morph.Migrator
  ( migrate
  ) where

import Control.Monad

import Data.Function
import Data.List
import Data.Monoid
import Data.String

import System.Directory
import System.FilePath
import System.IO

import Database.PostgreSQL.Simple

-- | A migration can either be read from file and contain both sides or from the
-- database and contain only the down side.
data MigrationType = Full | Rollback

type family MigrationSQL (a :: MigrationType) :: * where
  MigrationSQL 'Full     = (Query, String)
  MigrationSQL 'Rollback = Query

data Migration :: MigrationType -> * where
  Migration ::
    { Migration a -> String
migrationIdentifier :: String
    , Migration a -> MigrationSQL a
migrationSQL        :: MigrationSQL a
    } -> Migration a

createMigrationTable :: Connection -> IO ()
createMigrationTable :: Connection -> IO ()
createMigrationTable Connection
conn = IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
execute_ Connection
conn
  Query
"CREATE TABLE IF NOT EXISTS migrations (\
  \  id           varchar PRIMARY KEY CHECK (id <> ''),\
  \  rollback_sql text CHECK (rollback_sql <> '')\
  \);"

listDone :: Connection -> IO [Migration 'Rollback]
listDone :: Connection -> IO [Migration 'Rollback]
listDone Connection
conn = do
  [(String, Maybe String)]
pairs <- Connection -> Query -> IO [(String, Maybe String)]
forall r. FromRow r => Connection -> Query -> IO [r]
query_ Connection
conn Query
"SELECT id, rollback_sql FROM migrations ORDER BY id ASC"
  [Migration 'Rollback] -> IO [Migration 'Rollback]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Migration 'Rollback] -> IO [Migration 'Rollback])
-> [Migration 'Rollback] -> IO [Migration 'Rollback]
forall a b. (a -> b) -> a -> b
$ (((String, Maybe String) -> Migration 'Rollback)
 -> [(String, Maybe String)] -> [Migration 'Rollback])
-> [(String, Maybe String)]
-> ((String, Maybe String) -> Migration 'Rollback)
-> [Migration 'Rollback]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((String, Maybe String) -> Migration 'Rollback)
-> [(String, Maybe String)] -> [Migration 'Rollback]
forall a b. (a -> b) -> [a] -> [b]
map [(String, Maybe String)]
pairs (((String, Maybe String) -> Migration 'Rollback)
 -> [Migration 'Rollback])
-> ((String, Maybe String) -> Migration 'Rollback)
-> [Migration 'Rollback]
forall a b. (a -> b) -> a -> b
$ \(String
identifier, Maybe String
mSQL) -> Migration :: forall (a :: MigrationType).
String -> MigrationSQL a -> Migration a
Migration
    { migrationIdentifier :: String
migrationIdentifier = String
identifier
    , migrationSQL :: MigrationSQL 'Rollback
migrationSQL        = Query -> (String -> Query) -> Maybe String -> Query
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Query
"" String -> Query
forall a. IsString a => String -> a
fromString Maybe String
mSQL
    }

listGoals :: FilePath -> IO [Migration 'Full]
listGoals :: String -> IO [Migration 'Full]
listGoals String
dir = do
    [String]
allNames <- [String] -> [String]
forall a. Ord a => [a] -> [a]
sort ([String] -> [String]) -> IO [String] -> IO [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO [String]
getDirectoryContents String
dir
    let upNames :: [String]
upNames     = (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String
".up.sql"   String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf`) [String]
allNames
        downNames :: [String]
downNames   = (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String
".down.sql" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf`) [String]
allNames

    [String]
-> (String -> IO (Migration 'Full)) -> IO [Migration 'Full]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [String]
upNames ((String -> IO (Migration 'Full)) -> IO [Migration 'Full])
-> (String -> IO (Migration 'Full)) -> IO [Migration 'Full]
forall a b. (a -> b) -> a -> b
$ \String
upName -> do
      let identifier :: String
identifier = String -> String
extractIdentifier String
upName
      Query
up   <- String -> IO Query
readMigrationFile String
upName
      String
down <- [String] -> String -> IO String
readDownMigrationFile [String]
downNames String
identifier
      Migration 'Full -> IO (Migration 'Full)
forall (m :: * -> *) a. Monad m => a -> m a
return Migration :: forall (a :: MigrationType).
String -> MigrationSQL a -> Migration a
Migration
        { migrationIdentifier :: String
migrationIdentifier = String
identifier
        , migrationSQL :: MigrationSQL 'Full
migrationSQL        = (Query
up, String
down)
        }

  where
    extractIdentifier :: FilePath -> String
    extractIdentifier :: String -> String
extractIdentifier = (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Char -> String -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (String
"0123456789" :: String))

    readMigrationFile :: FilePath -> IO Query
    readMigrationFile :: String -> IO Query
readMigrationFile String
path = do
      String
contents <- String -> IO String
readFile (String -> IO String) -> String -> IO String
forall a b. (a -> b) -> a -> b
$ String
dir String -> String -> String
</> String
path
      Query -> IO Query
forall (m :: * -> *) a. Monad m => a -> m a
return (Query -> IO Query) -> Query -> IO Query
forall a b. (a -> b) -> a -> b
$ String -> Query
forall a. IsString a => String -> a
fromString String
contents

    readDownMigrationFile :: [FilePath] -> String -> IO String
    readDownMigrationFile :: [String] -> String -> IO String
readDownMigrationFile [String]
paths String
identifier =
      case (String -> Bool) -> [String] -> Maybe String
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((String -> String -> Bool
forall a. Eq a => a -> a -> Bool
==String
identifier) (String -> Bool) -> (String -> String) -> String -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
extractIdentifier) [String]
paths of
        Maybe String
Nothing -> String -> IO String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> IO String) -> String -> IO String
forall a b. (a -> b) -> a -> b
$
          String
"RAISE EXCEPTION 'No rollback migration found for "
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String -> String
forall a. IsString a => String -> a
fromString String
identifier String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"';"
        Just String
path -> String -> IO String
readFile (String -> IO String) -> String -> IO String
forall a b. (a -> b) -> a -> b
$ String
dir String -> String -> String
</> String
path

rollbackMigration :: Connection -> Migration 'Rollback -> IO ()
rollbackMigration :: Connection -> Migration 'Rollback -> IO ()
rollbackMigration Connection
conn Migration 'Rollback
migration = do
  Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
    String
"Rollbacking migration " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Migration 'Rollback -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier Migration 'Rollback
migration String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ..."
  IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
execute_ Connection
conn (Query -> IO Int64) -> Query -> IO Int64
forall a b. (a -> b) -> a -> b
$ Migration 'Rollback -> MigrationSQL 'Rollback
forall (a :: MigrationType). Migration a -> MigrationSQL a
migrationSQL Migration 'Rollback
migration
  IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> Only String -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
execute Connection
conn Query
"DELETE FROM migrations WHERE id = ?" (Only String -> IO Int64) -> Only String -> IO Int64
forall a b. (a -> b) -> a -> b
$
    String -> Only String
forall a. a -> Only a
Only (String -> Only String) -> String -> Only String
forall a b. (a -> b) -> a -> b
$ Migration 'Rollback -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier Migration 'Rollback
migration

doMigration :: Connection -> Migration 'Full -> IO ()
doMigration :: Connection -> Migration 'Full -> IO ()
doMigration Connection
conn Migration 'Full
migration = do
  Handle -> String -> IO ()
hPutStrLn Handle
stderr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
    String
"Running migration " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Migration 'Full -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier Migration 'Full
migration String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" ..."
  let (Query
up, String
down) = Migration 'Full -> MigrationSQL 'Full
forall (a :: MigrationType). Migration a -> MigrationSQL a
migrationSQL Migration 'Full
migration
  IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
execute_ Connection
conn Query
up
  IO Int64 -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Int64 -> IO ()) -> IO Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> (String, String) -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
execute Connection
conn Query
"INSERT INTO migrations (id, rollback_sql) VALUES (?, ?)"
                 (Migration 'Full -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier Migration 'Full
migration, String
down)

migrate :: Bool -> Connection -> FilePath -> IO ()
migrate :: Bool -> Connection -> String -> IO ()
migrate Bool
inTransaction Connection
conn String
dir = do
  Connection -> IO ()
createMigrationTable Connection
conn

  [Migration 'Rollback]
doneMigrations <- Connection -> IO [Migration 'Rollback]
listDone  Connection
conn
  [Migration 'Full]
goalMigrations <- String -> IO [Migration 'Full]
listGoals String
dir

  let doneIdentifiers :: [String]
doneIdentifiers = (Migration 'Rollback -> String)
-> [Migration 'Rollback] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Migration 'Rollback -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier [Migration 'Rollback]
doneMigrations
      goalIdentifiers :: [String]
goalIdentifiers = (Migration 'Full -> String) -> [Migration 'Full] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Migration 'Full -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier [Migration 'Full]
goalMigrations

      toRollbackIdentifiers :: [String]
toRollbackIdentifiers = [String]
doneIdentifiers [String] -> [String] -> [String]
forall a. Eq a => [a] -> [a] -> [a]
\\ [String]
goalIdentifiers
      toDoIdentifiers :: [String]
toDoIdentifiers       = [String]
goalIdentifiers [String] -> [String] -> [String]
forall a. Eq a => [a] -> [a] -> [a]
\\ [String]
doneIdentifiers

      toRollback :: [Migration 'Rollback]
toRollback = (Migration 'Rollback -> Migration 'Rollback -> Ordering)
-> [Migration 'Rollback] -> [Migration 'Rollback]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((Migration 'Rollback -> Migration 'Rollback -> Ordering)
-> Migration 'Rollback -> Migration 'Rollback -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip (String -> String -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (String -> String -> Ordering)
-> (Migration 'Rollback -> String)
-> Migration 'Rollback
-> Migration 'Rollback
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Migration 'Rollback -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier)) ([Migration 'Rollback] -> [Migration 'Rollback])
-> [Migration 'Rollback] -> [Migration 'Rollback]
forall a b. (a -> b) -> a -> b
$
        (Migration 'Rollback -> Bool)
-> [Migration 'Rollback] -> [Migration 'Rollback]
forall a. (a -> Bool) -> [a] -> [a]
filter ((String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
toRollbackIdentifiers) (String -> Bool)
-> (Migration 'Rollback -> String) -> Migration 'Rollback -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Migration 'Rollback -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier)
               [Migration 'Rollback]
doneMigrations
      toDo :: [Migration 'Full]
toDo = (Migration 'Full -> Bool) -> [Migration 'Full] -> [Migration 'Full]
forall a. (a -> Bool) -> [a] -> [a]
filter ((String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [String]
toDoIdentifiers) (String -> Bool)
-> (Migration 'Full -> String) -> Migration 'Full -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Migration 'Full -> String
forall (a :: MigrationType). Migration a -> String
migrationIdentifier)
                    [Migration 'Full]
goalMigrations

  (if Bool
inTransaction then Connection -> IO () -> IO ()
forall a. Connection -> IO a -> IO a
withTransaction Connection
conn else IO () -> IO ()
forall a. a -> a
id) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    [Migration 'Rollback] -> (Migration 'Rollback -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Migration 'Rollback]
toRollback ((Migration 'Rollback -> IO ()) -> IO ())
-> (Migration 'Rollback -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Migration 'Rollback -> IO ()
rollbackMigration Connection
conn
    [Migration 'Full] -> (Migration 'Full -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Migration 'Full]
toDo       ((Migration 'Full -> IO ()) -> IO ())
-> (Migration 'Full -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Migration 'Full -> IO ()
doMigration       Connection
conn