{-|
Module      :  Database.Persist.Migration.Core
Maintainer  :  Brandon Chinn <brandonchinn178@gmail.com>
Stability   :  experimental
Portability :  portable

Defines a migration framework for the persistent library.
-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

module Database.Persist.Migration.Core
  ( Version
  , OperationPath
  , (~>)
  , Migration
  , MigrationPath(..)
  , opPath
  , MigrateSettings(..)
  , defaultSettings
  , validateMigration
  , runMigration
  , getMigration
  ) where

import Control.Monad (unless, when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Reader (mapReaderT)
import Data.List (nub)
import Data.Maybe (fromMaybe)
import qualified Data.Text as Text
import Data.Time.Clock (getCurrentTime)
import Database.Persist.Migration.Backend (MigrateBackend(..))
import Database.Persist.Migration.Operation (Operation(..), validateOperation)
import Database.Persist.Migration.Operation.Types
    (Column(..), ColumnProp(..), TableConstraint(..))
import Database.Persist.Migration.Utils.Plan (getPath)
import Database.Persist.Migration.Utils.Sql (MigrateSql, executeSql)
import Database.Persist.Sql
    (PersistValue(..), Single(..), SqlPersistT, rawExecute, rawSql)
import Database.Persist.Types (SqlType(..))

-- | The version of a database. An operation migrates from the given version to another version.
--
-- The version must be increasing, such that the lowest version is the first version and the highest
-- version is the most up-to-date version.
--
-- A version represents a version of the database schema. In other words, any set of operations
-- taken to get to version X *MUST* all result in the same database schema.
type Version = Int

-- | The path that an operation takes.
type OperationPath = (Version, Version)

-- | An infix constructor for 'OperationPath'.
(~>) :: Version -> Version -> OperationPath
~> :: Version -> Version -> OperationPath
(~>) = (,)

-- | A migration list that defines operations to manually migrate a database schema.
type Migration = [MigrationPath]

-- | A path representing the operations needed to run to get from one version of the database schema
-- to the next.
data MigrationPath = OperationPath := [Operation]
  deriving (Version -> MigrationPath -> ShowS
[MigrationPath] -> ShowS
MigrationPath -> String
(Version -> MigrationPath -> ShowS)
-> (MigrationPath -> String)
-> ([MigrationPath] -> ShowS)
-> Show MigrationPath
forall a.
(Version -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MigrationPath] -> ShowS
$cshowList :: [MigrationPath] -> ShowS
show :: MigrationPath -> String
$cshow :: MigrationPath -> String
showsPrec :: Version -> MigrationPath -> ShowS
$cshowsPrec :: Version -> MigrationPath -> ShowS
Show)

-- | Get the OperationPath in the MigrationPath.
opPath :: MigrationPath -> OperationPath
opPath :: MigrationPath -> OperationPath
opPath (OperationPath
path := [Operation]
_) = OperationPath
path

-- | Get the current version of the database, or Nothing if none exists.
getCurrVersion :: MonadIO m => MigrateBackend -> SqlPersistT m (Maybe Version)
getCurrVersion :: MigrateBackend -> SqlPersistT m (Maybe Version)
getCurrVersion MigrateBackend
backend = do
  -- create the persistent_migration table if it doesn't already exist
  (IO [MigrateSql] -> m [MigrateSql])
-> ReaderT SqlBackend IO [MigrateSql]
-> ReaderT SqlBackend m [MigrateSql]
forall (m :: * -> *) a (n :: * -> *) b r.
(m a -> n b) -> ReaderT r m a -> ReaderT r n b
mapReaderT IO [MigrateSql] -> m [MigrateSql]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (MigrateBackend -> Operation -> ReaderT SqlBackend IO [MigrateSql]
getMigrationSql MigrateBackend
backend Operation
migrationSchema) ReaderT SqlBackend m [MigrateSql]
-> ([MigrateSql] -> ReaderT SqlBackend m ())
-> ReaderT SqlBackend m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (MigrateSql -> ReaderT SqlBackend m ())
-> [MigrateSql] -> ReaderT SqlBackend m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ MigrateSql -> ReaderT SqlBackend m ()
forall (m :: * -> *). MonadIO m => MigrateSql -> SqlPersistT m ()
executeSql
  [Single Version] -> Maybe Version
forall a. [Single a] -> Maybe a
extractVersion ([Single Version] -> Maybe Version)
-> ReaderT SqlBackend m [Single Version]
-> SqlPersistT m (Maybe Version)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Text -> [PersistValue] -> ReaderT SqlBackend m [Single Version]
forall a (m :: * -> *) backend.
(RawSql a, MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m [a]
rawSql Text
queryVersion []
  where
    migrationSchema :: Operation
migrationSchema = CreateTable :: Text -> [Column] -> [TableConstraint] -> Operation
CreateTable
      { name :: Text
name = Text
"persistent_migration"
      , schema :: [Column]
schema =
          [ Text -> SqlType -> [ColumnProp] -> Column
Column Text
"id" SqlType
SqlInt32 [ColumnProp
NotNull, ColumnProp
AutoIncrement]
          , Text -> SqlType -> [ColumnProp] -> Column
Column Text
"version" SqlType
SqlInt32 [ColumnProp
NotNull]
          , Text -> SqlType -> [ColumnProp] -> Column
Column Text
"label" SqlType
SqlString []
          , Text -> SqlType -> [ColumnProp] -> Column
Column Text
"timestamp" SqlType
SqlDayTime [ColumnProp
NotNull]
          ]
      , constraints :: [TableConstraint]
constraints =
          [ [Text] -> TableConstraint
PrimaryKey [Text
"id"]
          ]
      }
    queryVersion :: Text
queryVersion = Text
"SELECT version FROM persistent_migration ORDER BY timestamp DESC LIMIT 1"
    extractVersion :: [Single a] -> Maybe a
extractVersion = \case
      [] -> Maybe a
forall a. Maybe a
Nothing
      [Single a
v] -> a -> Maybe a
forall a. a -> Maybe a
Just a
v
      [Single a]
_ -> String -> Maybe a
forall a. HasCallStack => String -> a
error String
"Invalid response from the database."

-- | Get the list of operations to run, given the current state of the database.
getOperations :: Migration -> Maybe Version -> Either (Version, Version) [Operation]
getOperations :: [MigrationPath]
-> Maybe Version -> Either OperationPath [Operation]
getOperations [MigrationPath]
migration Maybe Version
mVersion = case [(OperationPath, [Operation])]
-> Version -> Version -> Maybe [[Operation]]
forall a. [(OperationPath, a)] -> Version -> Version -> Maybe [a]
getPath [(OperationPath, [Operation])]
edges Version
start Version
end of
  Just [[Operation]]
path -> [Operation] -> Either OperationPath [Operation]
forall a b. b -> Either a b
Right ([Operation] -> Either OperationPath [Operation])
-> [Operation] -> Either OperationPath [Operation]
forall a b. (a -> b) -> a -> b
$ [[Operation]] -> [Operation]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Operation]]
path
  Maybe [[Operation]]
Nothing -> OperationPath -> Either OperationPath [Operation]
forall a b. a -> Either a b
Left (Version
start, Version
end)
  where
    edges :: [(OperationPath, [Operation])]
edges = (MigrationPath -> (OperationPath, [Operation]))
-> [MigrationPath] -> [(OperationPath, [Operation])]
forall a b. (a -> b) -> [a] -> [b]
map (\(OperationPath
path := [Operation]
ops) -> (OperationPath
path, [Operation]
ops)) [MigrationPath]
migration
    start :: Version
start = Version -> Maybe Version -> Version
forall a. a -> Maybe a -> a
fromMaybe ([MigrationPath] -> Version
getFirstVersion [MigrationPath]
migration) Maybe Version
mVersion
    end :: Version
end = [MigrationPath] -> Version
getLatestVersion [MigrationPath]
migration

-- | Get the first version in the given migration.
getFirstVersion :: Migration -> Version
getFirstVersion :: [MigrationPath] -> Version
getFirstVersion = [Version] -> Version
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum ([Version] -> Version)
-> ([MigrationPath] -> [Version]) -> [MigrationPath] -> Version
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MigrationPath -> Version) -> [MigrationPath] -> [Version]
forall a b. (a -> b) -> [a] -> [b]
map (OperationPath -> Version
forall a b. (a, b) -> a
fst (OperationPath -> Version)
-> (MigrationPath -> OperationPath) -> MigrationPath -> Version
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MigrationPath -> OperationPath
opPath)

