{-# LANGUAGE ScopedTypeVariables, TemplateHaskell, TypeSynonymInstances, FlexibleInstances #-}
module Database.Persist.HsSqlPpp
  (Query,
   persistSql,
   persistSqlFile,
   parseEntityFromFile,
   parseEntity,
   selectFromQuery,
   selectFromQuery',
   checkSQL
  ) where

import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Trans
import Control.Monad.Error
import Control.Monad.Instances
import Control.Monad.IO.Control
import Data.Char
import Data.List (intercalate)
import qualified Data.ByteString as B
import qualified Data.Text as Text
import Language.Haskell.TH
import qualified Language.Haskell.TH.Lift as THL
import qualified Language.Haskell.TH.Syntax as TH
import Language.Haskell.TH.Quote

import Database.Persist
import Database.Persist.Base
import Database.Persist.GenericSql.Raw
import qualified Database.Persist.GenericSql.Internal as I
import Database.Persist.TH

import Database.HsSqlPpp.Types as Types
import Database.HsSqlPpp.Ast as AST
import Database.HsSqlPpp.Annotation
import Database.HsSqlPpp.Catalog
import Database.HsSqlPpp.Parser
import Database.HsSqlPpp.Pretty

-- | Pairs (column name, column type)
type Types = [(String, String)]

-- | (SQL query text, entity definition)
type Query = (QueryExpr, EntityDef)

class IsName x where
  name2string :: x -> String

instance IsName String where
  name2string = id

instance IsName AST.Name where
  name2string (AST.Name _ list) = name2string list

instance IsName NameComponent where
  name2string n = ncStr n

instance IsName [NameComponent] where
  name2string list = ncStr (last list)

toString :: B.ByteString -> String
toString bs = map (chr . fromIntegral) (B.unpack bs)

THL.deriveLift ''SelectItem
THL.deriveLift ''SelectList
THL.deriveLift ''Distinct
THL.deriveLift ''FrameClause
THL.deriveLift ''TypeName
THL.deriveLift ''LiftFlavour
THL.deriveLift ''IntervalField
THL.deriveLift ''InList
THL.deriveLift ''ExtractField
THL.deriveLift ''ScalarExpr
THL.deriveLift ''Direction
THL.deriveLift ''TableAlias
THL.deriveLift ''NameComponent
THL.deriveLift ''AST.Name
THL.deriveLift ''JoinExpr
THL.deriveLift ''JoinType
THL.deriveLift ''Natural
THL.deriveLift ''TableRef
THL.deriveLift ''CombineType
THL.deriveLift ''WithQuery
THL.deriveLift ''TypeError
THL.deriveLift ''FunFlav
THL.deriveLift ''CastContext
THL.deriveLift ''PseudoType
THL.deriveLift ''Types.Type
THL.deriveLift ''CatalogUpdate
THL.deriveLift ''Annotation
THL.deriveLift ''QueryExpr

allE :: (a -> Either e b) -> b -> [a] -> Either e b
allE _ b [] = Right b
allE fn b (x:xs) =
  case fn x of
    Left e -> Left e
    Right _ -> allE fn b xs

(<&>) :: Either String a -> Either String a -> (a -> Either String a)
(Right _) <&> (Right _) = Right
(Right _) <&> (Left e)  = \_ -> Left e
(Left e)  <&> (Right _) = \_ -> Left e
(Left e)  <&> (Left e') = \_ -> Left (e ++ "; " ++ e')

emptyAnn :: Annotation
emptyAnn = Annotation Nothing Nothing [] Nothing [] Nothing Nothing

makeQuery :: (PersistEntity t) => I.Connection -> QueryExpr -> [Filter t] -> Either String QueryExpr
makeQuery conn query filts = foldM (addWhere conn) query filts

addWhere :: (PersistEntity t) => I.Connection -> QueryExpr -> Filter t -> Either String QueryExpr
addWhere conn (Select ann dist fields trefs old a b c d e) filt = Right $ Select ann dist fields trefs (Just $ convertCond conn filt old) a b c d e
addWhere _ x _ = Left $ "Unsupported query expr in addWhere: " ++ show x

name :: String -> AST.Name
name s = Name emptyAnn [Nmc s]

convertCond :: (PersistEntity t) => I.Connection -> Filter t -> Maybe ScalarExpr -> ScalarExpr
convertCond conn f Nothing = convertFilter conn f
convertCond conn f (Just old) = FunCall emptyAnn (name "!and") [convertFilter conn f, old]

convertFilter conn (Filter field value op) =
    FunCall emptyAnn (name $ showSqlFilter op) (Identifier emptyAnn (Nmc fname): allVals)
  where
    allVals = map persist2expr $ filterValueToPersistValues value

    persist2expr (PersistText t) = StringLit emptyAnn (Text.unpack t)
    persist2expr (PersistByteString bstr) = StringLit emptyAnn (toString bstr)
    persist2expr (PersistInt64 i) = NumberLit emptyAnn (show i)
    persist2expr (PersistDouble x) = NumberLit emptyAnn (show x)
    persist2expr (PersistBool b) = BooleanLit emptyAnn b
    persist2expr (PersistDay d) = StringLit emptyAnn (show d)
    persist2expr (PersistTimeOfDay tod) = StringLit emptyAnn (show tod)
    persist2expr (PersistUTCTime t) = StringLit emptyAnn (show t)
    persist2expr (PersistNull) = NullLit emptyAnn
    persist2expr x = error $ "persist2expr: unsupported Persistent value in filter: " ++ show x

    filterValueToPersistValues :: forall a.  PersistField a => Either a [a] -> [PersistValue]
    filterValueToPersistValues v = map toPersistValue $ either return id v

    fname = I.escapeName conn $ getFieldName t $ columnName $ persistColumnDef field
    t = entityDef $ dummyFromFilts [Filter field value op]

    getFieldName :: EntityDef -> String -> I.RawName
    getFieldName t s = I.rawFieldName $ tableColumn t s

-- This is copy-paste from Database.Persist.GenericSql.Internal
showSqlFilter Eq = "="
showSqlFilter Ne = "<>"
showSqlFilter Gt = ">"
showSqlFilter Lt = "<"
showSqlFilter Ge = ">="
showSqlFilter Le = "<="
showSqlFilter In = " IN "
showSqlFilter NotIn = " NOT IN "
showSqlFilter (BackendSpecificFilter s) = s

dummyFromFilts :: [Filter v] -> v
dummyFromFilts _ = error "dummyFromFilts"

tableColumn :: EntityDef -> String -> ColumnDef
tableColumn t s | s == id_ = ColumnDef id_ "Int64" []
  where id_ = I.unRawName $ I.rawTableIdName t
tableColumn t s = go $ entityColumns t
  where
    go [] = error $ "Unknown table column: " ++ s
    go (ColumnDef x y z:rest)
        | x == s = ColumnDef x y z
        | otherwise = go rest
-- End of copy-paste

-- | Convert SELECT query
convertSelect :: Types -> QueryExpr -> Either String EntityDef
convertSelect types (Select _ _ (SelectList _ items) _ _ _ _ _ _ _) =
  EntityDef <$> return "Undefined"
            <*> return []
            <*> mapM (convertColumn types) items
            <*> return []
            <*> return ["Eq", "Show"]
convertSelect _ expr = Left $ "Unsupported SQL query syntax in " ++ printQueryExpr expr

getColumnNames :: QueryExpr -> Either String [ColumnName]
getColumnNames (Select _ _ (SelectList _ items) _ _ _ _ _ _ _) =
    mapM getColumnName items
getColumnNames expr = Left $ "Unsupported SQL query syntax in " ++ printQueryExpr expr

checkSelect :: [String] -> QueryExpr -> Either String QueryExpr
checkSelect tables s@(Select _ _ (SelectList _ items) trefs _ _ _ _ _ _) = 
    (allE goodTref s trefs <&> allE goodCol s items) s
  where
    goodTref (FunTref _ _ a)            = Left $ "Function call as table reference is not allowed: " ++ show a
    goodTref (JoinTref _ t1 _ _ t2 _ _) = (goodTref t1 <&> goodTref t2) s
    goodTref (SubTref _ q _)            = checkSelect tables q
    goodTref (Tref _ n@(Name _ list) _) =
      let tname = lowerName n
      in  if tname `elem` tables
            then Right s
            else Left $ "Table reference is not allowed: " ++ tname

    goodCol (SelExp _ expr)       = goodExpr expr
    goodCol (SelectItem _ expr _) = goodExpr expr

    goodExpr (FunCall _ fn args) =
      let name = map toLower (name2string fn)
      in  if name `elem` ["count", "max", "min", "avg", "sum"]
            then allE goodExpr s args
            else Left $ "Function call not allowed: " ++ name2string fn
    goodExpr (ScalarSubQuery _ q) = checkSelect tables q
    goodExpr (Case {})       = Left "CASE expressions are not allowed"
    goodExpr (CaseSimple {}) = Left "CASE expressions are not allowed"
    goodExpr (Cast {})       = Left "CAST expressions are not allowed"
    goodExpr (Exists _ q)    = checkSelect tables q
    goodExpr (Extract {})    = Left "EXTRACT expressions are not allowed"
    goodExpr (InPredicate _ e _ (InList _ es)) = (goodExpr e <&> allE goodExpr s es) s
    goodExpr (InPredicate _ e _ (InQueryExpr _ q)) = (goodExpr e <&> checkSelect tables q) s
    goodExpr (LiftOperator _ _ _ es) = allE goodExpr s es
    goodExpr (QIdentifier _ _) = Right s                               --- !?
    goodExpr (WindowFn {}) = Left "WINDOW functions are not allowed"
    goodExpr _ = Right s
checkSelect _ q = Left $ "Non-SELECT query forbidden: " ++ printQueryExpr q

getConnection :: (MonadControlIO m) => SqlPersist m I.Connection
getConnection = SqlPersist ask

-- | Select list of records from DB using given SQL query
selectFromQuery :: forall m a. (MonadControlIO m, PersistEntity a)
                => QueryExpr               -- ^ SQL query
                -> [Filter a]              -- ^ Filters
                -> SqlPersist m (Either String [a])
selectFromQuery query filts = do
    conn <- getConnection
    case makeQuery conn query filts of
      Left err -> fail err
      Right expr -> withStmt (Text.pack $ printQueryExpr expr) [] worker
  where
    worker :: I.RowPopper (SqlPersist m) -> SqlPersist m (Either String [a])
    worker popper = runErrorT (go popper)

    go popper = do
      row <- lift popper
      case row of
        Nothing -> return []
        Just list -> case fromPersistValues list of
                       Left err -> fail err
                       Right rec -> do
                                    next <- go popper
                                    return (rec: next)

-- | Check if given SQL query is safe SELECT query
checkSQL :: [String] -- ^ Allowed table names
         -> String   -- ^ SQL query
         -> Either String QueryExpr
checkSQL tables query = do
  (expr,_) <- parseSelect "<no file>" query
  checkSelect tables expr

-- | Select list of records from DB using given SQL SELECT query.
-- Query is checked for safety (arbitrary function calls,
-- complex expressions, etc are not permitted).
-- Each row will be represented as [PersistValue].
selectFromQuery' :: (MonadControlIO m)
                 => [String]           -- ^ Names of allowed tables
                 -> String             -- ^ SQL query
                 -> SqlPersist m (Either String ([ColumnName], [[PersistValue]]))
selectFromQuery' tables query = do
    case parseSelect "<inline>" query of
      Left err -> return (Left err)
      Right (expr, names) ->
        case checkSelect tables expr of
          Right _ -> do
               rows <- withStmt (Text.pack query) [] worker
               return $ Right (names, rows)
          Left err -> return $ Left $ "Forbidden query: " ++ err
  where
    worker popper = do
      row <- popper
      case row of
        Nothing -> return []
        Just rec -> do
                    next <- worker popper
                    return (rec: next)

-- | Quasi-quoter which parses SQL SELECT queries.
-- Example input:
--
-- @
--   SELECT family, salary FROM employee;
--   --------------------
--   family String
--   salary Int
-- @
--
-- NB: entity name will be \"Undefined\", so
-- you'll need to use record update syntax to
-- set name your want, e.g. entity {entityName = \"Query\"}.
persistSql :: QuasiQuoter
persistSql = QuasiQuoter {
               quoteExp = \str -> case parseEntity "<inline>" str of
                                    Right expr -> TH.lift expr
                                    Left err -> fail err,
               quotePat = undefined,
               quoteType = undefined,
               quoteDec = undefined }

-- | Load entity declaration from file containing SQL query.
-- TH version.
persistSqlFile :: FilePath -> Q [Exp]
persistSqlFile path = do
  entity <- runIO $ parseEntityFromFile path
  x <- TH.lift entity
  return [x]

-- | Load entity declaration from file containing SQL query
parseEntityFromFile :: FilePath -> IO (Either String Query)
parseEntityFromFile path = do
  text <- readFile path
  return (parseEntity path text)

-- | Parse SQL entity declaration
parseEntity :: FilePath  -- ^ File name to use in error messages
            -> String    -- ^ Declaration text
            -> Either String Query
parseEntity path text =
  let delimiter str = not (null str) && all (`elem` "-= \t") str && any (`elem` "-=") str
      (queryLines, _:typeLines) = break delimiter (lines text)
      types = parseTypes typeLines
      bad line = all (`elem` " \t") line
      query = unlines $ filter (not . bad) queryLines
  in  parseSelectT path types $ query 

parseSelectT :: String -> Types -> String -> Either String Query
parseSelectT path types query =
  case parseQueryExpr path query of
    Left err -> Left $ "Cannot parse query expression: " ++ show err
    Right expr -> do
                  res <- convertSelect types expr
                  return (expr, res)

parseSelect :: String -> String -> Either String (QueryExpr, [ColumnName])
parseSelect path query = do
  case parseQueryExpr path query of
    Left err -> Left (show err)
    Right expr -> do
                  cols <- getColumnNames expr
                  return (expr, cols)

parseTypes :: [String] -> Types
parseTypes = concatMap parseLine
  where
    parseLine line =
      case words line of
        [] -> []
        (name:other) -> [(name, unwords other)]

lowerName :: IsName name => name -> String
lowerName n = map toLower (name2string n)

lookupType :: (IsName name) => name -> Types -> Either String ColumnType
lookupType name ts =
  case lookup (name2string name) ts of
    Nothing -> Left $ "Do not know type of " ++ name2string name
    Just typ -> return typ

getColumnType :: Types -> SelectItem -> Either String ColumnType
getColumnType t (SelExp _ (Identifier _ name))    = lookupType name t
getColumnType t (SelExp _ (QIdentifier _ name))   = lookupType name t
getColumnType t (SelExp _ (FunCall _ fn args))    = getFuncType t (lowerName fn) args
getColumnType _ (SelExp _ expr)                   = Left $ "Cannot guess type of " ++ printScalarExpr expr
getColumnType t (SelectItem _ expr name) =
    case lookup (name2string name) t of
      Just typ -> return typ
      Nothing -> getType t expr

getColumnName :: SelectItem -> Either String ColumnName
getColumnName (SelExp _ (Identifier _ name))  = return (name2string name)
getColumnName (SelExp _ (QIdentifier _ name)) = return (name2string name)
getColumnName (SelExp _ (FunCall _ fn _))     = return (lowerName fn)
getColumnName (SelExp _ (QStar _ name))       = return (name2string name ++ ".*")
getColumnName (SelExp _ (Star _))             = return "*"
getColumnName (SelExp _ expr)                 = Left $ "Do not want invent the name for " ++ printScalarExpr expr
getColumnName (SelectItem _ _ name)           = return (name2string name)

convertColumn :: Types -> SelectItem -> Either String ColumnDef
convertColumn t s =
  ColumnDef <$> getColumnName s <*> getColumnType t s <*> return []

getType :: Types -> ScalarExpr -> Either String ColumnType
getType _ (BooleanLit _ _)      = return "Bool"
getType _ (Exists _ _)          = return "Bool"
getType t (FunCall _ fn args)   = getFuncType t (lowerName fn) args
getType _ (InPredicate _ _ _ _) = return "Bool"
getType _ (NumberLit _ _)       = return "Int"
getType t (Identifier _ name)   = lookupType name t
getType t (QIdentifier _ name)  = lookupType name t
getType _ (StringLit _ _)       = return "String"
getType _ x = Left $ "Unsupported scalar expression type: " ++ show x

getFuncType :: Types -> String -> [ScalarExpr] -> Either String ColumnType
getFuncType t "max" [e] = getType t e
getFuncType t "min" [e] = getType t e
getFuncType t "sum" [e] = getType t e
getFuncType t "avg" [e] = getType t e
getFuncType _ "count" _ = return "Int"
getFuncType t fn args =
  case lookup fn t of
    Just typ -> return typ
    Nothing -> Left $ "Cannot invent type for " ++ fn ++ show args