{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}

module TX
    ( Persistable(..)

      -- * Managing the database
    , Database
    , openDatabase
    , closeDatabase
    , withUserData

      -- * The TX monad
    , TX
    , persistently
    , record
    , getData
    , liftSTM
    , throwTX
    , unsafeIOToTX

      -- * Utility functions
    , (<?>)
    ) where

import Control.Applicative
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import qualified Data.ByteString as B
import Data.Maybe
import Data.SafeCopy
import Data.Serialize
import GHC.Conc
import System.IO

------------------------------------------------------------------------------

-- | The type family at the heart of TX.
--
-- You make any data type you want to use with the TX monad an instance of
-- 'Persistable' and define 'Update' constructors for each of the methods
-- acting on this type that you want to be able to record during a transaction.
-- Then you implement 'replay' in such a way that for each of the Update
-- constructors, the appropiate method is called.
--
-- Example:
--
-- > data MyDB = MyDB { posts :: TVar [String] }
-- >
-- > instance Persistable MyDB where
-- >     data Update MyDB = CreatePost String
-- >                      | ModifyPost Int String
-- >
-- >     replay (CreatePost p)   = void $ createPost p
-- >     replay (ModifyPost n p) = modifyPost n p
--
-- where @createPost@ and @modifyPost@ are functions in the TX monad:
--
-- > createPost :: String -> TX MyDB Int
-- > createPost p = do
-- >     record (CreatePost p)
-- >     (MyDB posts) <- getData
-- >     liftSTM $ do
-- >         ps <- readTVar posts
-- >         writeTVar posts (ps ++ [p])
-- >         return $ length ps
-- >
-- > modifyPost :: Int -> String -> TX MyDB ()
-- > modifyPost n p = do
-- >     record (ModifyPost n p)
-- >     (MyDB posts) <- getData
-- >     liftSTM $ do
-- >         ps <- readTVar posts
-- >         let (xs,ys) = splitAt n ps
-- >             ps'     = xs ++ p : (tail ys)
-- >         writeTVar posts ps'
--
-- Note that @Update@ also needs to be an instance of 'SafeCopy'. Currently,
-- it's not possible to derive SafeCopy instances for associated types
-- automatically, so you have to do it by hand:
--
-- > instance SafeCopy (Update MyDB) where
-- >     putCopy (CreatePost p)   = contain $ putWord8 0 >> safePut p
-- >     putCopy (ModifyPost n p) = contain $ putWord8 1 >> safePut n >> safePut p
-- >     getCopy = contain $ do
-- >         tag <- getWord8
-- >         case tag of
-- >             0 -> CreatePost <$> safeGet
-- >             1 -> ModifyPost <$> safeGet <*> safeGet
-- >             _ -> fail $ "unknown tag: " ++ show tag
class SafeCopy (Update d) => Persistable d where
    data Update d
    replay :: Update d -> TX d ()

-- TODO: provide automatic derivation using TH

------------------------------------------------------------------------------

-- | An opaque type wrapping any kind of user data for use in the 'TX' monad.
data Database d = Database { userData :: d
                           , logHandle :: Handle
                           , logQueue :: TQueue (Update d)
                           , _record :: TQueue (Update d) -> (Update d) -> STM ()
                           , serializerTid :: MVar ThreadId
                           }

-- | Opens the database at the given path or creates a new one.
openDatabase :: Persistable d
             => FilePath  -- ^ Location of the log file.
             -> d  -- ^ Base data. Any existing log is replayed on top of this.
             -> IO (Database d)
openDatabase logPath userData = do
    putStr ("Opening " ++ logPath ++ " database... ")
    logHandle <- openBinaryFile logPath ReadWriteMode
    logQueue <- newTQueueIO
    serializerTid <- newEmptyMVar
    let _db = Database { _record = const $ const $ return (), .. }
    hIsEOF logHandle >>= flip unless (replayUpdates _db)
    putMVar serializerTid =<< forkIO (serializer _db)
    let db = _db { _record = writeTQueue }
    putStrLn ("DONE")
    return db