-- | Get the most up-to-date version in the given migration.
getLatestVersion :: Migration -> Version
getLatestVersion :: [MigrationPath] -> Version
getLatestVersion = [Version] -> Version
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Version] -> Version)
-> ([MigrationPath] -> [Version]) -> [MigrationPath] -> Version
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MigrationPath -> Version) -> [MigrationPath] -> [Version]
forall a b. (a -> b) -> [a] -> [b]
map (OperationPath -> Version
forall a b. (a, b) -> b
snd (OperationPath -> Version)
-> (MigrationPath -> OperationPath) -> MigrationPath -> Version
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MigrationPath -> OperationPath
opPath)

{- Migration plan and execution -}

-- | Settings to customize migration steps.
newtype MigrateSettings = MigrateSettings
  { MigrateSettings -> Version -> Maybe String
versionToLabel :: Version -> Maybe String
      -- ^ A function to optionally label certain versions
  }

-- | Default migration settings.
defaultSettings :: MigrateSettings
defaultSettings :: MigrateSettings
defaultSettings = MigrateSettings :: (Version -> Maybe String) -> MigrateSettings
MigrateSettings
  { $sel:versionToLabel:MigrateSettings :: Version -> Maybe String
versionToLabel = Maybe String -> Version -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing
  }

-- | Validate the given migration.
validateMigration :: Migration -> Either String ()
validateMigration :: [MigrationPath] -> Either String ()
validateMigration [MigrationPath]
migration = do
  Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([OperationPath] -> Bool
allIncreasing [OperationPath]
opVersions) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$
    String -> Either String ()
