module Database.Selda.Prepared (Preparable, Prepare, prepared) where
import Database.Selda.Backend.Internal
import Database.Selda.Caching
import Database.Selda.Column
import Database.Selda.Compile
import Database.Selda.Query.Type
import Database.Selda.SQL (param, paramType)
import Database.Selda.Types (TableName)
import Control.Exception
import Control.Monad.IO.Class
import qualified Data.HashMap.Strict as M
import Data.IORef
import Data.Proxy
import Data.Text (Text)
import System.IO.Unsafe
data Placeholder = Placeholder Int
deriving Show
instance Exception Placeholder
firstParamIx :: Int
firstParamIx = 0
type family ResultT f where
ResultT (a -> b) = ResultT b
ResultT (m a) = a
type family Equiv q f where
Equiv (Col s a -> q) (a -> f) = Equiv q f
Equiv (Query s a) (m [b]) = Res a ~ b
type CompResult = (Text, [Either Int Param], [SqlTypeRep], [TableName])
class Preparable q where
mkQuery :: MonadSelda m
=> Int
-> q
-> [SqlTypeRep]
-> m CompResult
class Prepare q f where
mkFun :: Preparable q
=> IORef (Maybe (BackendID, CompResult))
-> StmtID
-> q
-> [Param]
-> f
instance (SqlType a, Prepare q b) => Prepare q (a -> b) where
mkFun ref sid qry ps x = mkFun ref sid qry (param x : ps)
instance (MonadSelda m, a ~ Res (ResultT q), Result (ResultT q)) =>
Prepare q (m [a]) where
mkFun ref sid qry arguments = do
conn <- seldaConnection
let backend = connBackend conn
args = reverse arguments
stmts <- liftIO $ readIORef (connStmts conn)
case M.lookup sid stmts of
Just stm -> do
liftIO $ do
runQuery conn stm args
_ -> do
compiled <- liftIO $ readIORef ref
(q, params, reps, ts) <- case compiled of
Just (bid, comp) | bid == backendId backend -> do
return comp
_ -> do
comp <- mkQuery firstParamIx qry []
liftIO $ writeIORef ref (Just (backendId backend, comp))
return comp
liftIO $ do
hdl <- prepareStmt backend sid reps q
let stm = SeldaStmt
{ stmtHandle = hdl
, stmtParams = params
, stmtTables = ts
, stmtText = q
}
atomicModifyIORef' (connStmts conn) $ \m -> (M.insert sid stm m, ())
runQuery conn stm args
where
runQuery conn stm args = do
let backend = connBackend conn
ps = replaceParams (stmtParams stm) args
key = (connDbId conn, stmtText stm, ps)
hdl = stmtHandle stm
mres <- cached key
case mres of
Just res -> do
return res
_ -> do
res <- runPrepared backend hdl ps
cache (stmtTables stm) key res
return $ map (toRes (Proxy :: Proxy (ResultT q))) (snd res)
instance (SqlType a, Preparable b) => Preparable (Col s a -> b) where
mkQuery n f ts = mkQuery (n+1) (f x) (sqlType (Proxy :: Proxy a) : ts)
where x = C $ Lit $ LCustom (throw (Placeholder n) :: Lit a)
instance Result a => Preparable (Query s a) where
mkQuery _ q types = do
b <- seldaBackend
case compileWithTables (ppConfig b) q of
(tables, (q', ps)) -> do
(ps', types') <- liftIO $ inspectParams (reverse types) ps
return (q', ps', types', tables)
prepared :: (Preparable q, Prepare q f, Equiv q f) => q -> f
prepared q = unsafePerformIO $ do
ref <- newIORef Nothing
sid <- freshStmtId
return $ mkFun ref sid q []
replaceParams :: [Either Int Param] -> [Param] -> [Param]
replaceParams params = map fromRight . go firstParamIx params
where
go n ps (x:xs) = go (n+1) (map (subst n x) ps) xs
go _ ps _ = ps
subst n x (Left n') | n == n' = Right x
subst _ _ old = old
fromRight (Right x) = x
fromRight _ = error "BUG: query parameter not substituted!"
inspectParams :: [SqlTypeRep] -> [Param] -> IO ([Either Int Param], [SqlTypeRep])
inspectParams ts (x:xs) = do
res <- try $ pure $! forceParam x
let (x', t) = case res of
Right p -> (Right p, paramType p)
Left (Placeholder ix) -> (Left ix, ts !! ix)
(xs', ts') <- inspectParams ts xs
return (x' : xs', t : ts')
inspectParams _ [] = do
return ([], [])
forceParam :: Param -> Param
forceParam p@(Param (LCustom x)) | x `seq` True = p
forceParam p = p