{-# LANGUAGE OverloadedStrings #-}
module Database.PostgreSQL.Typed.TypeCache
  ( PGTypes
  , pgGetTypes
  , PGTypeConnection
  , pgConnection
  , newPGTypeConnection
  , flushPGTypeConnection
  , lookupPGType
  , findPGType
  ) where

import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import qualified Data.IntMap as IntMap
import Data.List (find)

import Database.PostgreSQL.Typed.Types (PGName, OID)
import Database.PostgreSQL.Typed.Dynamic
import Database.PostgreSQL.Typed.Protocol

-- |Map keyed on fromIntegral OID.
type PGTypes = IntMap.IntMap PGName

-- |A 'PGConnection' along with cached information about types.
data PGTypeConnection = PGTypeConnection
  { PGTypeConnection -> PGConnection
pgConnection :: !PGConnection
  , PGTypeConnection -> IORef (Maybe PGTypes)
pgTypes :: IORef (Maybe PGTypes)
  }

-- |Create a 'PGTypeConnection'.
newPGTypeConnection :: PGConnection -> IO PGTypeConnection
newPGTypeConnection :: PGConnection -> IO PGTypeConnection
newPGTypeConnection PGConnection
c = do
  IORef (Maybe PGTypes)
t <- Maybe PGTypes -> IO (IORef (Maybe PGTypes))
forall a. a -> IO (IORef a)
newIORef Maybe PGTypes
forall a. Maybe a
Nothing
  PGTypeConnection -> IO PGTypeConnection
forall (m :: * -> *) a. Monad m => a -> m a
return (PGTypeConnection -> IO PGTypeConnection)
-> PGTypeConnection -> IO PGTypeConnection
forall a b. (a -> b) -> a -> b
$ PGConnection -> IORef (Maybe PGTypes) -> PGTypeConnection
PGTypeConnection PGConnection
c IORef (Maybe PGTypes)
t

-- |Flush the cached type list, forcing it to be reloaded.
flushPGTypeConnection :: PGTypeConnection -> IO ()
flushPGTypeConnection :: PGTypeConnection -> IO ()
flushPGTypeConnection PGTypeConnection
c =
  IORef (Maybe PGTypes) -> Maybe PGTypes -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (PGTypeConnection -> IORef (Maybe PGTypes)
pgTypes PGTypeConnection
c) Maybe PGTypes
forall a. Maybe a
Nothing

-- |Get a map of types from the database.
pgGetTypes :: PGConnection -> IO PGTypes
pgGetTypes :: PGConnection -> IO PGTypes
pgGetTypes PGConnection
c =
  [(Key, PGName)] -> PGTypes
forall a. [(Key, a)] -> IntMap a
IntMap.fromAscList ([(Key, PGName)] -> PGTypes)
-> ((Key, [[PGValue]]) -> [(Key, PGName)])
-> (Key, [[PGValue]])
-> PGTypes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([PGValue] -> (Key, PGName)) -> [[PGValue]] -> [(Key, PGName)]
forall a b. (a -> b) -> [a] -> [b]
map (\[PGValue
to, PGValue
tn] -> (OID -> Key
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PGValue -> OID
forall a. PGRep a => PGValue -> a
pgDecodeRep PGValue
to :: OID), PGValue -> PGName
forall a. PGRep a => PGValue -> a
pgDecodeRep PGValue
tn)) ([[PGValue]] -> [(Key, PGName)])
-> ((Key, [[PGValue]]) -> [[PGValue]])
-> (Key, [[PGValue]])
-> [(Key, PGName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    (Key, [[PGValue]]) -> [[PGValue]]
forall a b. (a, b) -> b
snd ((Key, [[PGValue]]) -> PGTypes)
-> IO (Key, [[PGValue]]) -> IO PGTypes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PGConnection -> ByteString -> IO (Key, [[PGValue]])
pgSimpleQuery PGConnection
c ByteString
"SELECT oid, format_type(CASE WHEN typtype = 'd' THEN typbasetype ELSE oid END, -1) FROM pg_catalog.pg_type ORDER BY oid"

-- |Get a cached map of types.
getPGTypes :: PGTypeConnection -> IO PGTypes
getPGTypes :: PGTypeConnection -> IO PGTypes
getPGTypes (PGTypeConnection PGConnection
c IORef (Maybe PGTypes)
tr) =
  IO PGTypes
-> (PGTypes -> IO PGTypes) -> Maybe PGTypes -> IO PGTypes
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (do
      PGTypes
t <- PGConnection -> IO PGTypes
pgGetTypes PGConnection
c
      IORef (Maybe PGTypes) -> Maybe PGTypes -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe PGTypes)
tr (Maybe PGTypes -> IO ()) -> Maybe PGTypes -> IO ()
forall a b. (a -> b) -> a -> b
$ PGTypes -> Maybe PGTypes
forall a. a -> Maybe a
Just PGTypes
t
      PGTypes -> IO PGTypes
forall (m :: * -> *) a. Monad m => a -> m a
return PGTypes
t)
    PGTypes -> IO PGTypes
forall (m :: * -> *) a. Monad m => a -> m a
return
    (Maybe PGTypes -> IO PGTypes) -> IO (Maybe PGTypes) -> IO PGTypes
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IORef (Maybe PGTypes) -> IO (Maybe PGTypes)
forall a. IORef a -> IO a
readIORef IORef (Maybe PGTypes)
tr

-- |Lookup a type name by OID.
-- This is an efficient, often pure operation.
lookupPGType :: PGTypeConnection -> OID -> IO (Maybe PGName)
lookupPGType :: PGTypeConnection -> OID -> IO (Maybe PGName)
lookupPGType PGTypeConnection
c OID
o =
  Key -> PGTypes -> Maybe PGName
forall a. Key -> IntMap a -> Maybe a
IntMap.lookup (OID -> Key
forall a b. (Integral a, Num b) => a -> b
fromIntegral OID
o) (PGTypes -> Maybe PGName) -> IO PGTypes -> IO (Maybe PGName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PGTypeConnection -> IO PGTypes
getPGTypes PGTypeConnection
c

-- |Lookup a type OID by type name.
-- This is less common and thus less efficient than going the other way.
findPGType :: PGTypeConnection -> PGName -> IO (Maybe OID)
findPGType :: PGTypeConnection -> PGName -> IO (Maybe OID)
findPGType PGTypeConnection
c PGName
t =
  ((Key, PGName) -> OID) -> Maybe (Key, PGName) -> Maybe OID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key -> OID
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Key -> OID) -> ((Key, PGName) -> Key) -> (Key, PGName) -> OID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Key, PGName) -> Key
forall a b. (a, b) -> a
fst) (Maybe (Key, PGName) -> Maybe OID)
-> (PGTypes -> Maybe (Key, PGName)) -> PGTypes -> Maybe OID
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Key, PGName) -> Bool) -> [(Key, PGName)] -> Maybe (Key, PGName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (PGName -> PGName -> Bool
forall a. Eq a => a -> a -> Bool
(==) PGName
t (PGName -> Bool)
-> ((Key, PGName) -> PGName) -> (Key, PGName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Key, PGName) -> PGName
forall a b. (a, b) -> b
snd) ([(Key, PGName)] -> Maybe (Key, PGName))
-> (PGTypes -> [(Key, PGName)]) -> PGTypes -> Maybe (Key, PGName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PGTypes -> [(Key, PGName)]
forall a. IntMap a -> [(Key, a)]
IntMap.toList (PGTypes -> Maybe OID) -> IO PGTypes -> IO (Maybe OID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PGTypeConnection -> IO PGTypes
getPGTypes PGTypeConnection
c