-- TODO: throw actual error when operating on a closed database
-- | Close a database. Blocks until all pending recordings been serialized.
-- Using a database after it has been closed is an error.
closeDatabase :: Database d -> IO ()
closeDatabase Database {..} = do
    atomically $ check =<< isEmptyTQueue logQueue
    killThread =<< takeMVar serializerTid
    hClose logHandle

replayUpdates :: Persistable d => Database d -> IO ()
replayUpdates db = mapDecode (persistently db . replay)
                             (B.hGetSome (logHandle db) 1024)

-- | @mapDecode f nextChunk@ repeatedly calls @nextChunk@ to get a
-- 'B.ByteString', (partially) decodes this string using 'safeGet' and
-- and then applies @f@ to the (final) result. This continues until
-- @nextChunk@ returns an empty ByteString.
mapDecode :: SafeCopy a => (a -> IO ()) -> IO B.ByteString -> IO ()
mapDecode f nextChunk = go run =<< nextChunk
    where
        run = runGetPartial safeGet
        go k c = case k c of
            Fail    err  -> error ("TX.mapDecode: " ++ err)
            Partial k'   -> go k' =<< nextChunk
            Done    u c' -> f u >> if B.null c'
                                       then do c'' <- nextChunk
                                               if B.null c''
                                                   then return ()
                                                   else go run c''
                                       else go run c'

-- | Operate non-persistently on the user data contained in the database.
withUserData :: Database d -> (d -> a) -> a
withUserData db act = act (userData db)

------------------------------------------------------------------------------

serializer :: Persistable d => Database d -> IO ()
serializer Database {..} = forever $ do
    u <- atomically $ readTQueue logQueue
    let str = runPut (safePut u)
    B.hPut logHandle str

------------------------------------------------------------------------------

-- | A thin wrapper around STM. The main feature is the ability to 'record'
-- updates of the underlying data.
newtype TX d a = TX (ReaderT (Database d) STM a)
    deriving (Functor, Applicative, Monad)

-- | Perform a series of TX actions persistently.
--
-- Note that there is no guarantee that all recorded updates have been serialized
-- when the functions returns. As such, durability is only partially guaranteed.
--
-- Since this calls 'atomically' on the underlying STM actions,
-- the same caveats apply (e.g. you can't use it inside 'unsafePerformIO').
persistently :: Database d -> TX d a -> IO a
persistently db (TX action) = atomically $ runReaderT action db

-- | Record an 'Update' to be serialized to disk when the transaction commits.
-- If the transaction retries, the update is still only recorded once.
-- If the transaction aborts, the update is not recorded at all.
record :: Update d -> TX d ()
record u = do
    Database {..} <- TX ask
    liftSTM $ _record logQueue u
{-# INLINE record #-}

-- | Get the user data from the database.
getData :: TX d d
getData = userData <$> TX ask
{-# INLINE getData #-}

-- | Run STM actions inside TX.
liftSTM :: STM a -> TX d a
liftSTM = TX . lift
{-# INLINE liftSTM #-}

-- | Throw an exception in TX, which will abort the transaction.
--
-- @throwTX = liftSTM . throwSTM@
throwTX :: Exception e => e -> TX d a
throwTX = liftSTM . throwSTM
{-# INLINE throwTX #-}

-- | Unsafely performs IO in the TX monad. Highly dangerous!
-- The same caveats as with 'unsafeIOToSTM' apply.
--
-- @unsafeIOToTX = liftSTM . unsafeIOToSTM@
unsafeIOToTX :: IO a -> TX d a
unsafeIOToTX = liftSTM . unsafeIOToSTM
{-# INLINE unsafeIOToTX #-}

------------------------------------------------------------------------------

-- | @act \<?\> err = maybe (throwTX err) return =<< act@
(<?>) :: Exception e => TX d (Maybe a) -> e -> TX d a
act <?> err = maybe (throwTX err) return =<< act