forall a b. a -> Either a b
Left String
"Operation versions must be monotonically increasing"
  Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([OperationPath] -> Bool
forall a. Eq a => [a] -> Bool
hasDuplicates [OperationPath]
opVersions) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$
    String -> Either String ()
forall a b. a -> Either a b
Left String
"There may only be one operation per pair of versions"
  where
    opVersions :: [OperationPath]
opVersions = (MigrationPath -> OperationPath)
-> [MigrationPath] -> [OperationPath]
forall a b. (a -> b) -> [a] -> [b]
map MigrationPath -> OperationPath
opPath [MigrationPath]
migration
    allIncreasing :: [OperationPath] -> Bool
allIncreasing = (OperationPath -> Bool) -> [OperationPath] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Version -> Version -> Bool) -> OperationPath -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
(<))
    hasDuplicates :: [a] -> Bool
hasDuplicates [a]
l = [a] -> Version
forall (t :: * -> *) a. Foldable t => t a -> Version
length ([a] -> [a]
forall a. Eq a => [a] -> [a]
nub [a]
l) Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< [a] -> Version
forall (t :: * -> *) a. Foldable t => t a -> Version
length [a]
l

-- | Run the given migration. After successful completion, saves the migration to the database.
runMigration :: MonadIO m => MigrateBackend -> MigrateSettings -> Migration -> SqlPersistT m ()
runMigration :: MigrateBackend
-> MigrateSettings -> [MigrationPath] -> SqlPersistT m ()
runMigration MigrateBackend
backend settings :: MigrateSettings
settings@MigrateSettings{Version -> Maybe String
versionToLabel :: Version -> Maybe String
$sel:versionToLabel:MigrateSettings :: MigrateSettings -> Version -> Maybe String
..} [MigrationPath]
migration = do
  Maybe Version
currVersion <- MigrateBackend -> SqlPersistT m (Maybe Version)
forall (m :: * -> *).
MonadIO m =>
MigrateBackend -> SqlPersistT m (Maybe Version)
getCurrVersion MigrateBackend
backend
  let latestVersion :: Version
latestVersion = [MigrationPath] -> Version
getLatestVersion [MigrationPath]
migration
  case Maybe Version
currVersion of
    Just Version
current | Version
current Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
latestVersion -> () -> SqlPersistT m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Maybe Version
_ -> do
      MigrateBackend
-> MigrateSettings -> [MigrationPath] -> SqlPersistT m [MigrateSql]
forall (m :: * -> *).
MonadIO m =>
MigrateBackend
-> MigrateSettings -> [MigrationPath] -> SqlPersistT m [MigrateSql]
getMigration MigrateBackend
backend MigrateSettings
settings [MigrationPath]
migration SqlPersistT m [MigrateSql]
-> ([MigrateSql] -> SqlPersistT m ()) -> SqlPersistT m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (MigrateSql -> SqlPersistT m ())
-> [MigrateSql] -> SqlPersistT m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ MigrateSql -> SqlPersistT m ()
forall (m :: * -> *). MonadIO m => MigrateSql -> SqlPersistT m ()
executeSql
      UTCTime
now <- IO UTCTime -> ReaderT SqlBackend m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
      Text -> [PersistValue] -> SqlPersistT m ()
