{-# OPTIONS -Wall #-}
{-# LANGUAGE EmptyDataDecls    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | An ODBC backend for persistent.
module Database.Persist.MigratePostgres
    ( getMigrationStrategy
    ) where

import Database.Persist.Sql
import Control.Monad.IO.Class (MonadIO (..))
import qualified Data.Text as T
import Data.Text (pack,Text)

import Data.Either (partitionEithers)
import Control.Arrow
import Data.List (find, intercalate, sort, groupBy)
import Data.Function (on)
import Data.Conduit (connect, (.|))
import qualified Data.Conduit.List as CL
import Data.Maybe (mapMaybe)

import qualified Data.Text.Encoding as T

import Database.Persist.ODBCTypes
import Data.Acquire (with)

#if DEBUG
import Debug.Trace
tracex :: String -> a -> a
tracex = trace
#else
tracex :: String -> a -> a
tracex _ b = b
#endif

getMigrationStrategy :: DBType -> MigrationStrategy
getMigrationStrategy dbtype@Postgres {} =
     MigrationStrategy
                          { dbmsLimitOffset=decorateSQLWithLimitOffset "LIMIT ALL"
                           ,dbmsMigrate=migrate'
                           ,dbmsInsertSql=insertSql'
                           ,dbmsEscape=escape
                           ,dbmsType=dbtype
                          }
getMigrationStrategy dbtype = error $ "Postgres: calling with invalid dbtype " ++ show dbtype

migrate' :: [EntityDef]
         -> (Text -> IO Statement)
         -> EntityDef
         -> IO (Either [Text] [(Bool, Text)])
migrate' allDefs getter val = fmap (fmap $ map showAlterDb) $ do
    let name = entityDB val
    old <- getColumns getter val
    case partitionEithers old of
        ([], old'') -> do
            let old' = partitionEithers old''
            let (newcols', udefs, fdefs) = mkColumns allDefs val
            let newcols = filter (not . safeToRemove val . cName) newcols'
            let udspair = map udToPair udefs
            --let composite = isJust $ entityPrimary val
            if null old
                then do
                    let idtxt = case entityPrimary val of
                                  Just pdef -> concat [" PRIMARY KEY (", intercalate "," $ map (T.unpack . escape . fieldDB) $ compositeFields pdef, ")"]
                                  Nothing   -> concat [T.unpack $ escape $ fieldDB $ entityId val
                                        , " SERIAL PRIMARY KEY UNIQUE"]
                    let addTable = AddTable $ concat
                            -- Lower case e: see Database.Persist.Sql.Migration
                            [ "CREATe TABLE "
                            , T.unpack $ escape name
                            , "("
                            , idtxt
                            , if null newcols then [] else ","
                            , intercalate "," $ map showColumn newcols
                            , ")"
                            ]
                    let uniques = flip concatMap udspair $ \(uname, ucols) ->
                            [AlterTable name $ AddUniqueConstraint uname ucols]
                        references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } -> getAddReference allDefs name refTblName cname (cReference c)) $ filter (\c -> cReference c /= Nothing) newcols
                        foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
                                                    in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
                    return $ Right $ addTable : uniques ++ references ++ foreignsAlt
                else do
                    let (acs, ats) = getAlters allDefs val (newcols, udspair) old'
                    let acs' = map (AlterColumn name) acs
                    let ats' = map (AlterTable name) ats
                    return $ Right $ acs' ++ ats'
        (errs, _) -> return $ Left errs

type SafeToRemove = Bool

data AlterColumn = Type SqlType | IsNull | NotNull | Add' Column | Drop SafeToRemove
                 | Default String | NoDefault | Update' String
                 | AddReference DBName [DBName] [DBName] | DropReference DBName
type AlterColumn' = (DBName, AlterColumn)

data AlterTable = AddUniqueConstraint DBName [DBName]
                | DropConstraint DBName

data AlterDB = AddTable String
             | AlterColumn DBName AlterColumn'
             | AlterTable DBName AlterTable

-- | Returns all of the columns in the given table currently in the database.
getColumns :: (Text -> IO Statement)
           -> EntityDef
           -> IO [Either Text (Either Column (DBName, [DBName]))]
getColumns getter def = do
    let sqlv=concat ["SELECT "
                          ,"column_name "
                          ,",is_nullable "
                          ,",udt_name "
                          ,",column_default "
                          ,",numeric_precision "
                          ,",numeric_scale "
                          ,"FROM information_schema.columns "
                          ,"WHERE table_catalog=current_database() "
                          ,"AND table_schema=current_schema() "
                          ,"AND table_name=? "
                          ,"AND column_name <> ?"]

    stmt <- getter $ pack sqlv
    let vals =
            [ PersistText $ unDBName $ entityDB def
            , PersistText $ unDBName $ fieldDB $ entityId def
            ]
    cs <- with (stmtQuery stmt vals) (`connect` helperClmns)
    let sqlc=concat ["SELECT "
                          ,"c.constraint_name, "
                          ,"c.column_name "
                          ,"FROM information_schema.key_column_usage c, "
                          ,"information_schema.table_constraints k "
                          ,"WHERE c.table_catalog=current_database() "
                          ,"AND c.table_catalog=k.table_catalog "
                          ,"AND c.table_schema=current_schema() "
                          ,"AND c.table_schema=k.table_schema "
                          ,"AND c.table_name=? "
                          ,"AND c.table_name=k.table_name "
                          ,"AND c.column_name <> ? "
                          ,"AND c.ordinal_position=1 "
                          ,"AND c.constraint_name=k.constraint_name "
                          ,"AND k.constraint_type not in ('PRIMARY KEY','FOREIGN KEY') "
                          ,"ORDER BY c.constraint_name, c.column_name"]

    stmt' <- getter $ pack sqlc

    us <- with (stmtQuery stmt' vals) (`connect` helperU)
    liftIO $ putStrLn $ "\n\ngetColumns cs="++show cs++"\n\nus="++show us
    return $ cs ++ us
  where
    getAll front = do
        x <- CL.head
        case x of
            Nothing -> return $ front []
            Just [PersistText con, PersistText col] -> getAll (front . (:) (con, col))
            Just [PersistByteString con, PersistByteString col] -> getAll (front . (:) (T.decodeUtf8 con, T.decodeUtf8 col))
            Just xx -> error $ "oops: unexpected datatype returned odbc postgres  xx="++show xx
    helperU = do
        rows <- getAll id
        return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
               $ groupBy ((==) `on` fst) rows

    helperClmns = CL.mapM getIt .| CL.consume
        where
          getIt = fmap (either Left (Right . Left)) .
                  liftIO .
                  getColumn getter (entityDB def)
{-
    helper = do
        x <- CL.head
        case x of
            Nothing -> tracex "getColumns helper Nothing!!!" $ return []
            Just x' -> do
                col <- liftIO $ getColumn getter (entityDB def) x'
                let col' = tracex ("getColumns helper: col="++show col) $ case col of
                            Left e -> Left e
                            Right c -> Right $ Left c
                cols <- helper
                return $ col' : cols
-}
-- | Check if a column name is listed as the "safe to remove" in the entity
-- list.
safeToRemove :: EntityDef -> DBName -> Bool
safeToRemove def (DBName colName)
    = any (elem "SafeToRemove" . fieldAttrs)
    $ filter ((== (DBName colName)) . fieldDB)
    $ entityFields def

getAlters :: [EntityDef]
          -> EntityDef
          -> ([Column], [(DBName, [DBName])])
          -> ([Column], [(DBName, [DBName])])
          -> ([AlterColumn'], [AlterTable])
getAlters allDefs def (c1, u1) (c2, u2) =
    (getAltersC c1 c2, getAltersU u1 u2)
  where
    getAltersC [] old = map (\x -> (cName x, Drop $ safeToRemove def $ cName x)) old
    getAltersC (new:news) old =
        let (alters, old') = findAlters allDefs (entityDB def) new old
         in alters ++ getAltersC news old'

    getAltersU :: [(DBName, [DBName])]
               -> [(DBName, [DBName])]
               -> [AlterTable]
    getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old
    getAltersU ((name, cols):news) old =
        case lookup name old of
            Nothing -> AddUniqueConstraint name cols : getAltersU news old
            Just ocols ->
                let old' = filter (\(x, _) -> x /= name) old
                 in if sort cols == sort ocols
                        then getAltersU news old'
                        else  DropConstraint name
                            : AddUniqueConstraint name cols
                            : getAltersU news old'

    -- Don't drop constraints which were manually added.
    isManual (DBName x) = "__manual_" `T.isPrefixOf` x

getColumn :: (Text -> IO Statement)
          -> DBName -> [PersistValue]
          -> IO (Either Text Column)
getColumn getter tname [PersistByteString x, PersistByteString y, PersistByteString z, d, npre, nscl] = do
    case d' of
        Left s -> return $ Left s
        Right d'' ->
            case getType (T.decodeUtf8 z) of
                Left s -> return $ Left s
                Right t -> do
                    let cname = DBName $ T.decodeUtf8 x
                    ref <- getRef cname
                    return $ Right Column
                        { cName = cname
                        , cNull = y == "YES"
                        , cSqlType = t
                        , cDefault = d''
                        , cDefaultConstraintName = Nothing
                        , cMaxLen = Nothing
                        , cReference = ref
                        }
  where
    getRef cname = do
        let sql = pack $ concat
                [ "SELECT "
                  ,"tc.table_name, "
                  ,"kcu.column_name, "
                  ,"ccu.table_name AS foreign_table_name, "
                  ,"ccu.column_name AS foreign_column_name, "
                  ,"kcu.ordinal_position "
                  ,"FROM "
                  ,"information_schema.table_constraints AS tc "
                  ,"JOIN information_schema.key_column_usage "
                  ,"AS kcu ON tc.constraint_name = kcu.constraint_name "
                  ,"JOIN information_schema.constraint_column_usage "
                  ,"AS ccu ON ccu.constraint_name = tc.constraint_name "
                  ,"WHERE constraint_type = 'FOREIGN KEY' "
                  ,"and tc.table_name=? "
                  ,"and tc.constraint_name=? "
                  ,"and tc.table_catalog=current_database() "
                  ,"AND tc.table_catalog=kcu.table_catalog "
                  ,"AND tc.table_catalog=ccu.table_catalog "
                  ,"AND tc.table_schema=current_schema() "
                  ,"AND tc.table_schema=kcu.table_schema "
                  ,"AND tc.table_schema=ccu.table_schema "
                  ,"order by tc.table_name,tc.constraint_name, kcu.ordinal_position "
                  ]

        let ref = refName tname cname
        stmt <- getter sql
        with (stmtQuery stmt
                     [ PersistText $ unDBName tname
                     , PersistText $ unDBName ref
                     ]) (`connect` do
            m <- CL.head

            return $ case m of
              Just [PersistText _table, PersistText _col, PersistText reftable, PersistText _refcol, PersistInt64 _pos] -> Just (DBName reftable, ref)
              Just [PersistByteString _table, PersistByteString _col, PersistByteString reftable, PersistByteString _refcol, PersistInt64 _pos] -> Just (DBName (T.decodeUtf8 reftable), ref)
              Nothing -> Nothing
              _ -> error $ "unexpected result found ["++ show m ++ "]" )
    d' = case d of
            PersistNull   -> Right Nothing
            PersistText t -> Right $ Just t
            PersistByteString bs -> Right $ Just $ T.decodeUtf8 bs
            _ -> Left $ pack $ "Invalid default column: " ++ show d
    getType "int4"        = Right $ SqlInt32
    getType "int8"        = Right $ SqlInt64
    getType "varchar"     = Right $ SqlString
    getType "date"        = Right $ SqlDay
    getType "bool"        = Right $ SqlBool
    getType "timestamp"   = Right $ SqlDayTime
--    getType "timestamptz" = Right $ SqlDayTimeZoned
    getType "float4"      = Right $ SqlReal
    getType "float8"      = Right $ SqlReal
    getType "bytea"       = Right $ SqlBlob
    getType "time"        = Right $ SqlTime
    getType "numeric"     = getNumeric npre nscl
    getType a             = Right $ SqlOther a

    getNumeric (PersistInt64 a) (PersistInt64 b) = Right $ SqlNumeric (fromIntegral a) (fromIntegral b)
    getNumeric a b = Left $ pack $ "Can not get numeric field precision, got: " ++ show a ++ " and " ++ show b ++ " as precision and scale"
getColumn _ a2 x =
    return $ Left $ pack $ "Invalid result from information_schema: " ++ show x ++ " a2[" ++ show a2 ++ "]"

findAlters :: [EntityDef] -> DBName -> Column -> [Column] -> ([AlterColumn'], [Column])
findAlters defs tablename col@(Column name isNull sqltype def _defConstraintName _maxLen ref) cols =
    tracex ("\n\n\nfindAlters tablename="++show tablename++ " name="++ show name++" col="++show col++"\ncols="++show cols++"\n\n\n") $ case filter ((name ==) . cName) cols of
        [] -> ([(name, Add' col)], cols)
        Column _ isNull' sqltype' def' defConstraintName' _maxLen' ref':_ ->
            let refDrop Nothing = []
                refDrop (Just (_, cname)) = tracex ("\n\n\n44444 findAlters dropping fkey defConstraintName'="++show defConstraintName' ++" name="++show name++" cname="++show cname++" tablename="++show tablename++"\n\n\n") $
                                             [(name, DropReference cname)]
                refAdd Nothing = []
                refAdd (Just (tname, a)) = tracex ("\n\n\n33333 findAlters adding fkey defConstraintName'="++show defConstraintName' ++" name="++show name++" tname="++show tname++" a="++show a++" tablename="++show tablename++"\n\n\n") $
                                           case find ((==tname) . entityDB) defs of
                                                Just refdef -> [(tname, AddReference a [name] [fieldDB $ entityId refdef])]
                                                Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]"
                modRef = tracex ("modType: sqltype[" ++ show sqltype ++ "] sqltype'[" ++ show sqltype' ++ "] name=" ++ show name) $
                    if fmap snd ref == fmap snd ref'
                        then []
                        else tracex ("\n\n\nmodRef findAlters drop/add cos ref doesnt match ref[" ++ show ref ++ "] ref'[" ++ show ref' ++ "] tablename="++show tablename++"\n\n\n") $
                              refDrop ref' ++ refAdd ref
                modNull = case (isNull, isNull') of
                            (True, False) -> [(name, IsNull)]
                            (False, True) ->
                                let up = case def of
                                            Nothing -> id
                                            Just s -> (:) (name, Update' $ T.unpack s)
                                 in up [(name, NotNull)]
                            _ -> []
                modType = tracex ("modType: sqltype[" ++ show sqltype ++ "] sqltype'[" ++ show sqltype' ++ "] name=" ++ show name) $
                          if sqltype == sqltype' then [] else [(name, Type sqltype)]
                modDef = tracex ("modDef col=" ++ show col ++ " def=" ++ show def ++ " def'=" ++ show def') $
                    if cmpdef def def'
                        then []
                        else case def of
                                Nothing -> [(name, NoDefault)]
                                Just s -> [(name, Default $ T.unpack s)]
             in (modRef ++ modDef ++ modNull ++ modType,
                 filter (\c -> cName c /= name) cols)

cmpdef :: Maybe Text -> Maybe Text -> Bool
cmpdef Nothing Nothing = True
cmpdef (Just def) (Just def') | def==def' = True
                              | otherwise =
        let (a,_)=T.breakOnEnd ":" def'
        in -- tracex ("cmpdef def[" ++ show def ++ "] def'[" ++ show def' ++ "] a["++show a++"]") $
           case T.stripSuffix "::" a of
              Just xs -> def==xs
              Nothing -> False
cmpdef _ _ = False

-- | Get the references to be added to a table for the given column.
getAddReference :: [EntityDef] -> DBName -> DBName -> DBName -> Maybe (DBName, DBName) -> Maybe AlterDB
getAddReference allDefs table reftable cname ref =
    case ref of
        Nothing -> Nothing
        Just (s, z) -> tracex ("\n\ngetaddreference table="++ show table++" reftable="++show reftable++" s="++show s++" z=" ++ show z++"\n\n") $
                       Just $ AlterColumn table (s, AddReference (refName table cname) [cname] [id_])
                          where
                            id_ = maybe (error $ "Could not find ID of entity " ++ show reftable)
                                        id $ do
                                          entDef <- find ((== reftable) . entityDB) allDefs
                                          return (fieldDB $ entityId entDef)


showColumn :: Column -> String
showColumn (Column n nu sqlType' def _defConstraintName _maxLen _ref) = concat
    [ T.unpack $ escape n
    , " "
    , showSqlType sqlType' _maxLen
    , " "
    , if nu then "NULL" else "NOT NULL"
    , case def of
        Nothing -> ""
        Just s -> " DEFAULT " ++ T.unpack s
    ]

showSqlType :: SqlType -> Maybe Integer -> String
showSqlType SqlString Nothing = "VARCHAR"
showSqlType SqlString (Just len) = "VARCHAR(" ++ show len ++ ")"
showSqlType SqlInt32 _ = "INT4"
showSqlType SqlInt64 _ = "INT8"
showSqlType SqlReal _ = "DOUBLE PRECISION"
showSqlType (SqlNumeric s prec) _ = "NUMERIC(" ++ show s ++ "," ++ show prec ++ ")"
showSqlType SqlDay _ = "DATE"
showSqlType SqlTime _ = "TIME"
showSqlType SqlDayTime _ = "TIMESTAMP"
--showSqlType SqlDayTimeZoned _ = "TIMESTAMP WITH TIME ZONE"
showSqlType SqlBlob _ = "BYTEA"
showSqlType SqlBool _ = "BOOLEAN"
showSqlType (SqlOther t) _ = T.unpack t

showAlterDb :: AlterDB -> (Bool, Text)
showAlterDb (AddTable s) = (False, pack s)
showAlterDb (AlterColumn t (c, ac)) =
    (isUnsafe ac, pack $ showAlter t (c, ac))
  where
    isUnsafe (Drop safeToRem) = not safeToRem
    isUnsafe _ = False
showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)

showAlterTable :: DBName -> AlterTable -> String
showAlterTable table (AddUniqueConstraint cname cols) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " ADD CONSTRAINT "
    , T.unpack $ escape cname
    , " UNIQUE("
    , intercalate "," $ map (T.unpack . escape) cols
    , ")"
    ]
showAlterTable table (DropConstraint cname) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " DROP CONSTRAINT "
    , T.unpack $ escape cname
    ]

showAlter :: DBName -> AlterColumn' -> String
showAlter table (n, Type t) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " TYPE "
        , showSqlType t Nothing
        ]
showAlter table (n, IsNull) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " DROP NOT NULL"
        ]
showAlter table (n, NotNull) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " SET NOT NULL"
        ]
showAlter table (_, Add' col) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ADD COLUMN "
        , showColumn col
        ]
showAlter table (n, Drop _) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " DROP COLUMN "
        , T.unpack $ escape n
        ]
showAlter table (n, Default s) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " SET DEFAULT "
        , s
        ]
showAlter table (n, NoDefault) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " ALTER COLUMN "
    , T.unpack $ escape n
    , " DROP DEFAULT"
    ]
showAlter table (n, Update' s) = concat
    [ "UPDATE "
    , T.unpack $ escape table
    , " SET "
    , T.unpack $ escape n
    , "="
    , s
    , " WHERE "
    , T.unpack $ escape n
    , " IS NULL"
    ]
showAlter table (reftable, AddReference fkeyname t2 id2) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " ADD CONSTRAINT "
    , T.unpack $ escape fkeyname
    , " FOREIGN KEY("
    , T.unpack $ T.intercalate "," $ map escape t2
    , ") REFERENCES "
    , T.unpack $ escape reftable
    , "("
    , T.unpack $ T.intercalate "," $ map escape id2
    , ")"
    ]
showAlter table (_, DropReference cname) = concat
    [ "ALTER TABLE "
    , T.unpack (escape table)
    , " DROP CONSTRAINT "
    , T.unpack $ escape cname
    ]

escape :: DBName -> Text
escape (DBName s) =
    T.pack $ '"' : go (T.unpack s) ++ "\""
  where
    go "" = ""
    go ('"':xs) = "\"\"" ++ go xs
    go (x:xs) = x : go xs


refName :: DBName -> DBName -> DBName
refName (DBName table) (DBName column) =
    DBName $ T.concat [table, "_", column, "_fkey"]

udToPair :: UniqueDef -> (DBName, [DBName])
udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)

insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
insertSql' ent vals = tracex ("\n\n\nGBTEST " ++ show (entityFields ent) ++ "\n\n\n") $
  case entityPrimary ent of
    Just _pdef ->
      ISRManyKeys sql vals
        where sql = pack $ concat
                [ "INSERT INTO "
                , T.unpack $ escape $ entityDB ent
                , "("
                , intercalate "," $ map (T.unpack . escape . fieldDB) $ entityFields ent
                , ") VALUES("
                , intercalate "," (map (const "?") $ entityFields ent)
                , ")"
                ]
    Nothing ->
      ISRSingle $ pack $ concat
        [ "INSERT INTO "
        , T.unpack $ escape $ entityDB ent
        , "("
        , intercalate "," $ map (T.unpack . escape . fieldDB) $ entityFields ent
        , ") VALUES("
        , intercalate "," (map (const "?") $ entityFields ent)
        , ") RETURNING "
        , T.unpack $ escape $ fieldDB $ entityId ent
        ]