{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
module Database.Persist.MySQL
  ( withMySQLPool
  , withMySQLConn
  , createMySQLPool
  , module Database.Persist.Sql
  , MySQLConnectInfo
  , mkMySQLConnectInfo
  , setMySQLConnectInfoPort
  , setMySQLConnectInfoCharset
  , MySQLConf
  , mkMySQLConf
  , mockMigration
  
  , insertOnDuplicateKeyUpdate
  , insertEntityOnDuplicateKeyUpdate
  , insertManyOnDuplicateKeyUpdate
  , insertEntityManyOnDuplicateKeyUpdate
  , HandleUpdateCollision
  , pattern SomeField
  , SomeField
  , copyField
  , copyUnlessNull
  , copyUnlessEmpty
  , copyUnlessEq
  
  , setMySQLConnectInfoTLS
  , MySQLTLS.TrustedCAStore(..)
  , MySQLTLS.makeClientParams
  , MySQLTLS.makeClientParams'
  
  , myConnInfo
  , myPoolSize
) where
import Control.Arrow
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Logger (MonadLogger, runNoLoggingT)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Except (ExceptT, runExceptT)
import Control.Monad.Trans.Reader (runReaderT, ReaderT)
import Control.Monad.Trans.Writer (runWriterT)
import           Data.Conduit (ConduitM, (.|), runConduit, runConduitRes)
import qualified Data.Conduit.List as CL
import Data.Acquire (Acquire, mkAcquire, with)
import Data.Aeson
import Data.Aeson.Types (modifyFailure)
import qualified Data.ByteString.Lazy as BS
import qualified Data.ByteString.Char8  as BSC
import Data.Either (partitionEithers)
import Data.Fixed (Pico)
import Data.Function (on)
import Data.Int (Int64)
import Data.IORef
import Data.List (find, intercalate, sort, groupBy)
import qualified Data.Map as Map
import Data.Monoid ((<>))
import qualified Data.Monoid as Monoid
import Data.Pool (Pool)
import Data.Text (Text, pack)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.IO as T
import Text.Read (readMaybe)
import System.Environment (getEnvironment)
import Database.Persist.Sql
import Database.Persist.Sql.Types.Internal (mkPersistBackend, makeIsolationLevelStatement)
import qualified Database.Persist.Sql.Util as Util
import Database.Persist.MySQLConnectInfoShowInstance ()
import qualified Database.MySQL.Base    as MySQL
import qualified Database.MySQL.TLS     as MySQLTLS
import qualified Network.TLS            as TLS
import qualified System.IO.Streams      as Streams
import qualified Data.Time.Calendar     as Time
import qualified Data.Time.LocalTime    as Time
import qualified Network.Socket         as NetworkSocket
import qualified Data.Word              as Word
import           Data.String (fromString)
withMySQLPool :: (MonadLogger m, MonadUnliftIO m)
              => MySQLConnectInfo
              
              -> Int
              
              -> (Pool SqlBackend -> m a)
              
              -> m a
withMySQLPool ci = withSqlPool $ open' ci
createMySQLPool :: (MonadUnliftIO m, MonadLogger m)
                => MySQLConnectInfo
                
                -> Int
                
                -> m (Pool SqlBackend)
createMySQLPool ci = createSqlPool $ open' ci
withMySQLConn :: (MonadUnliftIO m, MonadLogger m)
              => MySQLConnectInfo
              
              -> (SqlBackend -> m a)
              
              -> m a
withMySQLConn = withSqlConn . open'
connect' :: MySQLConnectInfo -> IO MySQL.MySQLConn
connect' (MySQLConnectInfo innerCi Nothing)
  = MySQL.connect innerCi
connect' (MySQLConnectInfo innerCi (Just tls))
  = MySQLTLS.connect innerCi (tls, "persistent-mysql-haskell")
open' :: MySQLConnectInfo -> LogFunc -> IO SqlBackend
open' ci@(MySQLConnectInfo innerCi _) logFunc = do
    conn <- connect' ci
    autocommit' conn False 
    smap <- newIORef $ Map.empty
    return . mkPersistBackend $ SqlBackend
        { connPrepare    = prepare' conn
        , connStmtMap    = smap
        , connInsertSql  = insertSql'
        , connInsertManySql = Nothing
        , connUpsertSql = Nothing
        , connPutManySql = Just putManySql
        , connClose      = MySQL.close conn
        , connMigrateSql = migrate' innerCi
        , connBegin      = const $ begin' conn
        , connCommit     = const $ commit' conn
        , connRollback   = const $ rollback' conn
        , connEscapeName = pack . escapeDBName
        , connNoLimit    = "LIMIT 18446744073709551615"
        
        
        , connRDBMS      = "mysql"
        , connLimitOffset = decorateSQLWithLimitOffset "LIMIT 18446744073709551615"
        , connLogFunc    = logFunc
        , connMaxParams = Nothing
        , connRepsertManySql = Just repsertManySql
        }
autocommit' :: MySQL.MySQLConn -> Bool -> IO ()
autocommit' conn bool = void $ MySQL.execute conn "SET autocommit=?" [encodeBool bool]
begin' :: MySQL.MySQLConn -> Maybe IsolationLevel -> IO ()
begin' conn mIso
  = void
  $ mapM_ (MySQL.execute_ conn . fromString . makeIsolationLevelStatement) mIso
  >> MySQL.execute_ conn "BEGIN"
commit' :: MySQL.MySQLConn -> IO ()
commit' conn = void $ MySQL.execute_ conn "COMMIT"
rollback' :: MySQL.MySQLConn -> IO ()
rollback' conn = void $ MySQL.execute_ conn "ROLLBACK"
prepare' :: MySQL.MySQLConn -> Text -> IO Statement
prepare' conn sql = do
    let query = MySQL.Query . BS.fromStrict . T.encodeUtf8 $ sql
    return Statement
        { stmtFinalize = return ()
        , stmtReset    = return ()
        , stmtExecute  = execute' conn query
        , stmtQuery    = withStmt' conn query
        }
insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
insertSql' ent vals =
  let sql = pack $ concat
                [ "INSERT INTO "
                , escapeDBName $ entityDB ent
                , "("
                , intercalate "," $ map (escapeDBName . fieldDB) $ entityFields ent
                , ") VALUES("
                , intercalate "," (map (const "?") $ entityFields ent)
                , ")"
                ]
  in case entityPrimary ent of
       Just _ -> ISRManyKeys sql vals
       Nothing -> ISRInsertGet sql "SELECT LAST_INSERT_ID()"
execute' :: MySQL.MySQLConn -> MySQL.Query -> [PersistValue] -> IO Int64
execute' conn query vals
  = fmap (fromIntegral . MySQL.okAffectedRows) $ MySQL.execute conn query (map P vals)
query'
  :: MySQL.QueryParam p => MySQL.MySQLConn -> MySQL.Query -> [p]
  -> IO ([MySQL.ColumnDef], Streams.InputStream [MySQL.MySQLValue])
query' conn qry [] = MySQL.query_ conn qry
query' conn qry ps = MySQL.query  conn qry ps
withStmt' :: MonadIO m
          => MySQL.MySQLConn
          -> MySQL.Query
          -> [PersistValue]
          -> Acquire (ConduitM () [PersistValue] m ())
withStmt' conn query vals
  = fetchRows <$> mkAcquire createResult releaseResult
  where
    createResult = query' conn query (map P vals)
    releaseResult (_, is) = Streams.skipToEof is
    fetchRows (fields, is) = CL.unfoldM getVal is
      where
      
          getters = fmap getGetter fields
          convert = zipWith (\g -> \c -> g c) getters
          getVal s = do
            v <- liftIO $ Streams.read s
            case v of
              (Just r)  -> pure $ Just (convert r, s)
              _         -> pure Nothing
encodeBool :: Bool -> MySQL.MySQLValue
encodeBool True = MySQL.MySQLInt8U 1
encodeBool False = MySQL.MySQLInt8U 0
decodeInteger :: Integral a => a -> PersistValue
decodeInteger = PersistInt64 . fromIntegral
decodeDouble :: Real a => a -> PersistValue
decodeDouble = PersistDouble . realToFrac
newtype P = P PersistValue
instance MySQL.QueryParam P where
  render (P (PersistText t))        = MySQL.putTextField $ MySQL.MySQLText t
  render (P (PersistByteString b))  = MySQL.putTextField $ MySQL.MySQLBytes b
  render (P (PersistInt64 i))       = MySQL.putTextField $ MySQL.MySQLInt64 i
  render (P (PersistDouble d))      = MySQL.putTextField $ MySQL.MySQLDouble d
  render (P (PersistBool b))        = MySQL.putTextField $ encodeBool b
  render (P (PersistDay d))         = MySQL.putTextField $ MySQL.MySQLDate d
  render (P (PersistTimeOfDay t))   = MySQL.putTextField $ MySQL.MySQLTime 0 t
  render (P (PersistUTCTime t))     = MySQL.putTextField . MySQL.MySQLTimeStamp $ Time.utcToLocalTime Time.utc t
  render (P (PersistNull))          = MySQL.putTextField $ MySQL.MySQLNull
  render (P (PersistList l))        = MySQL.putTextField . MySQL.MySQLText $ listToJSON l
  render (P (PersistMap m))         = MySQL.putTextField . MySQL.MySQLText $ mapToJSON m
  render (P (PersistRational r))    =
    MySQL.putTextField $ MySQL.MySQLDecimal $ read $ show (fromRational r :: Pico)
    
  render (P (PersistDbSpecific b))  = MySQL.putTextField $ MySQL.MySQLBytes b
  render (P (PersistArray a))       = MySQL.render (P (PersistList a))
  render (P (PersistObjectId _))    =
    error "Refusing to serialize a PersistObjectId to a MySQL value"
type Getter a = MySQL.MySQLValue -> a
getGetter :: MySQL.ColumnDef -> Getter PersistValue
getGetter _field = go
  where
    
    go (MySQL.MySQLInt8U  v) = decodeInteger v
    go (MySQL.MySQLInt8   v) = decodeInteger v
    go (MySQL.MySQLInt16U v) = decodeInteger v
    go (MySQL.MySQLInt16  v) = decodeInteger v
    go (MySQL.MySQLInt32U v) = decodeInteger v
    go (MySQL.MySQLInt32  v) = decodeInteger v
    go (MySQL.MySQLInt64U v) = decodeInteger v
    go (MySQL.MySQLInt64  v) = decodeInteger v
    go (MySQL.MySQLBit    v) = decodeInteger v
    
    
    go (MySQL.MySQLFloat    v) = decodeDouble v
    go (MySQL.MySQLDouble   v) = decodeDouble v
    go (MySQL.MySQLDecimal  v) = decodeDouble v
    
    go (MySQL.MySQLBytes  v) = PersistByteString v
    go (MySQL.MySQLText   v) = PersistText v
    
    
    go (MySQL.MySQLDateTime   v) = PersistUTCTime $ Time.localTimeToUTC Time.utc v
    go (MySQL.MySQLTimeStamp  v) = PersistUTCTime $ Time.localTimeToUTC Time.utc v
    go (MySQL.MySQLYear       v) = PersistDay (Time.fromGregorian (fromIntegral v) 1 1)
    go (MySQL.MySQLDate       v) = PersistDay v
    go (MySQL.MySQLTime _     v) = PersistTimeOfDay v
    
    go (MySQL.MySQLNull        ) = PersistNull
    
    go (MySQL.MySQLGeometry   v) = PersistDbSpecific v
migrate' :: MySQL.ConnectInfo
         -> [EntityDef]
         -> (Text -> IO Statement)
         -> EntityDef
         -> IO (Either [Text] [(Bool, Text)])
migrate' connectInfo allDefs getter val = do
    let name = entityDB val
    (idClmn, old) <- getColumns connectInfo getter val
    let (newcols, udefs, fdefs) = mkColumns allDefs val
    let udspair = map udToPair udefs
    case (idClmn, old, partitionEithers old) of
      
      ([], [], _) -> do
        let uniques = flip concatMap udspair $ \(uname, ucols) ->
                      [ AlterTable name $
                        AddUniqueConstraint uname $
                        map (findTypeAndMaxLen name) ucols ]
        let foreigns = do
              Column { cName=cname, cReference=Just (refTblName, refConstraintName) } <- newcols
              return $ AlterColumn name (refTblName, addReference allDefs refConstraintName refTblName cname)
        let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
                                        in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
        return $ Right $ map showAlterDb $ (addTable newcols val): uniques ++ foreigns ++ foreignsAlt
      
      (_, _, ([], old')) -> do
        let excludeForeignKeys (xs,ys) = (map (\c -> case cReference c of
                                                    Just (_,fk) -> case find (\f -> fk == foreignConstraintNameDBName f) fdefs of
                                                                     Just _ -> c { cReference = Nothing }
                                                                     Nothing -> c
                                                    Nothing -> c) xs,ys)
            (acs, ats) = getAlters allDefs name (newcols, udspair) $ excludeForeignKeys $ partitionEithers old'
            acs' = map (AlterColumn name) acs
            ats' = map (AlterTable  name) ats
        return $ Right $ map showAlterDb $ acs' ++ ats'
      
      (_, _, (errs, _)) -> return $ Left errs
      where
        findTypeAndMaxLen tblName col = let (col', ty) = findTypeOfColumn allDefs tblName col
                                            (_, ml) = findMaxLenOfColumn allDefs tblName col
                                         in (col', ty, ml)
addTable :: [Column] -> EntityDef -> AlterDB
addTable cols entity = AddTable $ concat
           
           [ "CREATe TABLE "
           , escapeDBName name
           , "("
           , idtxt
           , if null cols then [] else ","
           , intercalate "," $ map showColumn cols
           , ")"
           ]
    where
      name = entityDB entity
      idtxt = case entityPrimary entity of
                Just pdef -> concat [" PRIMARY KEY (", intercalate "," $ map (escapeDBName . fieldDB) $ compositeFields pdef, ")"]
                Nothing ->
                  let defText = defaultAttribute $ fieldAttrs $ entityId entity
                      sType = fieldSqlType $ entityId entity
                      autoIncrementText = case (sType, defText) of
                        (SqlInt64, Nothing) -> " AUTO_INCREMENT"
                        _ -> ""
                      maxlen = findMaxLenOfField (entityId entity)
                  in concat
                         [ escapeDBName $ fieldDB $ entityId entity
                         , " " <> showSqlType sType maxlen False
                         , " NOT NULL"
                         , autoIncrementText
                         , " PRIMARY KEY"
                         ]
findTypeOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, FieldType)
findTypeOfColumn allDefs name col =
    maybe (error $ "Could not find type of column " ++
                   show col ++ " on table " ++ show name ++
                   " (allDefs = " ++ show allDefs ++ ")")
          ((,) col) $ do
            entDef   <- find ((== name) . entityDB) allDefs
            fieldDef <- find ((== col)  . fieldDB) (entityFields entDef)
            return (fieldType fieldDef)
findMaxLenOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, Integer)
findMaxLenOfColumn allDefs name col =
   maybe (col, 200)
         ((,) col) $ do
           entDef     <- find ((== name) . entityDB) allDefs
           fieldDef   <- find ((== col) . fieldDB) (entityFields entDef)
           findMaxLenOfField fieldDef
findMaxLenOfField :: FieldDef -> Maybe Integer
findMaxLenOfField fieldDef = do
    maxLenAttr <- find ((T.isPrefixOf "maxlen=") . T.toLower) (fieldAttrs fieldDef)
    readMaybe . T.unpack . T.drop 7 $ maxLenAttr
addReference :: [EntityDef] -> DBName -> DBName -> DBName -> AlterColumn
addReference allDefs fkeyname reftable cname = AddReference reftable fkeyname [cname] referencedColumns
    where
      referencedColumns = maybe (error $ "Could not find ID of entity " ++ show reftable
                                  ++ " (allDefs = " ++ show allDefs ++ ")")
                                id $ do
                                  entDef <- find ((== reftable) . entityDB) allDefs
                                  return $ map fieldDB $ entityKeyFields entDef
data AlterColumn = Change Column
                 | Add' Column
                 | Drop
                 | Default String
                 | NoDefault
                 | Update' String
                 
                 | AddReference
                    DBName 
                    DBName 
                    [DBName] 
                    [DBName] 
                 | DropReference DBName
type AlterColumn' = (DBName, AlterColumn)
data AlterTable = AddUniqueConstraint DBName [(DBName, FieldType, Integer)]
                | DropUniqueConstraint DBName
data AlterDB = AddTable String
             | AlterColumn DBName AlterColumn'
             | AlterTable DBName AlterTable
udToPair :: UniqueDef -> (DBName, [DBName])
udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
getColumns :: MySQL.ConnectInfo
           -> (Text -> IO Statement)
           -> EntityDef
           -> IO ( [Either Text (Either Column (DBName, [DBName]))] 
                 , [Either Text (Either Column (DBName, [DBName]))] 
                 )
getColumns connectInfo getter def = do
    
    stmtIdClmn <- getter $ T.concat
      [ "SELECT COLUMN_NAME, "
      ,   "IS_NULLABLE, "
      ,   "DATA_TYPE, "
      ,   "COLUMN_DEFAULT "
      , "FROM INFORMATION_SCHEMA.COLUMNS "
      , "WHERE TABLE_SCHEMA = ? "
      ,   "AND TABLE_NAME   = ? "
      ,   "AND COLUMN_NAME  = ?"
      ]
    inter1 <- with (stmtQuery stmtIdClmn vals) (\src -> runConduit $ src .| CL.consume)
    ids <- runConduitRes $ CL.sourceList inter1 .| helperClmns 
    
    stmtClmns <- getter $ T.concat
      [ "SELECT COLUMN_NAME, "
      ,   "IS_NULLABLE, "
      ,   "DATA_TYPE, "
      ,   "COLUMN_TYPE, "
      ,   "CHARACTER_MAXIMUM_LENGTH, "
      ,   "NUMERIC_PRECISION, "
      ,   "NUMERIC_SCALE, "
      ,   "COLUMN_DEFAULT "
      , "FROM INFORMATION_SCHEMA.COLUMNS "
      , "WHERE TABLE_SCHEMA = ? "
      ,   "AND TABLE_NAME   = ? "
      ,   "AND COLUMN_NAME <> ?"
      ]
    inter2 <- with (stmtQuery stmtClmns vals) (\src -> runConduitRes $ src .| CL.consume)
    cs <- runConduitRes $ CL.sourceList inter2 .| helperClmns 
    
    stmtCntrs <- getter $ T.concat
      [ "SELECT CONSTRAINT_NAME, "
      ,   "COLUMN_NAME "
      , "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
      , "WHERE TABLE_SCHEMA = ? "
      ,   "AND TABLE_NAME   = ? "
      ,   "AND COLUMN_NAME <> ? "
      ,   "AND CONSTRAINT_NAME <> 'PRIMARY' "
      ,   "AND REFERENCED_TABLE_SCHEMA IS NULL "
      , "ORDER BY CONSTRAINT_NAME, "
      ,   "COLUMN_NAME"
      ]
    us <- with (stmtQuery stmtCntrs vals) (\src -> runConduitRes $ src .| helperCntrs)
    
    return (ids, cs ++ us)
  where
    vals = [ PersistText $ T.decodeUtf8 $ MySQL.ciDatabase connectInfo
           , PersistText $ unDBName $ entityDB def
           , PersistText $ unDBName $ fieldDB $ entityId def ]
    helperClmns = CL.mapM getIt .| CL.consume
        where
          getIt = fmap (either Left (Right . Left)) .
                  liftIO .
                  getColumn connectInfo getter (entityDB def)
    helperCntrs = do
      let check [ PersistText cntrName
                , PersistText clmnName] = return ( cntrName, clmnName )
          check other = fail $ "helperCntrs: unexpected " ++ show other
      rows <- mapM check =<< CL.consume
      return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
             $ groupBy ((==) `on` fst) rows
getColumn :: MySQL.ConnectInfo
          -> (Text -> IO Statement)
          -> DBName
          -> [PersistValue]
          -> IO (Either Text Column)
getColumn connectInfo getter tname [ PersistText cname
                                   , PersistText null_
                                   , PersistText dataType
                                   , PersistText colType
                                   , colMaxLen
                                   , colPrecision
                                   , colScale
                                   , default'] =
    fmap (either (Left . pack) Right) $
    runExceptT $ do
      
      default_ <- case default' of
                    PersistNull   -> return Nothing
                    PersistText t -> return (Just t)
                    PersistByteString bs ->
                      case T.decodeUtf8' bs of
                        Left exc -> fail $ "Invalid default column: " ++
                                           show default' ++ " (error: " ++
                                           show exc ++ ")"
                        Right t  -> return (Just t)
                    _ -> fail $ "Invalid default column: " ++ show default'
      
      stmt <- lift . getter . T.concat $
        [ "SELECT REFERENCED_TABLE_NAME, "
        ,   "CONSTRAINT_NAME, "
        ,   "ORDINAL_POSITION "
        , "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
        , "WHERE TABLE_SCHEMA = ? "
        ,   "AND TABLE_NAME   = ? "
        ,   "AND COLUMN_NAME  = ? "
        ,   "AND REFERENCED_TABLE_SCHEMA = ? "
        , "ORDER BY CONSTRAINT_NAME, "
        ,   "COLUMN_NAME"
        ]
      let vars = [ PersistText $ T.decodeUtf8 $ MySQL.ciDatabase connectInfo
                 , PersistText $ unDBName $ tname
                 , PersistText cname
                 , PersistText $ T.decodeUtf8 $ MySQL.ciDatabase connectInfo ]
      cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
      ref <- case cntrs of
               [] -> return Nothing
               [[PersistText tab, PersistText ref, PersistInt64 pos]] ->
                   return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
               _ -> fail "MySQL.getColumn/getRef: never here"
      let colMaxLen' = case colMaxLen of
            PersistInt64 l -> Just (fromIntegral l)
            _ -> Nothing
          ci = ColumnInfo
            { ciColumnType = colType
            , ciMaxLength = colMaxLen'
            , ciNumericPrecision = colPrecision
            , ciNumericScale = colScale
            }
      (typ, maxLen) <- parseColumnType dataType ci
      
      return Column
        { cName = DBName $ cname
        , cNull = null_ == "YES"
        , cSqlType = typ
        , cDefault = default_
        , cDefaultConstraintName = Nothing
        , cMaxLen = maxLen
        , cReference = ref
        }
getColumn _ _ _ x =
    return $ Left $ pack $ "Invalid result from INFORMATION_SCHEMA: " ++ show x
data ColumnInfo = ColumnInfo
  { ciColumnType :: Text
  , ciMaxLength :: Maybe Integer
  , ciNumericPrecision :: PersistValue
  , ciNumericScale :: PersistValue
  }
parseColumnType :: Text -> ColumnInfo -> ExceptT String IO (SqlType, Maybe Integer)
parseColumnType "tinyint" ci | ciColumnType ci == "tinyint(1)" = return (SqlBool, Nothing)
parseColumnType "int" ci | ciColumnType ci == "int(11)"        = return (SqlInt32, Nothing)
parseColumnType "bigint" ci | ciColumnType ci == "bigint(20)"  = return (SqlInt64, Nothing)
parseColumnType x@("double") ci | ciColumnType ci == x         = return (SqlReal, Nothing)
parseColumnType "decimal" ci                                   =
  case (ciNumericPrecision ci, ciNumericScale ci) of
    (PersistInt64 p, PersistInt64 s) ->
      return (SqlNumeric (fromIntegral p) (fromIntegral s), Nothing)
    _ ->
      fail "missing DECIMAL precision in DB schema"
parseColumnType "varchar" ci                                   = return (SqlString, ciMaxLength ci)
parseColumnType "text" _                                       = return (SqlString, Nothing)
parseColumnType "varbinary" ci                                 = return (SqlBlob, ciMaxLength ci)
parseColumnType "blob" _                                       = return (SqlBlob, Nothing)
parseColumnType "time" _                                       = return (SqlTime, Nothing)
parseColumnType "datetime" _                                   = return (SqlDayTime, Nothing)
parseColumnType "date" _                                       = return (SqlDay, Nothing)
parseColumnType _ ci                                           = return (SqlOther (ciColumnType ci), Nothing)
getAlters :: [EntityDef]
          -> DBName
          -> ([Column], [(DBName, [DBName])])
          -> ([Column], [(DBName, [DBName])])
          -> ([AlterColumn'], [AlterTable])
getAlters allDefs tblName (c1, u1) (c2, u2) =
    (getAltersC c1 c2, getAltersU u1 u2)
  where
    getAltersC [] old = concatMap dropColumn old
    getAltersC (new:news) old =
        let (alters, old') = findAlters tblName allDefs new old
         in alters ++ getAltersC news old'
    dropColumn col =
      map ((,) (cName col)) $
        [DropReference n | Just (_, n) <- [cReference col]] ++
        [Drop]
    getAltersU [] old = map (DropUniqueConstraint . fst) old
    getAltersU ((name, cols):news) old =
        case lookup name old of
            Nothing ->
                AddUniqueConstraint name (map findTypeAndMaxLen cols) : getAltersU news old
            Just ocols ->
                let old' = filter (\(x, _) -> x /= name) old
                 in if sort cols == ocols
                        then getAltersU news old'
                        else  DropUniqueConstraint name
                            : AddUniqueConstraint name (map findTypeAndMaxLen cols)
                            : getAltersU news old'
        where
          findTypeAndMaxLen col = let (col', ty) = findTypeOfColumn allDefs tblName col
                                      (_, ml) = findMaxLenOfColumn allDefs tblName col
                                   in (col', ty, ml)
findAlters :: DBName -> [EntityDef] -> Column -> [Column] -> ([AlterColumn'], [Column])
findAlters _tblName allDefs col@(Column name isNull type_ def _defConstraintName maxLen ref) cols =
    case filter ((name ==) . cName) cols of
    
        [] -> case ref of
               Nothing -> ([(name, Add' col)],[])
               Just (tname, cname) -> let cnstr = [addReference allDefs cname tname name]
                                  in (map ((,) tname) (Add' col : cnstr), cols)
        Column _ isNull' type_' def' _defConstraintName' maxLen' ref':_ ->
            let 
                refDrop = case (ref == ref', ref') of
                            (False, Just (_, cname)) -> [(name, DropReference cname)]
                            _ -> []
                refAdd  = case (ref == ref', ref) of
                            (False, Just (tname, cname)) -> [(tname, addReference allDefs cname tname name)]
                            _ -> []
                
                modType | showSqlType type_ maxLen False `ciEquals` showSqlType type_' maxLen' False && isNull == isNull' = []
                        | otherwise = [(name, Change col)]
                
                
                modDef | def == def' = []
                       | otherwise   = case def of
                                         Nothing -> [(name, NoDefault)]
                                         Just s -> if T.toUpper s == "NULL" then []
                                                   else [(name, Default $ T.unpack s)]
            in ( refDrop ++ modType ++ modDef ++ refAdd
               , filter ((name /=) . cName) cols )
  where
    ciEquals x y = T.toCaseFold (T.pack x) == T.toCaseFold (T.pack y)
showColumn :: Column -> String
showColumn (Column n nu t def _defConstraintName maxLen ref) = concat
    [ escapeDBName n
    , " "
    , showSqlType t maxLen True
    , " "
    , if nu then "NULL" else "NOT NULL"
    , case def of
        Nothing -> ""
        Just s -> 
                  if T.toUpper s == "NULL" then ""
                  else " DEFAULT " ++ T.unpack s
    , case ref of
        Nothing -> ""
        Just (s, _) -> " REFERENCES " ++ escapeDBName s
    ]
showSqlType :: SqlType
            -> Maybe Integer 
            -> Bool 
            -> String
showSqlType SqlBlob    Nothing    _     = "BLOB"
showSqlType SqlBlob    (Just i)   _     = "VARBINARY(" ++ show i ++ ")"
showSqlType SqlBool    _          _     = "TINYINT(1)"
showSqlType SqlDay     _          _     = "DATE"
showSqlType SqlDayTime _          _     = "DATETIME"
showSqlType SqlInt32   _          _     = "INT(11)"
showSqlType SqlInt64   _          _     = "BIGINT"
showSqlType SqlReal    _          _     = "DOUBLE"
showSqlType (SqlNumeric s prec) _ _     = "NUMERIC(" ++ show s ++ "," ++ show prec ++ ")"
showSqlType SqlString  Nothing    True  = "TEXT CHARACTER SET utf8"
showSqlType SqlString  Nothing    False = "TEXT"
showSqlType SqlString  (Just i)   True  = "VARCHAR(" ++ show i ++ ") CHARACTER SET utf8"
showSqlType SqlString  (Just i)   False = "VARCHAR(" ++ show i ++ ")"
showSqlType SqlTime    _          _     = "TIME"
showSqlType (SqlOther t) _        _     = T.unpack t
showAlterDb :: AlterDB -> (Bool, Text)
showAlterDb (AddTable s) = (False, pack s)
showAlterDb (AlterColumn t (c, ac)) =
    (isUnsafe ac, pack $ showAlter t (c, ac))
  where
    isUnsafe Drop = True
    isUnsafe _    = False
showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)
showAlterTable :: DBName -> AlterTable -> String
showAlterTable table (AddUniqueConstraint cname cols) = concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " ADD CONSTRAINT "
    , escapeDBName cname
    , " UNIQUE("
    , intercalate "," $ map escapeDBName' cols
    , ")"
    ]
    where
      escapeDBName' (name, (FTTypeCon _ "Text"      ), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
      escapeDBName' (name, (FTTypeCon _ "String"    ), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
      escapeDBName' (name, (FTTypeCon _ "ByteString"), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
      escapeDBName' (name, _                         , _) = escapeDBName name
showAlterTable table (DropUniqueConstraint cname) = concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " DROP INDEX "
    , escapeDBName cname
    ]
showAlter :: DBName -> AlterColumn' -> String
showAlter table (oldName, Change (Column n nu t def defConstraintName maxLen _ref)) =
    concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " CHANGE "
    , escapeDBName oldName
    , " "
    , showColumn (Column n nu t def defConstraintName maxLen Nothing)
    ]
showAlter table (_, Add' col) =
    concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " ADD COLUMN "
    , showColumn col
    ]
showAlter table (n, Drop) =
    concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " DROP COLUMN "
    , escapeDBName n
    ]
showAlter table (n, Default s) =
    concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " ALTER COLUMN "
    , escapeDBName n
    , " SET DEFAULT "
    , s
    ]
showAlter table (n, NoDefault) =
    concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " ALTER COLUMN "
    , escapeDBName n
    , " DROP DEFAULT"
    ]
showAlter table (n, Update' s) =
    concat
    [ "UPDATE "
    , escapeDBName table
    , " SET "
    , escapeDBName n
    , "="
    , s
    , " WHERE "
    , escapeDBName n
    , " IS NULL"
    ]
showAlter table (_, AddReference reftable fkeyname t2 id2) = concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " ADD CONSTRAINT "
    , escapeDBName fkeyname
    , " FOREIGN KEY("
    , intercalate "," $ map escapeDBName t2
    , ") REFERENCES "
    , escapeDBName reftable
    , "("
    , intercalate "," $ map escapeDBName id2
    , ")"
    ]
showAlter table (_, DropReference cname) = concat
    [ "ALTER TABLE "
    , escapeDBName table
    , " DROP FOREIGN KEY "
    , escapeDBName cname
    ]
escape :: DBName -> Text
escape = T.pack . escapeDBName
escapeDBName :: DBName -> String
escapeDBName (DBName s) = '`' : go (T.unpack s)
    where
      go ('`':xs) = '`' : '`' : go xs
      go ( x :xs) =     x     : go xs
      go ""       = "`"
data MySQLConf = MySQLConf
    MySQLConnectInfo
    Int
    deriving Show
myConnInfo :: MySQLConf -> MySQLConnectInfo
myConnInfo (MySQLConf c _) = c
myPoolSize :: MySQLConf -> Int
myPoolSize (MySQLConf _ p) = p
setMyConnInfo :: MySQLConnectInfo -> MySQLConf -> MySQLConf
setMyConnInfo c (MySQLConf _ p) = MySQLConf c p
mkMySQLConf
  :: MySQLConnectInfo  
  -> Int               
  -> MySQLConf
mkMySQLConf = MySQLConf
data MySQLConnectInfo = MySQLConnectInfo
  { innerConnInfo :: MySQL.ConnectInfo
  , innerConnTLS  :: (Maybe TLS.ClientParams)
  } deriving Show
mkMySQLConnectInfo
  :: NetworkSocket.HostName  
  -> BSC.ByteString          
  -> BSC.ByteString          
  -> BSC.ByteString          
  -> MySQLConnectInfo
mkMySQLConnectInfo host user pass db
  = MySQLConnectInfo innerCi Nothing
  where
    innerCi = MySQL.defaultConnectInfo {
        MySQL.ciHost     = host
      , MySQL.ciUser     = user
      , MySQL.ciPassword = pass
      , MySQL.ciDatabase = db
    }
setMySQLConnectInfoPort
  :: NetworkSocket.PortNumber -> MySQLConnectInfo -> MySQLConnectInfo
setMySQLConnectInfoPort port ci
  = ci {innerConnInfo = innerCi { MySQL.ciPort = port } }
  where innerCi = innerConnInfo ci
setMySQLConnectInfoCharset
  :: Word.Word8       
  -> MySQLConnectInfo 
  -> MySQLConnectInfo
setMySQLConnectInfoCharset charset ci
  = ci {innerConnInfo = innerCi { MySQL.ciCharset = charset } }
  where innerCi = innerConnInfo ci
setMySQLConnectInfoTLS
  :: TLS.ClientParams 
  -> MySQLConnectInfo 
  -> MySQLConnectInfo
setMySQLConnectInfoTLS tls ci
  = ci {innerConnTLS = Just tls}
instance FromJSON MySQLConf where
    parseJSON v = modifyFailure ("Persistent: error loading MySQL conf: " ++) $
      flip (withObject "MySQLConf") v $ \o -> do
        database <- o .: "database"
        host     <- o .: "host"
        port     <- o .: "port"
        user     <- o .: "user"
        password <- o .: "password"
        pool     <- o .: "poolsize"
        let ci = MySQL.defaultConnectInfo
                   { MySQL.ciHost     = host
                   , MySQL.ciPort     = fromIntegral (port :: Word)
                   , MySQL.ciUser     = BSC.pack user
                   , MySQL.ciPassword = BSC.pack password
                   , MySQL.ciDatabase = BSC.pack database
                   }
        return $ MySQLConf (MySQLConnectInfo ci Nothing) pool
instance PersistConfig MySQLConf where
    type PersistConfigBackend MySQLConf = SqlPersistT
    type PersistConfigPool    MySQLConf = ConnectionPool
    createPoolConfig (MySQLConf cs size)
      = runNoLoggingT $ createMySQLPool cs size 
    runPool _ = runSqlPool
    loadConfig = parseJSON
    applyEnv conf = do
        env <- getEnvironment
        let maybeEnv old var = maybe old id $ fmap BSC.pack $ lookup ("MYSQL_" ++ var) env
        let innerCi = innerConnInfo . myConnInfo $ conf
        let innerCiNew = case innerCi of
                MySQL.ConnectInfo
                  { MySQL.ciHost     = host
                  , MySQL.ciPort     = port
                  , MySQL.ciUser     = user
                  , MySQL.ciPassword = password
                  , MySQL.ciDatabase = database
                  } -> (innerCi)
                        { MySQL.ciHost     = BSC.unpack $ maybeEnv (BSC.pack host) "HOST"
                        , MySQL.ciPort     = read (BSC.unpack $ maybeEnv (BSC.pack $ show port) "PORT")
                        , MySQL.ciUser     = maybeEnv user "USER"
                        , MySQL.ciPassword = maybeEnv password "PASSWORD"
                        , MySQL.ciDatabase = maybeEnv database "DATABASE"
                        }
        return $ setMyConnInfo (MySQLConnectInfo innerCiNew Nothing) conf
mockMigrate :: MySQL.ConnectInfo
         -> [EntityDef]
         -> (Text -> IO Statement)
         -> EntityDef
         -> IO (Either [Text] [(Bool, Text)])
mockMigrate _connectInfo allDefs _getter val = do
    let name = entityDB val
    let (newcols, udefs, fdefs) = mkColumns allDefs val
    let udspair = map udToPair udefs
    case () of
      
      () -> do
        let uniques = flip concatMap udspair $ \(uname, ucols) ->
                      [ AlterTable name $
                        AddUniqueConstraint uname $
                        map (findTypeAndMaxLen name) ucols ]
        let foreigns = do
              Column { cName=cname, cReference=Just (refTblName, refConstraintName) } <- newcols
              return $ AlterColumn name (refTblName, addReference allDefs refConstraintName refTblName cname)
        let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
                                        in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
        return $ Right $ map showAlterDb $ (addTable newcols val): uniques ++ foreigns ++ foreignsAlt
    
      where
        findTypeAndMaxLen tblName col = let (col', ty) = findTypeOfColumn allDefs tblName col
                                            (_, ml) = findMaxLenOfColumn allDefs tblName col
                                         in (col', ty, ml)
mockMigration :: Migration -> IO ()
mockMigration mig = do
  smap <- newIORef $ Map.empty
  let sqlbackend = SqlBackend { connPrepare = \_ -> do
                                             return Statement
                                                        { stmtFinalize = return ()
                                                        , stmtReset = return ()
                                                        , stmtExecute = undefined
                                                        , stmtQuery = \_ -> return $ return ()
                                                        },
                             connInsertManySql = Nothing,
                             connInsertSql = undefined,
                             connStmtMap = smap,
                             connClose = undefined,
                             connMigrateSql = mockMigrate undefined,
                             connBegin = undefined,
                             connCommit = undefined,
                             connRollback = undefined,
                             connEscapeName = undefined,
                             connNoLimit = undefined,
                             connRDBMS = undefined,
                             connLimitOffset = undefined,
                             connLogFunc = undefined,
                             connUpsertSql = undefined,
                             connPutManySql = undefined,
                             connMaxParams = Nothing,
                             connRepsertManySql = Nothing
                             }
      result = runReaderT . runWriterT . runWriterT $ mig
  resp <- result sqlbackend
  mapM_ T.putStrLn $ map snd $ snd resp
insertOnDuplicateKeyUpdate
  :: ( backend ~ PersistEntityBackend record
     , PersistEntity record
     , MonadIO m
     , PersistStore backend
     , BackendCompatible SqlBackend backend
     )
  => record
  -> [Update record]
  -> ReaderT backend m ()
insertOnDuplicateKeyUpdate record =
  insertManyOnDuplicateKeyUpdate [record] []
insertEntityOnDuplicateKeyUpdate
  :: ( backend ~ PersistEntityBackend record
     , PersistEntity record
     , MonadIO m
     , PersistStore backend
     , BackendCompatible SqlBackend backend
     )
  => Entity record
  -> [Update record]
  -> ReaderT backend m ()
insertEntityOnDuplicateKeyUpdate entity =
  insertEntityManyOnDuplicateKeyUpdate [entity] []
data HandleUpdateCollision record where
  
  CopyField :: EntityField record typ -> HandleUpdateCollision record
  
  CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
type SomeField = HandleUpdateCollision
pattern SomeField :: EntityField record typ -> SomeField record
pattern SomeField x = CopyField x
{-# DEPRECATED SomeField "The type SomeField is deprecated. Use the type HandleUpdateCollision instead, and use the function copyField instead of the data constructor." #-}
copyUnlessNull :: PersistField typ => EntityField record (Maybe typ) -> HandleUpdateCollision record
copyUnlessNull field = CopyUnlessEq field Nothing
copyUnlessEmpty :: (Monoid.Monoid typ, PersistField typ) => EntityField record typ -> HandleUpdateCollision record
copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty
copyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
copyUnlessEq = CopyUnlessEq
copyField :: PersistField typ => EntityField record typ -> HandleUpdateCollision record
copyField = CopyField
insertManyOnDuplicateKeyUpdate
    :: forall record backend m.
    ( backend ~ PersistEntityBackend record
    , BackendCompatible SqlBackend backend
    , PersistEntity record
    , MonadIO m
    )
    => [record] 
    -> [HandleUpdateCollision record] 
    -> [Update record] 
    -> ReaderT backend m ()
insertManyOnDuplicateKeyUpdate [] _ _ = return ()
insertManyOnDuplicateKeyUpdate records fieldValues updates =
    uncurry rawExecute
    $ mkBulkInsertQuery (Left records) fieldValues updates
insertEntityManyOnDuplicateKeyUpdate
    :: forall record backend m.
    ( backend ~ PersistEntityBackend record
    , BackendCompatible SqlBackend backend
    , PersistEntity record
    , MonadIO m
    )
    => [Entity record] 
    -> [HandleUpdateCollision record] 
    -> [Update record] 
    -> ReaderT backend m ()
insertEntityManyOnDuplicateKeyUpdate [] _ _ = return ()
insertEntityManyOnDuplicateKeyUpdate entities fieldValues updates =
    uncurry rawExecute
    $ mkBulkInsertQuery (Right entities) fieldValues updates
mkBulkInsertQuery
    :: PersistEntity record
    => Either [record] [Entity record] 
    -> [HandleUpdateCollision record] 
    -> [Update record] 
    -> (Text, [PersistValue])
mkBulkInsertQuery records fieldValues updates =
    (q, recordValues <> updsValues <> copyUnlessValues)
  where
    mfieldDef x = case x of
        CopyField rec -> Right (fieldDbToText (persistFieldDef rec))
        CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val)
    (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues
    fieldDbToText = T.pack . escapeDBName . fieldDB
    entityDef' = entityDef $ either id (map entityVal) records
    firstField = case entityFieldNames of
        [] -> error "The entity you're trying to insert does not have any fields."
        (field:_) -> field
    entityFieldNames = map fieldDbToText $ case records of
      Left _  ->                       entityFields entityDef'
      Right _ -> entityId entityDef' : entityFields entityDef'
    tableName = T.pack . escapeDBName . entityDB $ entityDef'
    copyUnlessValues = map snd fieldsToMaybeCopy
    values = either (map $ map toPersistValue . toPersistFields) (map entityValues) records
    recordValues = concat values
    recordPlaceholders = Util.commaSeparated $ map (Util.parenWrapped . Util.commaSeparated . map (const "?")) values
    mkCondFieldSet n _ = T.concat
        [ n
        , "=COALESCE("
        ,   "NULLIF("
        ,     "VALUES(", n, "),"
        ,     "?"
        ,   "),"
        ,   n
        , ")"
        ]
    condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy
    fieldSets = map (\n -> T.concat [n, "=VALUES(", n, ")"]) updateFieldNames
    upds = map (Util.mkUpdateText' (pack . escapeDBName) id) updates
    updsValues = map (\(Update _ val _) -> toPersistValue val) updates
    updateText = case fieldSets <> upds <> condFieldSets of
        [] -> T.concat [firstField, "=", firstField]
        xs -> Util.commaSeparated xs
    q = T.concat
        [ "INSERT INTO "
        , tableName
        , " ("
        , Util.commaSeparated entityFieldNames
        , ") "
        , " VALUES "
        , recordPlaceholders
        , " ON DUPLICATE KEY UPDATE "
        , updateText
        ]
putManySql :: EntityDef -> Int -> Text
putManySql ent n = putManySql' fields ent n
  where
    fields = entityFields ent
repsertManySql :: EntityDef -> Int -> Text
repsertManySql ent n = putManySql' fields ent n
  where
    fields = keyAndEntityFields ent
putManySql' :: [FieldDef] -> EntityDef -> Int -> Text
putManySql' fields ent n = q
  where
    fieldDbToText = escape . fieldDB
    mkAssignment f = T.concat [f, "=VALUES(", f, ")"]
    table = escape . entityDB $ ent
    columns = Util.commaSeparated $ map fieldDbToText fields
    placeholders = map (const "?") fields
    updates = map (mkAssignment . fieldDbToText) fields
    q = T.concat
        [ "INSERT INTO "
        , table
        , Util.parenWrapped columns
        , " VALUES "
        , Util.commaSeparated . replicate n
            . Util.parenWrapped . Util.commaSeparated $ placeholders
        , " ON DUPLICATE KEY UPDATE "
        , Util.commaSeparated updates
        ]