forall (m :: * -> *) backend.
(MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m ()
rawExecute Text
"INSERT INTO persistent_migration(version, label, timestamp) VALUES (?, ?, ?)"
        [ Int64 -> PersistValue
PersistInt64 (Int64 -> PersistValue) -> Int64 -> PersistValue
forall a b. (a -> b) -> a -> b
$ Version -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Version
latestVersion
        , Text -> PersistValue
PersistText (Text -> PersistValue) -> Text -> PersistValue
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe (Version -> String
forall a. Show a => a -> String
show Version
latestVersion) (Maybe String -> String) -> Maybe String -> String
forall a b. (a -> b) -> a -> b
$ Version -> Maybe String
versionToLabel Version
latestVersion
        , UTCTime -> PersistValue
PersistUTCTime UTCTime
now
        ]

-- | Get the SQL queries for the given migration.
getMigration :: MonadIO m
  => MigrateBackend
  -> MigrateSettings
  -> Migration
  -> SqlPersistT m [MigrateSql]
getMigration :: MigrateBackend
-> MigrateSettings -> [MigrationPath] -> SqlPersistT m [MigrateSql]
getMigration MigrateBackend
backend MigrateSettings
_ [MigrationPath]
migration = do
  (String -> ReaderT SqlBackend m ())
-> (() -> ReaderT SqlBackend m ())
-> Either String ()
-> ReaderT SqlBackend m ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ReaderT SqlBackend m ()
forall a. HasCallStack => String -> a
error () -> ReaderT SqlBackend m ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String () -> ReaderT SqlBackend m ())
-> Either String () -> ReaderT SqlBackend m ()
forall a b. (a -> b) -> a -> b
$ [MigrationPath] -> Either String ()
validateMigration [MigrationPath]
migration
  Maybe Version
currVersion <- MigrateBackend -> SqlPersistT m (Maybe Version)
forall (m :: * -> *).
MonadIO m =>
MigrateBackend -> SqlPersistT m (Maybe Version)
getCurrVersion MigrateBackend
backend
  [Operation]
operations <- (OperationPath -> ReaderT SqlBackend m [Operation])
-> ([Operation] -> ReaderT SqlBackend m [Operation])
-> Either OperationPath [Operation]
-> ReaderT SqlBackend m [Operation]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either OperationPath -> ReaderT SqlBackend m [Operation]
forall a a a. (Show a, Show a) => (a, a) -> a
badPath [Operation] -> ReaderT SqlBackend m [Operation]
forall (m :: * -> *) a. Monad m => a -> m a
return (Either OperationPath [Operation]
 -> ReaderT SqlBackend m [Operation])
-> Either OperationPath [Operation]
-> ReaderT SqlBackend m [Operation]
forall a b. (a -> b) -> a -> b
$ [MigrationPath]
-> Maybe Version -> Either OperationPath [Operation]
getOperations [MigrationPath]
migration Maybe Version
currVersion
  (String -> ReaderT SqlBackend m ())
-> (() -> ReaderT SqlBackend m ())
-> Either String ()
-> ReaderT SqlBackend m ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ReaderT SqlBackend m ()
forall a. HasCallStack => String -> a
error () -> ReaderT SqlBackend m ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String () -> ReaderT SqlBackend m ())
-> Either String () -> ReaderT SqlBackend m ()
forall a b. (a -> b) -> a -> b
$ (Operation -> Either String ()) -> [Operation] -> Either String ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Operation -> Either String ()
validateOperation [Operation]
operations
  (Operation -> SqlPersistT m [MigrateSql])
-> [Operation] -> SqlPersistT m [MigrateSql]
forall (t :: * -> *) (f :: * -> *) a a.
(Traversable t, Monad f) =>
(a -> f [a]) -> t a -> f [a]
concatMapM ((IO [MigrateSql] -> m [MigrateSql])
-> ReaderT SqlBackend IO [MigrateSql] -> SqlPersistT m [MigrateSql]
forall (m :: * -> *) a (n :: * -> *) b r.
(m a -> n b) -> ReaderT r m a -> ReaderT r n b
mapReaderT IO [MigrateSql] -> m [MigrateSql]
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (ReaderT SqlBackend IO [MigrateSql] -> SqlPersistT m [MigrateSql])
-> (Operation -> ReaderT SqlBackend IO [MigrateSql])
-> Operation
-> SqlPersistT m [MigrateSql]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MigrateBackend -> Operation -> ReaderT SqlBackend IO [MigrateSql]
getMigrationSql MigrateBackend
backend) [Operation]
operations
  where
    badPath :: (a, a) -> a
badPath (a
start, a
end) = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Could not find path: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
start String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" ~> " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
end
    -- Utilities
    concatMapM :: (a -> f [a]) -> t a -> f [a]
concatMapM a -> f [a]
f = (t [a] -> [a]) -> f (t [a]) -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap t [a] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (f (t [a]) -> f [a]) -> (t a -> f (t [a])) -> t a -> f [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> f [a]) -> t a -> f (t [a])
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> f [a]
f