{-# LANGUAGE CPP, DeriveDataTypeable, RankNTypes, RecordWildCards, ScopedTypeVariables #-}
-----------------------------------------------------------------------------
{- |
 Module      :  Data.Acid.Remote
 Copyright   :  PublicDomain

 Maintainer  :  lemmih@gmail.com
 Portability :  non-portable (uses GHC extensions)

 This module provides the ability perform 'update' and 'query' calls
from a remote process.

On the server-side you:

 1. open your 'AcidState' normally

 2. then use 'acidServer' to share the state

On the client-side you:

 1. use 'openRemoteState' to connect to the remote state

 2. use the returned 'AcidState' like any other 'AcidState' handle

'openRemoteState' and 'acidServer' communicate over an unencrypted
socket. If you need an encrypted connection, see @acid-state-tls@.

On Unix®-like systems you can use 'UnixSocket' to create a socket file for
local communication between the client and server. Access can be
controlled by setting the permissions of the parent directory
containing the socket file.

It is also possible to perform some simple authentication using
'sharedSecretCheck' and 'sharedSecretPerform'. Keep in mind that
secrets will be sent in plain-text if you do not use
@acid-state-tls@. If you are using a 'UnixSocket' additional
authentication may not be required, so you can use
'skipAuthenticationCheck' and 'skipAuthenticationPerform'.

Working with a remote 'AcidState' is nearly identical to working with
a local 'AcidState' with a few important differences.

The connection to the remote 'AcidState' can be lost. The client will
automatically attempt to reconnect every second. Because 'query'
events do not affect the state, an aborted 'query' will be retried
automatically after the server is reconnected.

If the connection was lost during an 'update' event, the event will
not be retried. Instead 'RemoteConnectionError' will be raised. This
is because it is impossible for the client to know if the aborted
update completed on the server-side or not.

When using a local 'AcidState', an update event in one thread does not
block query events taking place in other threads. With a remote
connection, all queries and requests are channeled over a single
connection. As a result, updates and queries are performed in the
order they are executed and do block each other. In the rare case
where this is an issue, you could create one remote connection per
thread.

When working with local state, a query or update which returns the
whole state is not usually a problem due to memory sharing. The
update/query event basically just needs to return a pointer to the
data already in memory. But, when working remotely, the entire result
will be serialized and sent to the remote client. Hence, it is good
practice to create queries and updates that will only return the
required data.

This module is designed to be extenible. You can easily add your own
authentication methods by creating a suitable pair of functions and
passing them to 'acidServer' and 'openRemoteState'.

It is also possible to create alternative communication layers using
'CommChannel', 'process', and 'processRemoteState'.

-}
module Data.Acid.Remote
    (
    -- * Server/Client
      acidServer
    , openRemoteState
    -- * Authentication
    , skipAuthenticationCheck
    , skipAuthenticationPerform
    , sharedSecretCheck
    , sharedSecretPerform
    -- * Exception type
    , AcidRemoteException(..)
    -- * Low-Level functions needed to implement additional communication channels
    , CommChannel(..)
    , process
    , processRemoteState
    ) where

import Prelude                                hiding ( catch )
import Control.Concurrent.STM                        ( atomically )
import Control.Concurrent.STM.TMVar                  ( newEmptyTMVar, readTMVar, takeTMVar, tryTakeTMVar, putTMVar )
import Control.Concurrent.STM.TQueue
import Control.Exception                             ( AsyncException(ThreadKilled)
                                                     , Exception(fromException), IOException, Handler(..)
                                                     , SomeException, catch, catches, throw )
import Control.Exception                             ( throwIO, finally )
import Control.Monad                                 ( forever, liftM, join, when )
import Control.Concurrent                            ( ThreadId, forkIO, threadDelay, killThread, myThreadId )
import Control.Concurrent.MVar                       ( MVar, newEmptyMVar, putMVar, takeMVar )
import Control.Concurrent.Chan                       ( newChan, readChan, writeChan )
import Data.Acid.Abstract
import Data.Acid.Core
import Data.Acid.Common
import qualified Data.ByteString                     as Strict
import Data.ByteString.Char8                         ( pack )
import qualified Data.ByteString.Lazy                as Lazy
import Data.IORef                                    ( newIORef, readIORef, writeIORef )
import Data.Serialize
import Data.SafeCopy                                 ( SafeCopy, safeGet, safePut )
import Data.Set                                      ( Set, member )
import Data.Typeable                                 ( Typeable )
import GHC.IO.Exception                              ( IOErrorType(..) )
import Network                                       ( HostName, PortID(..), connectTo, listenOn, withSocketsDo )
import Network.Socket                                ( Socket, accept, sClose )
import Network.Socket.ByteString                     ( recv, sendAll )
import System.Directory                              ( removeFile )
import System.IO                                     ( Handle, hPrint, hFlush, hClose, stderr )
import System.IO.Error                               ( ioeGetErrorType, isFullError, isDoesNotExistError )

debugStrLn :: String -> IO ()
debugStrLn s =
    do -- putStrLn s -- uncomment to enable debugging
       return ()

-- | 'CommChannel' is a record containing the IO functions we need for communication between the server and client.
--
-- We abstract this out of the core processing function so that we can easily add support for SSL/TLS and Unit testing.
data CommChannel = CommChannel
    { ccPut     :: Strict.ByteString -> IO ()
    , ccGetSome :: Int -> IO (Strict.ByteString)
    , ccClose   :: IO ()
    }

data AcidRemoteException
    = RemoteConnectionError
    | AcidStateClosed
    | SerializeError String
    | AuthenticationError String
      deriving (Eq, Show, Typeable)
instance Exception AcidRemoteException


-- | create a 'CommChannel' from a 'Handle'. The 'Handle' should be
-- some two-way communication channel, such as a socket
-- connection. Passing in a 'Handle' to a normal is file is unlikely
-- to do anything useful.
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel handle =
    CommChannel { ccPut     = \bs -> Strict.hPut handle bs >> hFlush handle
                , ccGetSome = Strict.hGetSome handle
                , ccClose   = hClose handle
                }

{- | create a 'CommChannel' from a 'Socket'. The 'Socket' should be
     an accepted socket, not a listen socket.
-}
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel socket =
    CommChannel { ccPut     = sendAll socket
                , ccGetSome = recv    socket
                , ccClose   = sClose  socket
                }

{- | skip server-side authentication checking entirely. -}
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck _ = return True

{- | skip client-side authentication entirely. -}
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform _ = return ()

{- | check that the client knows a shared secret.

The function takes a 'Set' of shared secrets. If a client knows any
of them, it is considered to be trusted.

The shared secret is any 'ByteString' of your choice.

If you give each client a different shared secret then you can
revoke access individually.

see also: 'sharedSecretPerform'
-}
sharedSecretCheck :: Set Strict.ByteString -- ^ set of shared secrets
                  -> (CommChannel -> IO Bool)
sharedSecretCheck secrets cc =
    do bs <- ccGetSome cc 1024
       if member bs secrets
          then do ccPut cc (pack "OK")
                  return True
          else do ccPut cc (pack "FAIL")
                  return False

-- | attempt to authenticate with the server using a shared secret.
sharedSecretPerform :: Strict.ByteString -- ^ shared secret
                    -> (CommChannel -> IO ())
sharedSecretPerform pw cc =
    do ccPut cc pw
       r <- ccGetSome cc 1024
       if r == (pack "OK")
          then return ()
          else throwIO (AuthenticationError "shared secret authentication failed.")

{- | Accept connections on @port@ and handle requests using the given 'AcidState'.
     This call doesn't return.

     On Unix®-like systems you can use 'UnixSocket' to communicate
     using a socket file. To control access, you can set the permissions of
     the parent directory which contains the socket file.

     see also: 'openRemoteState' and 'sharedSecretCheck'.
 -}
acidServer :: SafeCopy st =>
              (CommChannel -> IO Bool) -- ^ check authentication, see 'sharedSecretPerform'
           -> PortID                   -- ^ Port to listen on
           -> AcidState st             -- ^ state to serve
           -> IO ()
acidServer checkAuth port acidState
  = withSocketsDo $
    do listenSocket <- listenOn port
       let loop = forever $
             do (socket, _sockAddr) <- accept listenSocket
                let commChannel = socketToCommChannel socket
                forkIO $ do authorized <- checkAuth commChannel
                            when authorized $
                                 process commChannel acidState
                            ccClose commChannel -- FIXME: `finally` ?
           infi = loop `catchSome` logError >> infi
       infi `finally` (cleanup listenSocket)
    where
      logError :: (Show e) => e -> IO ()
      logError e = hPrint stderr e

      isResourceVanishedError :: IOException -> Bool
      isResourceVanishedError = isResourceVanishedType . ioeGetErrorType

      isResourceVanishedType :: IOErrorType -> Bool
      isResourceVanishedType ResourceVanished = True
      isResourceVanishedType _                = False

      catchSome :: IO () -> (Show e => e -> IO ()) -> IO ()
      catchSome op _h =
          op `catches` [ Handler $ \(e :: IOException)    ->
                           if isFullError e || isDoesNotExistError e || isResourceVanishedError e
                            then return () -- h (toException e) -- we could log the exception, but there could be thousands of them
                            else throw e
                       ]
      cleanup socket =
          do sClose socket
             case port of
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
               UnixSocket path -> removeFile path
#endif
               _               -> return ()

data Command = RunQuery (Tagged Lazy.ByteString)
             | RunUpdate (Tagged Lazy.ByteString)
             | CreateCheckpoint

instance Serialize Command where
  put cmd = case cmd of
              RunQuery query   -> do putWord8 0; put query
              RunUpdate update -> do putWord8 1; put update
              CreateCheckpoint ->    putWord8 2
  get = do tag <- getWord8
           case tag of
             0 -> liftM RunQuery get
             1 -> liftM RunUpdate get
             2 -> return CreateCheckpoint
             _ -> error $ "Serialize.get for Command, invalid tag: " ++ show tag

data Response = Result Lazy.ByteString | Acknowledgement | ConnectionError

instance Serialize Response where
  put resp = case resp of
               Result result -> do putWord8 0; put result
               Acknowledgement -> putWord8 1
               ConnectionError -> putWord8 2
  get = do tag <- getWord8
           case tag of
             0 -> liftM Result get
             1 -> return Acknowledgement
             2 -> return ConnectionError
             _ -> error $ "Serialize.get for Response, invalid tag: " ++ show tag

{- | Server inner-loop

     This function is generally only needed if you are adding a new communication channel.
-}
process :: SafeCopy st =>
           CommChannel  -- ^ a connected, authenticated communication channel
        -> AcidState st -- ^ state to share
        -> IO ()
process CommChannel{..} acidState
  = do chan <- newChan
       forkIO $ forever $ do response <- join (readChan chan)
                             ccPut (encode response)
       worker chan (runGetPartial get Strict.empty)
  where worker chan inp
          = case inp of
              Fail msg      -> throwIO (SerializeError msg)
              Partial cont  -> do bs <- ccGetSome 1024
                                  worker chan (cont bs)
              Done cmd rest -> do processCommand chan cmd; worker chan (runGetPartial get rest)
        processCommand chan cmd =
          case cmd of
            RunQuery query -> do result <- queryCold acidState query
                                 writeChan chan (return $ Result result)
            RunUpdate update -> do result <- scheduleColdUpdate acidState update
                                   writeChan chan (liftM Result $ takeMVar result)
            CreateCheckpoint -> do createCheckpoint acidState
                                   writeChan chan (return Acknowledgement)


data RemoteState st = RemoteState (Command -> IO (MVar Response)) (IO ())
                    deriving (Typeable)

{- | Connect to an acid-state server which is sharing an 'AcidState'. -}
openRemoteState :: IsAcidic st =>
                   (CommChannel -> IO ()) -- ^ authentication function, see 'sharedSecretPerform'
                -> HostName               -- ^ remote host to connect to (ignored when 'PortID' is 'UnixSocket')
                -> PortID                 -- ^ remote port to connect to
                -> IO (AcidState st)
openRemoteState performAuthorization host port
  = withSocketsDo $
    do processRemoteState reconnect
    where
      -- | reconnect
      reconnect :: IO CommChannel
      reconnect
          = (do debugStrLn "Reconnecting."
                handle <- connectTo host port
                let cc = handleToCommChannel handle
                performAuthorization cc
                debugStrLn "Reconnected."
                return cc
            )
            `catch`
            ((\_ -> threadDelay 1000000 >> reconnect) :: IOError -> IO CommChannel)


{- | Client inner-loop

     This function is generally only needed if you are adding a new communication channel.
-}
processRemoteState :: IsAcidic st =>
                      IO CommChannel -- ^ (re-)connect function
                   -> IO (AcidState st)
processRemoteState reconnect
  = do cmdQueue    <- atomically newTQueue
       ccTMV       <- atomically newEmptyTMVar
       isClosed    <- newIORef False

       let actor :: Command -> IO (MVar Response)
           actor command =
               do debugStrLn "actor: begin."
                  readIORef isClosed >>= flip when (throwIO AcidStateClosed)
                  ref <- newEmptyMVar
                  atomically $ writeTQueue cmdQueue (command, ref)
                  debugStrLn "actor: end."
                  return ref

           expireQueue listenQueue =
               do mCallback <- atomically $ tryReadTQueue listenQueue
                  case mCallback of
                    Nothing         -> return ()
                    (Just callback) ->
                        do callback ConnectionError
                           expireQueue listenQueue

           handleReconnect :: SomeException -> IO ()
           handleReconnect e
             = case fromException e of
                 (Just ThreadKilled) ->
                     do debugStrLn "handleReconnect: ThreadKilled. Not attempting to reconnect."
                        return ()
                 _ ->
                   do debugStrLn $ "handleReconnect begin."
                      tmv <- atomically $ tryTakeTMVar ccTMV
                      case tmv of
                        Nothing ->
                            do debugStrLn $ "handleReconnect: error handling already in progress."
                               debugStrLn $ "handleReconnect end."
                               return ()
                        (Just (oldCC, oldListenQueue, oldListenerTID)) ->
                            do thisTID <- myThreadId
                               when (thisTID /= oldListenerTID) (killThread oldListenerTID)
                               ccClose oldCC
                               expireQueue oldListenQueue
                               cc <- reconnect
                               listenQueue <- atomically $ newTQueue
                               listenerTID <- forkIO $ listener cc listenQueue
                               atomically $ putTMVar ccTMV (cc, listenQueue, listenerTID)
                               debugStrLn $ "handleReconnect end."
                               return ()

           listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
           listener cc listenQueue
             = getResponse Strict.empty `catch` handleReconnect
               where
                 getResponse leftover =
                     do debugStrLn $ "listener: listening for Response."
                        let go inp = case inp of
                                   Fail msg       -> error msg
                                   Partial cont   -> do debugStrLn $ "listener: ccGetSome"
                                                        bs <- ccGetSome cc 1024
                                                        go (cont bs)
                                   Done resp rest -> do debugStrLn $ "listener: getting callback"
                                                        callback <- atomically $ readTQueue listenQueue
                                                        debugStrLn $ "listener: passing Response to callback"
                                                        callback (resp :: Response)
                                                        return rest
                        rest <- go (runGetPartial get leftover) -- `catch` (\e -> do handleReconnect e
                                                                --                   throwIO e
                                                                 --        )
                        getResponse rest

           actorThread :: IO ()
           actorThread = forever $
             do debugStrLn "actorThread: waiting for something to do."
                (cc, cmd) <- atomically $
                  do (cmd, ref)        <- readTQueue cmdQueue
                     (cc, listenQueue, _) <- readTMVar ccTMV
                     writeTQueue listenQueue (putMVar ref)
                     return (cc, cmd)
                debugStrLn "actorThread: sending command."
                ccPut cc (encode cmd) `catch` handleReconnect
                debugStrLn "actorThread: sent."
                return ()

           shutdown :: ThreadId -> IO ()
           shutdown actorTID =
               do debugStrLn "shutdown: update isClosed IORef to True."
                  writeIORef isClosed True
                  debugStrLn "shutdown: killing actor thread."
                  killThread actorTID
                  debugStrLn "shutdown: taking ccTMV."
                  (cc, listenQueue, listenerTID) <- atomically $ takeTMVar ccTMV -- FIXME: or should this by tryTakeTMVar
                  debugStrLn "shutdown: killing listener thread."
                  killThread listenerTID
                  debugStrLn "shutdown: expiring listen queue."
                  expireQueue  listenQueue
                  debugStrLn "shutdown: closing connection."
                  ccClose cc
                  return ()

       cc <- reconnect
       listenQueue <- atomically $ newTQueue

       actorTID    <- forkIO $ actorThread
       listenerTID <- forkIO $ listener cc listenQueue

       atomically $ putTMVar ccTMV (cc, listenQueue, listenerTID)

       return (toAcidState $ RemoteState actor (shutdown actorTID))

remoteQuery :: QueryEvent event => RemoteState (EventState event) -> event -> IO (EventResult event)
remoteQuery acidState event
  = do let encoded = runPutLazy (safePut event)
       resp <- remoteQueryCold acidState (methodTag event, encoded)
       return (case runGetLazyFix safeGet resp of
                 Left msg -> error msg
                 Right result -> result)

remoteQueryCold :: RemoteState st -> Tagged Lazy.ByteString -> IO Lazy.ByteString
remoteQueryCold rs@(RemoteState fn _shutdown) event
  = do resp <- takeMVar =<< fn (RunQuery event)
       case resp of
         (Result result) -> return result
         ConnectionError -> do debugStrLn "retrying query event."
                               remoteQueryCold rs event
         Acknowledgement    -> error "remoteQueryCold got Acknowledgement. That should never happen."


scheduleRemoteUpdate :: UpdateEvent event => RemoteState (EventState event) -> event -> IO (MVar (EventResult event))
scheduleRemoteUpdate (RemoteState fn _shutdown) event
  = do let encoded = runPutLazy (safePut event)
       parsed <- newEmptyMVar
       respRef <- fn (RunUpdate (methodTag event, encoded))
       forkIO $ do Result resp <- takeMVar respRef
                   putMVar parsed (case runGetLazyFix safeGet resp of
                                      Left msg -> error msg
                                      Right result -> result)
       return parsed

scheduleRemoteColdUpdate :: RemoteState st -> Tagged Lazy.ByteString -> IO (MVar Lazy.ByteString)
scheduleRemoteColdUpdate (RemoteState fn _shutdown) event
  = do parsed <- newEmptyMVar
       respRef <- fn (RunUpdate event)
       forkIO $ do Result resp <- takeMVar respRef
                   putMVar parsed resp
       return parsed

closeRemoteState :: RemoteState st -> IO ()
closeRemoteState (RemoteState _fn shutdown) = shutdown

createRemoteCheckpoint :: RemoteState st -> IO ()
createRemoteCheckpoint (RemoteState fn _shutdown)
  = do Acknowledgement <- takeMVar =<< fn CreateCheckpoint
       return ()

toAcidState :: IsAcidic st => RemoteState st -> AcidState st
toAcidState remote
  = AcidState { _scheduleUpdate    = scheduleRemoteUpdate remote
              , scheduleColdUpdate = scheduleRemoteColdUpdate remote
              , _query             = remoteQuery remote
              , queryCold          = remoteQueryCold remote
              , createCheckpoint   = createRemoteCheckpoint remote
              , closeAcidState     = closeRemoteState remote
              , acidSubState       = mkAnyState remote
              }