module Smap.Commands (run) where import Prelude hiding (filter, subtract, init, sin) import qualified Data.HashMap.Strict as Map (insert, empty, member) import Data.HashMap.Strict as Map (HashMap) import Data.Hashable (Hashable) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Streaming.Char8 as BS8 import qualified Streaming.Prelude as P import qualified Streaming as S import Control.Monad (foldM) import Control.Monad.Trans.Class (lift) import Data.Strict.Tuple (Pair((:!:))) import Data.List.NonEmpty (NonEmpty((:|))) import qualified Control.Monad.Trans.Resource as Resource import Control.Monad.IO.Class (MonadIO) import Crypto.MAC.SipHash (SipHash(..), hash) import Smap.Flags ( Hdl(Std, File) , Descriptor(Keyed, UnKeyed) , Command(Union, Subtract, Intersect) , Accuracy(Approximate, Exact) ) type Stream m k v = S.Stream (S.Of (Pair k v)) m () type SetOperation = forall key . (Hashable key, Eq key) => NonEmpty (Stream (Resource.ResourceT IO) key ByteString) -- Input maps -> Stream (Resource.ResourceT IO) key ByteString -- Output map cat :: SetOperation cat streams = foldM filter Map.empty streams *> return () where filter seen = P.foldM_ filter' (return seen) return . S.hoist lift -- for some strange reason I can't import alterF from Data.HashMap.Strict filter' (seen :: HashMap k ()) (bs :!: v) = if bs `Map.member` seen then return seen else P.yield (bs :!: v) >> return (Map.insert bs () seen) filterStreamWith :: (Bool -> Bool) -> SetOperation filterStreamWith includeIfPresent (first :| seconds) = do second <- lift $ collects seconds P.filter (\(k :!: _) -> includeIfPresent (k `Map.member` second)) first where collects = foldM collect Map.empty collect subs = P.fold_ (\s (k :!: _) -> Map.insert k () s) subs id sub :: SetOperation sub = filterStreamWith not int :: SetOperation int = filterStreamWith id load :: (MonadIO m, Resource.MonadResource m) => Descriptor ty -> Stream m ByteString ByteString load descriptor = case descriptor of UnKeyed hdl -> P.map (\x -> x :!: x) (linesOf hdl) Keyed keys values -> S.zipsWith' (\q (k P.:> ks) (v P.:> vs) -> (k :!: v) P.:> (q ks vs)) (linesOf keys) (linesOf values) where linesOf = S.mapsM BS8.toStrict . BS8.lines . hin hin Std = BS8.stdin hin (File path) = BS8.readFile path withAccuracy :: Accuracy -> SetOperation -> NonEmpty (Stream (Resource.ResourceT IO) ByteString ByteString) -> Hdl -> IO () withAccuracy accuracy op inputs output = case accuracy of Exact -> approximateWith id Approximate key -> approximateWith (sip key) where format = BS8.unlines . S.maps (\((_k :!: v) P.:> r) -> BS8.fromStrict v >> return r) hout Std = BS8.stdout hout (File path) = BS8.writeFile path sip key bs = let SipHash h = hash key bs in h keyMap f = P.map (\(k :!: v) -> (f k :!: v)) approximateWith approximator = Resource.runResourceT $ hout output $ format $ op $ fmap (keyMap approximator) inputs run :: Command -> IO () run cmd = case cmd of Subtract acc p ms o -> withAccuracy acc sub (load p :| fmap load ms) o Intersect acc i is o -> withAccuracy acc int (load i :| fmap load is) o Union acc is o -> withAccuracy acc cat (fmap load inputs) o where inputs = case is of [] -> UnKeyed Std :| [] (x : xs) -> x :| xs