module Database.Persist.MongoDB
( MongoDBReader
, withMongoDBConn
, runMongoDBConn
, HostName
, PersistValue(..)
, module Database.Persist
) where
import Database.Persist
import Database.Persist.Base
import Control.Monad.Trans.Reader
import qualified Control.Monad.IO.Class as Trans
import qualified Database.MongoDB as DB
import Control.Applicative (Applicative)
import Control.Exception (toException)
import Data.UString (u)
import qualified Data.CompactString.UTF8 as CS
import Data.Enumerator hiding (map, length)
import Network.Socket (HostName)
import qualified Network.Abstract(NetworkIO)
import Data.Maybe (mapMaybe, fromJust)
import qualified Data.Text as T
import qualified Data.Text.Encoding as E
import qualified Data.Serialize as S
newtype MongoDBReader t m a = MongoDBReader (ReaderT ((DB.ConnPool t), HostName) m a)
deriving (Monad, Trans.MonadIO, Functor, Applicative)
withMongoDBConn :: (Network.Abstract.NetworkIO m) => t -> HostName -> ((DB.ConnPool DB.Host, t) -> m b) -> m b
withMongoDBConn dbname hostname connectionReader = do
pool <- DB.newConnPool 1 $ DB.host hostname
connectionReader (pool, dbname)
runMongoDBConn :: MongoDBReader t m a -> (DB.ConnPool t, HostName) -> m a
runMongoDBConn (MongoDBReader r) = runReaderT r
runPool :: (DB.Service s, Trans.MonadIO m) => DB.ConnPool s -> String
-> ReaderT DB.Database (DB.Action m) a -> m (Either DB.Failure a)
runPool pool dbname action =
DB.access DB.safe DB.Master pool $ DB.use (DB.Database (u dbname)) action
execute :: (DB.Service s, Trans.MonadIO m) =>
ReaderT DB.Database (DB.Action (MongoDBReader s m)) b -> MongoDBReader s m b
execute action = do
(pool, dbname) <- MongoDBReader ask
res <- runPool pool dbname action
case res of
(Right result) -> return result
(Left x) -> fail (show x)
value :: DB.Field -> DB.Value
value (_ DB.:= val) = val
rightPersistVals :: (PersistEntity val) => EntityDef -> [DB.Field] -> val
rightPersistVals ent vals = case wrapFromPersistValues ent vals of
Left e -> error e
Right v -> v
fst3 :: forall t t1 t2. (t, t1, t2) -> t
fst3 (x, _, _) = x
filterByKey :: (PersistEntity val) => Key val -> DB.Document
filterByKey k = [u"_id" DB.=: keyToDbOid k]
queryByKey :: (PersistEntity val) => Key val -> EntityDef -> DB.Query
queryByKey k entity = (DB.select (filterByKey k) (u $ entityName entity))
selectByKey :: (PersistEntity val) => Key val -> EntityDef -> DB.Selection
selectByKey k entity = (DB.select (filterByKey k) (u $ entityName entity))
updateFields :: (PersistEntity val) => [Update val] -> [DB.Field]
updateFields upds = map updateField upds
updateField :: (PersistEntity val) => Update val -> DB.Field
updateField upd = opName DB.:= DB.Doc [( (u $ persistUpdateToFieldName upd) DB.:= opValue)]
where
opValue = (DB.val $ transform $ persistUpdateToValue upd)
transform (PersistInt64 i) = PersistInt64 $
case persistUpdateToUpdate upd of
Subtract -> i
_ -> i
transform x = x
opName = case persistUpdateToUpdate upd of
Update -> u "$set"
Add -> u "$inc"
Subtract -> u "$inc"
Multiply -> error "multiply not supported yet"
Divide -> error "divide not supported yet"
uniqSelector :: forall val. (PersistEntity val) => Unique val -> [DB.Field]
uniqSelector uniq = zipWith (DB.:=)
(map u (persistUniqueToFieldNames uniq))
(map DB.val (persistUniqueToValues uniq))
pairFromDocument :: forall val val1. (PersistEntity val, PersistEntity val1) => EntityDef -> [DB.Field] -> Either String (Key val, val1)
pairFromDocument ent document = pairFromPersistValues document
where
pairFromPersistValues (x:xs) =
case wrapFromPersistValues ent xs of
Left e -> Left e
Right xs' -> Right ((toPersistKey . fromJust . DB.cast' . value) x, xs')
pairFromPersistValues _ = Left "error in fromPersistValues'"
insertFields :: forall val. (PersistEntity val) => EntityDef -> val -> [DB.Field]
insertFields t record = zipWith (DB.:=) (toLabels) (toValues)
where
toLabels = map (u . fst3) $ entityColumns t
toValues = map (DB.val . toPersistValue) (toPersistFields record)
instance (DB.DbAccess m, DB.Service t) => PersistBackend (MongoDBReader t m) where
insert record = do
(DB.ObjId oid) <- execute $ DB.insert (u $ entityName t) (insertFields t record)
return $ toPersistKey $ dbOidToKey oid
where
t = entityDef record
replace k record = do
execute $ DB.replace (selectByKey k t) (insertFields t record)
return ()
where
t = entityDef record
update _ [] = return ()
update k upds =
execute $ DB.modify
(DB.Select [u"_id" DB.:= (DB.ObjId $ keyToDbOid k)] (u $ entityName t))
$ updateFields upds
where
t = entityDef $ dummyFromKey k
updateWhere _ [] = return ()
updateWhere filts upds =
execute $ DB.modify DB.Select {
DB.coll = (u $ entityName t)
, DB.selector = filterToSelector filts
} $ updateFields upds
where
t = entityDef $ dummyFromFilts filts
delete k =
execute $ DB.deleteOne DB.Select {
DB.coll = (u $ entityName t)
, DB.selector = filterByKey k
}
where
t = entityDef $ dummyFromKey k
deleteWhere filts = do
execute $ DB.delete DB.Select {
DB.coll = (u $ entityName t)
, DB.selector = filterToSelector filts
}
where
t = entityDef $ dummyFromFilts filts
deleteBy uniq =
execute $ DB.delete DB.Select {
DB.coll = u $ entityName t
, DB.selector = uniqSelector uniq
}
where
t = entityDef $ dummyFromUnique uniq
get k = do
d <- execute $ DB.findOne (queryByKey k t)
case d of
Nothing -> return Nothing
Just doc -> do
return $ Just $ rightPersistVals t (tail doc)
where
t = entityDef $ dummyFromKey k
getBy uniq = do
mdocument <- execute $ DB.findOne $
(DB.select (uniqSelector uniq) (u $ entityName t))
case mdocument of
Nothing -> return Nothing
Just document -> case pairFromDocument t document of
Left s -> error s
Right (k, x) -> return $ Just (k, x)
where
t = entityDef $ dummyFromUnique uniq
count filts = do
i <- execute $ DB.count query
return $ fromIntegral i
where
query = DB.select (filterToSelector filts) (u $ entityName t)
t = entityDef $ dummyFromFilts filts
selectEnum filts ords limit offset = Iteratee . start
where
start x = do
cursor <- execute $ DB.find query
loop x cursor
query = (DB.select (filterToSelector filts) (u $ entityName t)) {
DB.limit = fromIntegral limit
, DB.skip = fromIntegral offset
, DB.sort = if null ords then [] else map orderClause ords
}
t = entityDef $ dummyFromFilts filts
orderClause o = (u(persistOrderToFieldName o))
DB.=: (case persistOrderToOrder o of
Asc -> 1 :: Int
Desc -> 1 )
loop (Continue k) curs = do
doc <- execute $ DB.next curs
case doc of
Nothing -> return $ Continue k
Just document -> case pairFromDocument t document of
Left s -> return $ Error $ toException
$ PersistMarshalException s
Right row -> do
step <- runIteratee $ k $ Chunks [row]
loop step curs
loop step _ = return step
selectKeys filts =
Iteratee . start
where
start x = do
cursor <- execute $ DB.find query
loop x cursor
loop (Continue k) curs = do
doc <- execute $ DB.next curs
case doc of
Nothing -> return $ Continue k
Just [_ DB.:= (DB.ObjId oid)] -> do
step <- runIteratee $ k $ Chunks [toPersistKey $ dbOidToKey oid]
loop step curs
Just y -> return $ Error $ toException $ PersistMarshalException
$ "Unexpected in selectKeys: " ++ show y
loop step _ = return step
query = (DB.select (filterToSelector filts) (u $ entityName t)) {
DB.project = [u"_id" DB.=: (1 :: Int)]
}
t = entityDef $ dummyFromFilts filts
filterToSelector :: PersistEntity val => [Filter val] -> DB.Document
filterToSelector filts = map filterField filts
filterField :: PersistEntity val => Filter val -> DB.Field
filterField f = case filt of
Eq -> name DB.:= filterValue
_ -> name DB.=: [u(showFilter filt) DB.:= filterValue]
where
name = case (persistFilterToFieldName f) of
"id" -> u "_id"
other -> u other
filt = persistFilterToFilter f
filterValue = case persistFilterToValue f of
Left v -> DB.val v
Right vs -> DB.Array (map DB.val vs)
showFilter Ne = "$ne"
showFilter Gt = "$gt"
showFilter Lt = "$lt"
showFilter Ge = "$gte"
showFilter Le = "$lte"
showFilter In = "$in"
showFilter NotIn = "$nin"
showFilter Eq = error ""
wrapFromPersistValues :: (PersistEntity val) => EntityDef -> [DB.Field] -> Either String val
wrapFromPersistValues e doc = fromPersistValues reorder
where
castDoc = mapFromDoc doc
castColumns = map (T.pack . fst3) $ (entityColumns e)
reorder :: [PersistValue]
reorder = match castColumns castDoc []
where
match :: [T.Text] -> [(T.Text, PersistValue)] -> [PersistValue] -> [PersistValue]
match [] [] values = values
match (c:cs) fields values =
let (found, unused) = matchOne fields []
in match cs unused (values ++ [snd found])
where
matchOne (f:fs) tried =
if c == fst f then (f, tried ++ fs) else matchOne fs (f:tried)
matchOne fields tried = error $ "field doesn't match" ++ (show c) ++ (show fields) ++ (show tried)
match cs fields values = error $ "fields don't match" ++ (show cs) ++ (show fields) ++ (show values)
mapFromDoc :: DB.Document -> [(T.Text, PersistValue)]
mapFromDoc = Prelude.map (\f -> ( ( csToT (DB.label f)), (fromJust . DB.cast') (DB.value f) ) )
csToT :: CS.CompactString -> T.Text
csToT = E.decodeUtf8 . CS.toByteString
tToCS :: T.Text -> CS.CompactString
tToCS = CS.fromByteString_ . E.encodeUtf8
dbOidToKey :: DB.ObjectId -> PersistValue
dbOidToKey = PersistForeignKey . S.encode
foreignKeyToDbOid :: PersistValue -> DB.ObjectId
foreignKeyToDbOid (PersistForeignKey k) = case S.decode k of
Left s -> error s
Right o -> o
foreignKeyToDbOid _ = error "expected PersistForeignKey"
keyToDbOid :: (PersistEntity val) => Key val -> DB.ObjectId
keyToDbOid = foreignKeyToDbOid . fromPersistKey
instance DB.Val PersistValue where
val (PersistInt64 x) = DB.Int64 x
val (PersistText x) = DB.String (tToCS x)
val (PersistDouble x) = DB.Float x
val (PersistBool x) = DB.Bool x
val (PersistUTCTime x) = DB.UTC x
val (PersistNull) = DB.Null
val (PersistList l) = DB.Array $ map DB.val l
val (PersistMap m) = DB.Doc $ map (\(k, v)-> (DB.=:) (tToCS k) v) m
val (PersistByteString x) = DB.String $ CS.fromByteString_ x
val x@(PersistForeignKey _) = DB.ObjId $ foreignKeyToDbOid x
val (PersistDay _) = error "only PersistUTCTime currently implemented"
val (PersistTimeOfDay _) = error "only PersistUTCTime currently implemented"
cast' (DB.Float x) = Just (PersistDouble x)
cast' (DB.Int32 x) = Just $ PersistInt64 $ fromIntegral x
cast' (DB.Int64 x) = Just $ PersistInt64 x
cast' (DB.String x) = Just $ PersistText (csToT x)
cast' (DB.Bool x) = Just $ PersistBool x
cast' (DB.UTC d) = Just $ PersistUTCTime d
cast' DB.Null = Just $ PersistNull
cast' (DB.Bin (DB.Binary b)) = Just $ PersistByteString b
cast' (DB.Fun (DB.Function f)) = Just $ PersistByteString f
cast' (DB.Uuid (DB.UUID uid)) = Just $ PersistByteString uid
cast' (DB.Md5 (DB.MD5 md5)) = Just $ PersistByteString md5
cast' (DB.UserDef (DB.UserDefined bs)) = Just $ PersistByteString bs
cast' (DB.RegEx (DB.Regex us1 us2)) = Just $ PersistByteString $ CS.toByteString $ CS.append us1 us2
cast' (DB.Doc doc) = Just $ PersistMap $ mapFromDoc doc
cast' (DB.Array xs) = Just $ PersistList $ mapMaybe DB.cast' xs
cast' (DB.ObjId x) = Just $ dbOidToKey x
cast' (DB.JavaScr _) = error "cast operation not supported for javascript"
cast' (DB.Sym _) = error "cast operation not supported for sym"
cast' (DB.Stamp _) = error "cast operation not supported for stamp"
cast' (DB.MinMax _) = error "cast operation not supported for minmax"
instance S.Serialize DB.ObjectId where
put (DB.Oid w1 w2) = do S.put w1
S.put w2
get = do w1 <- S.get
w2 <- S.get
return (DB.Oid w1 w2)
dummyFromKey :: Key v -> v
dummyFromKey _ = error "dummyFromKey"
dummyFromUnique :: Unique v -> v
dummyFromUnique _ = error "dummyFromUnique"
dummyFromFilts :: [Filter v] -> v
dummyFromFilts _ = error "dummyFromFilts"