module Snap.Snaplet.MysqlSimple (
Mysql(..)
, HasMysql(..)
, mysqlInit
, mysqlInit'
, getConnectionInfo
, query
, query_
, fold
, fold_
, forEach
, forEach_
, execute
, execute_
, executeMany
, rollback
, commit
, formatMany
, formatQuery
, M.Query
, M.In(..)
, M.Binary(..)
, M.Only(..)
, M.FormatError(..)
, M.QueryError(..)
, M.ResultError(..)
, MB.MySQLError()
, QueryResults(..)
, QueryParams(..)
, M.defaultConnectInfo
) where
import Prelude hiding ((++))
import Control.Lens
import Control.Monad.CatchIO (MonadCatchIO)
import Control.Monad.IO.Class
import Control.Monad.State
import Control.Monad.Trans.Reader
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as B8
import Data.Char
import qualified Data.Configurator as C
import qualified Data.Configurator.Types as C
import Data.Int
import Data.Maybe
import Data.Monoid(Monoid(..))
import Data.Pool
import qualified Data.Set as S
import Data.Text (Text)
import qualified Database.MySQL.Simple as M
import qualified Database.MySQL.Base as MB
import Database.MySQL.Simple.QueryParams
import Database.MySQL.Simple.QueryResults
import Snap
import Paths_snaplet_mysql_simple
(++) :: Monoid a => a -> a -> a
(++) = mappend
infixr 5 ++
data Mysql = Mysql
{ mysqlPool :: Pool M.Connection
}
class (MonadCatchIO m) => HasMysql m where
getMysqlState :: m Mysql
instance HasMysql (Handler b Mysql) where
getMysqlState = get
instance (MonadCatchIO m) => HasMysql (ReaderT (Snaplet Mysql) m) where
getMysqlState = asks (^# snapletValue)
instance (MonadCatchIO m) => HasMysql (ReaderT Mysql m) where
getMysqlState = ask
$(makeLensesWith (LensRules Just Just (const Nothing)
(S.fromList [SimpleLenses, GenerateSignatures]))
''M.ConnectInfo)
$(makeLensesWith (LensRules Just Just (const Nothing)
(S.fromList [SimpleLenses, GenerateSignatures]))
''MB.SSLInfo)
getConnectionInfo :: MonadIO m => C.Config -> m M.ConnectInfo
getConnectionInfo config = do
sslOpts <- foldl modRecord (return MB.defaultSSLInfo) sslParams
connOpts <- liftM catMaybes $ mapM handleOpts optParams
connInfo <- foldl modRecord (return M.defaultConnectInfo) params
return $ set connectSSL (Just sslOpts) $ set connectOptions connOpts connInfo
where params =
[ ("host", set connectHost)
, ("port", set connectPort . read)
, ("dbname", set connectDatabase)
, ("user", set connectUser)
, ("password", set connectPassword)
, ("path", set connectPath)
]
optParams =
[ ("connection_timeout", MB.ConnectTimeout . read)
, ("compress", const MB.Compress)
, ("named_pipe", const MB.NamedPipe)
, ("init_commmand", MB.InitCommand . B8.pack)
, ("read_default_file", MB.ReadDefaultFile)
, ("read_default_group", MB.ReadDefaultGroup . B8.pack)
, ("charset_dir", MB.CharsetDir)
, ("charset_name", MB.CharsetName)
, ("local_in_file", MB.LocalInFile . readBool)
, ("protocol", \s -> MB.Protocol $ case map toLower s of
"tcp" -> MB.TCP
"socket" -> MB.Socket
"pipe" -> MB.Pipe
"memory" -> MB.Memory
_ -> error "no valid protocol.")
, ("shared_memory_base_name", MB.SharedMemoryBaseName . B8.pack)
, ("read_timeout", MB.ReadTimeout . read)
, ("write_timeout", MB.WriteTimeout . read)
, ("use_remote_connection", const MB.UseRemoteConnection)
, ("use_embedded_connection", const MB.UseEmbeddedConnection)
, ("guess_connection", const MB.GuessConnection)
, ("client_ip", MB.ClientIP . B8.pack)
, ("secure_auth", MB.SecureAuth . readBool)
, ("report_data_truncation", MB.ReportDataTruncation . readBool)
, ("reconnect", MB.Reconnect . readBool)
, ("ssl_verify_server_cert", MB.SSLVerifyServerCert . readBool)
, ("found_rows", const MB.FoundRows)
, ("ignore_sigpipe", const MB.IgnoreSIGPIPE)
, ("ignore_space", const MB.IgnoreSpace)
, ("interactive", const MB.Interactive)
, ("local_files", const MB.LocalFiles)
, ("multi_results", const MB.MultiResults)
, ("multi_statements", const MB.MultiStatements)
, ("no_schema", const MB.NoSchema)
]
sslParams =
[ ("ssl_key", set sslKey)
, ("ssl_cert", set sslCert)
, ("ssl_ca", set sslCA)
, ("ssl_ca_path", set sslCAPath)
, ("ssl_ciphers", set sslCiphers)
]
modRecord :: MonadIO m => m a -> (C.Name, String -> a -> a) -> m a
modRecord conf (name, setter) = do
x <- liftIO $ C.lookup config name
case x of
Just val -> liftM (setter val) conf
Nothing -> conf
handleOpts :: MonadIO m => (C.Name, String -> a) -> m (Maybe a)
handleOpts (name, f) = liftIO $ liftM (fmap f) $ C.lookup config name
readBool :: String -> Bool
readBool s = case map toLower s of
"yes" -> True
"y" -> True
"true" -> True
"no" -> False
"n" -> False
"false" -> False
_ -> error "expected 'yes' or 'no'."
description :: Text
description = "MySQL abstraction"
datadir :: Maybe (IO FilePath)
datadir = Just $ liftM (++"/resources/db") getDataDir
mysqlInit :: SnapletInit b Mysql
mysqlInit = makeSnaplet "mysql-simple" description datadir $ do
config <- getSnapletUserConfig
initHelper config
mysqlInit' :: C.Config -> SnapletInit b Mysql
mysqlInit' config = makeSnaplet "mysql-simple" description datadir $
initHelper config
initHelper :: MonadIO m => C.Config -> m Mysql
initHelper config = do
conninfo <- liftIO $ getConnectionInfo config
stripes <- liftIO $ C.lookupDefault 1 config "numStripes"
idle <- liftIO $ C.lookupDefault 5 config "idleTime"
resources <- liftIO $ C.lookupDefault 20 config "maxResourcesPerStripe"
pool <- liftIO $ createPool (M.connect conninfo) M.close stripes
(realToFrac (idle :: Double)) resources
return $ Mysql pool
withMysql :: (HasMysql m)
=> (M.Connection -> IO b) -> m b
withMysql f = do
s <- getMysqlState
let pool = mysqlPool s
liftIO $ withResource pool f
query :: (HasMysql m, QueryParams q, QueryResults r)
=> M.Query -> q -> m [r]
query q params = withMysql (\c -> M.query c q params)
query_ :: (HasMysql m, QueryResults r) => M.Query -> m [r]
query_ q = withMysql (\c -> M.query_ c q)
fold :: (HasMysql m,
QueryResults row,
QueryParams params,
MonadCatchIO m)
=> M.Query -> params -> b -> (b -> row -> IO b) -> m b
fold template qs a f = withMysql (\c -> M.fold c template qs a f)
fold_ :: (HasMysql m,
QueryResults row,
MonadCatchIO m)
=> M.Query -> b -> (b -> row -> IO b) -> m b
fold_ template a f = withMysql (\c -> M.fold_ c template a f)
forEach :: (HasMysql m,
QueryResults r,
QueryParams q,
MonadCatchIO m)
=> M.Query -> q -> (r -> IO ()) -> m ()
forEach template qs f = withMysql (\c -> M.forEach c template qs f)
forEach_ :: (HasMysql m,
QueryResults r,
MonadCatchIO m)
=> M.Query -> (r -> IO ()) -> m ()
forEach_ template f = withMysql (\c -> M.forEach_ c template f)
execute :: (HasMysql m, QueryParams q, MonadCatchIO m)
=> M.Query -> q -> m Int64
execute template qs = withMysql (\c -> M.execute c template qs)
execute_ :: (HasMysql m, MonadCatchIO m)
=> M.Query -> m Int64
execute_ template = withMysql (\c -> M.execute_ c template)
executeMany :: (HasMysql m, QueryParams q, MonadCatchIO m)
=> M.Query -> [q] -> m Int64
executeMany template qs = withMysql (\c -> M.executeMany c template qs)
rollback :: (HasMysql m, MonadCatchIO m) => m ()
rollback = withMysql M.rollback
commit :: (HasMysql m, MonadCatchIO m) => m ()
commit = withMysql M.commit
formatMany :: (QueryParams q, HasMysql m, MonadCatchIO m)
=> M.Query -> [q] -> m ByteString
formatMany q qs = withMysql (\c -> M.formatMany c q qs)
formatQuery :: (QueryParams q, HasMysql m, MonadCatchIO m)
=> M.Query -> q -> m ByteString
formatQuery q qs = withMysql (\c -> M.formatQuery c q qs)