{-# LANGUAGE CPP, PatternGuards, ScopedTypeVariables, FlexibleContexts, TemplateHaskell, DataKinds #-}
module Database.PostgreSQL.Typed.TH
( getTPGDatabase
, withTPGTypeConnection
, withTPGConnection
, useTPGDatabase
, reloadTPGTypes
, TPGValueInfo(..)
, tpgDescribe
, tpgTypeEncoder
, tpgTypeDecoder
, tpgTypeBinary
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>), (<$))
#endif
import Control.Applicative ((<|>))
import Control.Concurrent.MVar (MVar, newMVar, takeMVar, putMVar, withMVar)
import Control.Exception (onException, finally)
#ifdef VERSION_tls
import Control.Exception (throwIO)
#endif
import Control.Monad (liftM2)
import qualified Data.ByteString as BS
#ifdef VERSION_tls
import qualified Data.ByteString.Char8 as BSC
#endif
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.UTF8 as BSU
import qualified Data.Foldable as Fold
import Data.Maybe (isJust, fromMaybe)
import Data.String (fromString)
import qualified Data.Traversable as Tv
import qualified Language.Haskell.TH as TH
import qualified Network.Socket as Net
import System.Environment (lookupEnv)
import System.IO.Unsafe (unsafePerformIO, unsafeInterleaveIO)
import Database.PostgreSQL.Typed.Types
import Database.PostgreSQL.Typed.Protocol
import Database.PostgreSQL.Typed.TypeCache
getTPGDatabase :: IO PGDatabase
getTPGDatabase :: IO PGDatabase
getTPGDatabase = do
String
user <- forall a. a -> Maybe a -> a
fromMaybe String
"postgres" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>) (String -> IO (Maybe String)
lookupEnv String
"TPG_USER") (String -> IO (Maybe String)
lookupEnv String
"USER")
String
db <- forall a. a -> Maybe a -> a
fromMaybe String
user forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_DB"
String
host <- forall a. a -> Maybe a -> a
fromMaybe String
"localhost" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_HOST"
String
pnum <- forall a. a -> Maybe a -> a
fromMaybe String
"5432" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_PORT"
#ifdef mingw32_HOST_OS
let port = Right pnum
#else
Either String String
port <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. b -> Either a b
Right String
pnum) forall a b. a -> Either a b
Left forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_SOCK"
#endif
String
pass <- forall a. a -> Maybe a -> a
fromMaybe String
"" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_PASS"
Bool
debug <- forall a. Maybe a -> Bool
isJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_DEBUG"
#ifdef VERSION_tls
Bool
tlsEnabled <- forall a. Maybe a -> Bool
isJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"TPG_TLS"
PGTlsValidateMode
tlsVerifyMode <- String -> IO (Maybe String)
lookupEnv String
"TPG_TLS_MODE" forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe String
modeStr ->
case Maybe String
modeStr of
Just String
"full" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PGTlsValidateMode
TlsValidateFull
Just String
"ca" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PGTlsValidateMode
TlsValidateCA
Just String
other -> forall e a. Exception e => e -> IO a
throwIO (String -> IOError
userError (String
"Unknown verify mode: " forall a. [a] -> [a] -> [a]
++ String
other))
Maybe String
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PGTlsValidateMode
TlsValidateCA
Maybe String
mTlsCertPem <- String -> IO (Maybe String)
lookupEnv String
"TPG_TLS_ROOT_CERT"
PGTlsMode
dbTls <- case Maybe String
mTlsCertPem of
Just String
certPem ->
case PGTlsValidateMode -> ByteString -> Either String PGTlsMode
pgTlsValidate PGTlsValidateMode
tlsVerifyMode (String -> ByteString
BSC.pack String
certPem) of
Right PGTlsMode
x -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PGTlsMode
x
Left String
err -> forall e a. Exception e => e -> IO a
throwIO (String -> IOError
userError String
err)
Maybe String
Nothing | Bool
tlsEnabled -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PGTlsMode
TlsNoValidate
Maybe String
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PGTlsMode
TlsDisabled
#endif
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ PGDatabase
defaultPGDatabase
{ pgDBAddr :: Either (String, String) SockAddr
pgDBAddr = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> SockAddr
Net.SockAddrUnix) (forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,) String
host) Either String String
port
, pgDBName :: ByteString
pgDBName = String -> ByteString
BSU.fromString String
db
, pgDBUser :: ByteString
pgDBUser = String -> ByteString
BSU.fromString String
user
, pgDBPass :: ByteString
pgDBPass = String -> ByteString
BSU.fromString String
pass
, pgDBDebug :: Bool
pgDBDebug = Bool
debug
#ifdef VERSION_tls
, pgDBTLS :: PGTlsMode
pgDBTLS = PGTlsMode
dbTls
#endif
}
{-# NOINLINE tpgState #-}
tpgState :: MVar (PGDatabase, Maybe PGTypeConnection)
tpgState :: MVar (PGDatabase, Maybe PGTypeConnection)
tpgState = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
PGDatabase
db <- forall a. IO a -> IO a
unsafeInterleaveIO IO PGDatabase
getTPGDatabase
forall a. a -> IO (MVar a)
newMVar (PGDatabase
db, forall a. Maybe a
Nothing)
withTPGTypeConnection :: (PGTypeConnection -> IO a) -> IO a
withTPGTypeConnection :: forall a. (PGTypeConnection -> IO a) -> IO a
withTPGTypeConnection PGTypeConnection -> IO a
f = do
(PGDatabase
db, Maybe PGTypeConnection
tpg') <- forall a. MVar a -> IO a
takeMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState
PGTypeConnection
tpg <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (PGConnection -> IO PGTypeConnection
newPGTypeConnection forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PGDatabase -> IO PGConnection
pgConnect PGDatabase
db) forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PGTypeConnection
tpg'
forall a b. IO a -> IO b -> IO a
`onException` forall a. MVar a -> a -> IO ()
putMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState (PGDatabase
db, forall a. Maybe a
Nothing)
PGTypeConnection -> IO a
f PGTypeConnection
tpg forall a b. IO a -> IO b -> IO a
`finally` forall a. MVar a -> a -> IO ()
putMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState (PGDatabase
db, forall a. a -> Maybe a
Just PGTypeConnection
tpg)
withTPGConnection :: (PGConnection -> IO a) -> IO a
withTPGConnection :: forall a. (PGConnection -> IO a) -> IO a
withTPGConnection PGConnection -> IO a
f = forall a. (PGTypeConnection -> IO a) -> IO a
withTPGTypeConnection (PGConnection -> IO a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGTypeConnection -> PGConnection
pgConnection)
useTPGDatabase :: PGDatabase -> TH.DecsQ
useTPGDatabase :: PGDatabase -> DecsQ
useTPGDatabase PGDatabase
db = forall a. IO a -> Q a
TH.runIO forall a b. (a -> b) -> a -> b
$ do
(PGDatabase
db', Maybe PGTypeConnection
tpg') <- forall a. MVar a -> IO a
takeMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState
forall a. MVar a -> a -> IO ()
putMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,) PGDatabase
db forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
(if PGDatabase
db forall a. Eq a => a -> a -> Bool
== PGDatabase
db'
then forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
Tv.mapM (\PGTypeConnection
t -> do
PGConnection
c <- PGConnection -> PGDatabase -> IO PGConnection
pgReconnect (PGTypeConnection -> PGConnection
pgConnection PGTypeConnection
t) PGDatabase
db
forall (m :: * -> *) a. Monad m => a -> m a
return PGTypeConnection
t{ pgConnection :: PGConnection
pgConnection = PGConnection
c }) Maybe PGTypeConnection
tpg'
else forall a. Maybe a
Nothing forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
Fold.mapM_ (PGConnection -> IO ()
pgDisconnect forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGTypeConnection -> PGConnection
pgConnection) Maybe PGTypeConnection
tpg')
forall a b. IO a -> IO b -> IO a
`onException` forall a. MVar a -> a -> IO ()
putMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState (PGDatabase
db, forall a. Maybe a
Nothing)
forall (m :: * -> *) a. Monad m => a -> m a
return []
reloadTPGTypes :: TH.DecsQ
reloadTPGTypes :: DecsQ
reloadTPGTypes = forall a. IO a -> Q a
TH.runIO forall a b. (a -> b) -> a -> b
$ [] forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar (PGDatabase, Maybe PGTypeConnection)
tpgState (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PGTypeConnection -> IO ()
flushPGTypeConnection forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
tpgType :: PGTypeConnection -> OID -> IO PGName
tpgType :: PGTypeConnection -> OID -> IO PGName
tpgType PGTypeConnection
c OID
o =
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unknown PostgreSQL type: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show OID
o forall a. [a] -> [a] -> [a]
++ String
"\nYou may need to use reloadTPGTypes or adjust search_path, or your postgresql-typed application may need to be rebuilt.") forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PGTypeConnection -> OID -> IO (Maybe PGName)
lookupPGType PGTypeConnection
c OID
o
getTPGTypeOID :: PGTypeConnection -> PGName -> IO OID
getTPGTypeOID :: PGTypeConnection -> PGName -> IO OID
getTPGTypeOID PGTypeConnection
c PGName
t =
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unknown PostgreSQL type: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show PGName
t forall a. [a] -> [a] -> [a]
++ String
"; be sure to use the exact type name from \\dTS") forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PGTypeConnection -> PGName -> IO (Maybe OID)
findPGType PGTypeConnection
c PGName
t
data TPGValueInfo = TPGValueInfo
{ TPGValueInfo -> ByteString
tpgValueName :: BS.ByteString
, TPGValueInfo -> OID
tpgValueTypeOID :: !OID
, TPGValueInfo -> PGName
tpgValueType :: PGName
, TPGValueInfo -> Bool
tpgValueNullable :: Bool
}
tpgDescribe :: BS.ByteString -> [String] -> Bool -> IO ([TPGValueInfo], [TPGValueInfo])
tpgDescribe :: ByteString
-> [String] -> Bool -> IO ([TPGValueInfo], [TPGValueInfo])
tpgDescribe ByteString
sql [String]
types Bool
nulls = forall a. (PGTypeConnection -> IO a) -> IO a
withTPGTypeConnection forall a b. (a -> b) -> a -> b
$ \PGTypeConnection
tpg -> do
[OID]
at <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PGTypeConnection -> PGName -> IO OID
getTPGTypeOID PGTypeConnection
tpg forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsString a => String -> a
fromString) [String]
types
([OID]
pt, [(ByteString, OID, Bool)]
rt) <- PGConnection
-> ByteString
-> [OID]
-> Bool
-> IO ([OID], [(ByteString, OID, Bool)])
pgDescribe (PGTypeConnection -> PGConnection
pgConnection PGTypeConnection
tpg) (ByteString -> ByteString
BSL.fromStrict ByteString
sql) [OID]
at Bool
nulls
(,)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\OID
o -> do
PGName
ot <- PGTypeConnection -> OID -> IO PGName
tpgType PGTypeConnection
tpg OID
o
forall (m :: * -> *) a. Monad m => a -> m a
return TPGValueInfo
{ tpgValueName :: ByteString
tpgValueName = ByteString
BS.empty
, tpgValueTypeOID :: OID
tpgValueTypeOID = OID
o
, tpgValueType :: PGName
tpgValueType = PGName
ot
, tpgValueNullable :: Bool
tpgValueNullable = Bool
True
}) [OID]
pt
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(ByteString
c, OID
o, Bool
n) -> do
PGName
ot <- PGTypeConnection -> OID -> IO PGName
tpgType PGTypeConnection
tpg OID
o
forall (m :: * -> *) a. Monad m => a -> m a
return TPGValueInfo
{ tpgValueName :: ByteString
tpgValueName = ByteString
c
, tpgValueTypeOID :: OID
tpgValueTypeOID = OID
o
, tpgValueType :: PGName
tpgValueType = PGName
ot
, tpgValueNullable :: Bool
tpgValueNullable = Bool
n Bool -> Bool -> Bool
&& OID
o forall a. Eq a => a -> a -> Bool
/= OID
2278
}) [(ByteString, OID, Bool)]
rt
typeApply :: PGName -> TH.Name -> TH.Name -> TH.Exp
typeApply :: PGName -> Name -> Name -> Exp
typeApply PGName
t Name
f Name
e =
Name -> Exp
TH.VarE Name
f Exp -> Exp -> Exp
`TH.AppE` Name -> Exp
TH.VarE Name
e
Exp -> Exp -> Exp
`TH.AppE` (Name -> Exp
TH.ConE 'PGTypeProxy Exp -> Type -> Exp
`TH.SigE` (Name -> Type
TH.ConT ''PGTypeID Type -> Type -> Type
`TH.AppT` TyLit -> Type
TH.LitT (String -> TyLit
TH.StrTyLit forall a b. (a -> b) -> a -> b
$ PGName -> String
pgNameString forall a b. (a -> b) -> a -> b
$ PGName
t)))
tpgTypeEncoder :: Bool -> TPGValueInfo -> TH.Name -> TH.Exp
tpgTypeEncoder :: Bool -> TPGValueInfo -> Name -> Exp
tpgTypeEncoder Bool
lit TPGValueInfo
v = PGName -> Name -> Name -> Exp
typeApply (TPGValueInfo -> PGName
tpgValueType TPGValueInfo
v) forall a b. (a -> b) -> a -> b
$
if Bool
lit
then 'pgEscapeParameter
else 'pgEncodeParameter
tpgTypeDecoder :: Bool -> TPGValueInfo -> TH.Name -> TH.Exp
tpgTypeDecoder :: Bool -> TPGValueInfo -> Name -> Exp
tpgTypeDecoder Bool
nulls TPGValueInfo
v = PGName -> Name -> Name -> Exp
typeApply (TPGValueInfo -> PGName
tpgValueType TPGValueInfo
v) forall a b. (a -> b) -> a -> b
$
if Bool
nulls Bool -> Bool -> Bool
&& TPGValueInfo -> Bool
tpgValueNullable TPGValueInfo
v
then 'pgDecodeColumn
else 'pgDecodeColumnNotNull
tpgTypeBinary :: TPGValueInfo -> TH.Name -> TH.Exp
tpgTypeBinary :: TPGValueInfo -> Name -> Exp
tpgTypeBinary TPGValueInfo
v = PGName -> Name -> Name -> Exp
typeApply (TPGValueInfo -> PGName
tpgValueType TPGValueInfo
v) 'pgBinaryColumn