{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE OverloadedStrings #-} module Sync.MerkleTree.Sync ( child , local , parent , openStreams , mkChanStreams , StreamPair(..) , Direction(..) , tests ) where import Control.Concurrent(newChan) import Control.Concurrent.MVar import Control.Monad import Control.Monad.State import Data.Monoid import System.FilePath import Prelude hiding (lookup) import Sync.MerkleTree.Trie hiding (tests) import Sync.MerkleTree.Types import System.IO import System.IO.Error import System.IO.Temp import System.IO.Streams(InputStream, OutputStream, connect) import Data.ByteString(ByteString) import qualified Data.Bytes.Serial as SE import qualified Data.Bytes.Put as P import qualified Data.ByteString as BS import qualified Data.Text as T import qualified System.IO.Streams as ST import qualified System.IO.Streams.Concurrent as ST import qualified Test.HUnit as H import Sync.MerkleTree.Analyse import Sync.MerkleTree.CommTypes import Sync.MerkleTree.Client import Sync.MerkleTree.Server import Sync.MerkleTree.Util.RequestMonad import Sync.MerkleTree.Util.GetFromInputStream data StreamPair = StreamPair { sp_in :: InputStream ByteString , sp_out :: OutputStream ByteString } openStreams :: Handle -> Handle -> IO StreamPair openStreams hIn hOut = do inStream <- ST.handleToInputStream hIn outStream <- ST.handleToOutputStream hOut return $ StreamPair { sp_in = inStream, sp_out = outStream } mkChanStreams :: IO (InputStream ByteString, OutputStream ByteString) mkChanStreams = do chan <- newChan liftM2 (,) (ST.chanToInput chan) (ST.chanToOutput chan) instance Protocol RequestMonad where queryHashReq = request . QueryHash querySetReq = request . QuerySet queryFileReq = request . QueryFile queryFileContReq = request . QueryFileCont logReq = request . Log queryTime = request QueryTime terminateReq = request . Terminate instance ClientMonad RequestMonad where split = splitRequests instance ClientMonad ServerMonad where split xs = liftM mconcat $ sequence xs data Direction = FromRemote | ToRemote child :: MVar () -> StreamPair -> IO () child gotMessage streams = do launchMessage <- getFromInputStream (sp_in streams) putMVar gotMessage () _ <- serverOrClient (read launchMessage) streams return () parent :: StreamPair -> FilePath -> FilePath -> Direction -> ClientServerOptions -> IO (Maybe T.Text) parent streams source destination direction clientServerOpts = case direction of FromRemote -> do respond (sp_out streams) $ show $ mkLaunchMessage Server source serverOrClient (mkLaunchMessage Client destination) streams ToRemote -> do respond (sp_out streams) $ show $ mkLaunchMessage Client destination serverOrClient (mkLaunchMessage Server source) streams where mkLaunchMessage side dir = LaunchMessage { lm_dir = dir , lm_clientServerOptions = clientServerOpts , lm_protocolVersion = thisProtocolVersion , lm_side = side } respond :: (SE.Serial a) => OutputStream ByteString -> a -> IO () respond os = mapM_ (flip ST.write os . Just) . (:[BS.empty]) . P.runPutS . SE.serialize local :: ClientServerOptions -> FilePath -> FilePath -> IO (Maybe T.Text) local cs source destination = do sourceDir <- liftM (mkTrie 0) $ analyse source (cs_ignore cs) destinationDir <- liftM (mkTrie 0) $ analyse destination (cs_ignore cs) serverState <- startServerState source sourceDir evalStateT (abstractClient cs destination destinationDir) serverState serverOrClient :: LaunchMessage -> StreamPair -> IO (Maybe T.Text) serverOrClient lm streams | lm_protocolVersion lm == thisProtocolVersion = let side = case lm_side lm of Server -> server Client -> client (lm_clientServerOptions lm) in do entries <- analyse (lm_dir lm) (cs_ignore $ lm_clientServerOptions lm) side entries (lm_dir lm) streams | otherwise = fail "Incompatible sync-mht versions." server :: [Entry] -> FilePath -> StreamPair -> IO (Maybe T.Text) server entries fp streams = (startServerState fp $ mkTrie 0 entries) >>= evalStateT loop where serverRespond = liftIO . respond (sp_out streams) loop = do req <- liftIO $ getFromInputStream (sp_in streams) case req of QueryHash l -> queryHashReq l >>= serverRespond >> loop QuerySet l -> querySetReq l >>= serverRespond >> loop QueryFile f -> queryFileReq f >>= serverRespond >> loop QueryFileCont c -> queryFileContReq c >>= serverRespond >> loop Log t -> logReq t >>= serverRespond >> loop QueryTime -> queryTime >>= serverRespond >> loop Terminate mMsg -> (terminateReq mMsg >>= serverRespond) >> return mMsg client :: ClientServerOptions -> [Entry] -> FilePath -> StreamPair -> IO (Maybe T.Text) client cs entries fp streams = runRequestMonad (sp_in streams) (sp_out streams) $ abstractClient cs fp $ mkTrie 0 entries tests :: H.Test tests = H.TestList $ [ H.TestLabel "testOpenStreams" $ H.TestCase $ withSystemTempDirectory "testStreams" $ \dir -> do let testStr = "31456" writeFile (dir "read.in") testStr hIn <- openFile (dir "read.in") ReadMode hOut <- openFile (dir "write.out") WriteMode st <- openStreams hIn hOut connect (sp_in st) (sp_out st) hClose hIn hClose hOut got <- readFile $ dir "write.out" testStr H.@=? got , H.TestLabel "testProtocolVersion" $ H.TestCase $ withSystemTempDirectory "testProtocolVersion" $ \dir -> do r <- flip catchIOError (\_ -> return True) $ do inst <- ST.fromByteString $ P.runPutS $ SE.serialize $ show $ LaunchMessage { lm_protocolVersion = ProtocolVersion 1 , lm_dir = dir , lm_side = Client , lm_clientServerOptions = ClientServerOptions { cs_add = False , cs_update = False , cs_delete = False , cs_ignore = [] , cs_compareClocks = Nothing } } out <- ST.nullOutput r <- newEmptyMVar child r $ StreamPair { sp_in = inst, sp_out = out } return False True H.@=? r ]