{-# LANGUAGE DeriveAnyClass    #-}
{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE FlexibleContexts  #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
module Haskoin.Store.Database.Writer
    ( WriterT
    , MemoryTx
    , runWriter
    , runTx
    ) where

import           Control.Applicative           ((<|>))
import           Control.DeepSeq               (NFData)
import           Control.Monad.Reader          (ReaderT (..))
import qualified Control.Monad.Reader          as R
import           Control.Monad.Trans.Maybe     (MaybeT (..), runMaybeT)
import qualified Data.ByteString.Short         as B.Short
import           Data.Hashable                 (Hashable)
import           Data.HashMap.Strict           (HashMap)
import qualified Data.HashMap.Strict           as M
import           Data.Maybe                    (maybeToList)
import           Database.RocksDB              (BatchOp)
import           Database.RocksDB.Query        (deleteOp, insertOp, writeBatch)
import           GHC.Generics                  (Generic)
import           Haskoin                       (Address, BlockHash, BlockHeight,
                                                Network, OutPoint (..), TxHash,
                                                headerHash, txHash)
import           Haskoin.Store.Common
import           Haskoin.Store.Data            (Balance (..), BlockData (..),
                                                BlockRef (..), Spender,
                                                TxData (..), TxRef (..),
                                                Unspent (..))
import           Haskoin.Store.Database.Reader
import           Haskoin.Store.Database.Types
import           UnliftIO                      (MonadIO, STM, TVar, atomically,
                                                modifyTVar, newTVarIO, readTVar,
                                                readTVarIO)

data Dirty a = Modified a | Deleted
    deriving (Eq, Show, Generic, Hashable, NFData)

instance Functor Dirty where
    fmap _ Deleted      = Deleted
    fmap f (Modified a) = Modified (f a)

data Writer = Writer { getReader :: !DatabaseReader
                     , getState  :: !(TVar Memory) }

type MemoryTx = ReaderT (TVar Memory) STM
type WriterT = ReaderT Writer

instance MonadIO m => StoreReadBase (WriterT m) where
    getNetwork =
        R.ask >>= getNetworkI
    getBestBlock =
        R.ask >>= getBestBlockI
    getBlocksAtHeight h =
        R.ask >>= getBlocksAtHeightI h
    getBlock b =
        R.ask >>= getBlockI b
    getTxData t =
        R.ask >>= getTxDataI t
    getSpender p =
        R.ask >>= getSpenderI p
    getUnspent a =
        R.ask >>= getUnspentI a
    getBalance a =
        R.ask >>= getBalanceI a
    getMempool =
        R.ask >>= getMempoolI

data Memory = Memory
    { hNet
      :: !Network
    , hBest
      :: !(Maybe BlockHash)
    , hBlock
      :: !(HashMap BlockHash BlockData)
    , hHeight
      :: !(HashMap BlockHeight [BlockHash])
    , hTx
      :: !(HashMap TxHash TxData)
    , hSpender
      :: !(HashMap OutPoint (Dirty Spender))
    , hUnspent
      :: !(HashMap OutPoint (Dirty UnspentVal))
    , hBalance
      :: !(HashMap Address BalVal)
    , hAddrTx
      :: !(HashMap (Address, TxRef) (Dirty ()))
    , hAddrOut
      :: !(HashMap (Address, BlockRef, OutPoint) (Dirty OutVal))
    , hMempool
      :: !(Maybe [TxRef])
    } deriving (Eq, Show)

instance StoreWrite MemoryTx where
    setBest h =
        ReaderT $ \v -> modifyTVar v $
        setBestH h
    insertBlock b =
        ReaderT $ \v -> modifyTVar v $
        insertBlockH b
    setBlocksAtHeight h g =
        ReaderT $ \v -> modifyTVar v $
        setBlocksAtHeightH h g
    insertTx t =
        ReaderT $ \v -> modifyTVar v $
        insertTxH t
    insertSpender p s =
        ReaderT $ \v -> modifyTVar v $
        insertSpenderH p s
    deleteSpender p =
        ReaderT $ \v -> modifyTVar v $
        deleteSpenderH p
    insertAddrTx a t =
        ReaderT $ \v -> modifyTVar v $
        insertAddrTxH a t
    deleteAddrTx a t =
        ReaderT $ \v -> modifyTVar v $
        deleteAddrTxH a t
    insertAddrUnspent a u =
        ReaderT $ \v -> modifyTVar v $
        insertAddrUnspentH a u
    deleteAddrUnspent a u =
        ReaderT $ \v -> modifyTVar v $
        deleteAddrUnspentH a u
    setMempool xs =
        ReaderT $ \v -> modifyTVar v $
        setMempoolH xs
    setBalance b =
        ReaderT $ \v -> modifyTVar v $
        setBalanceH b
    insertUnspent h =
        ReaderT $ \v -> modifyTVar v $
        insertUnspentH h
    deleteUnspent p =
        ReaderT $ \v -> modifyTVar v $
        deleteUnspentH p

instance StoreReadBase MemoryTx where
    getNetwork =
        ReaderT $ fmap hNet . readTVar
    getBestBlock =
        ReaderT $ \v -> getBestH <$> readTVar v >>= \case
            Nothing -> error "Best block not set in STM"
            Just b -> return (Just b)
    getBlocksAtHeight h =
        ReaderT $ \v -> getBlocksAtHeightH h <$> readTVar v >>= \case
            Nothing -> error "Blocks at height not set in STM"
            Just hs -> return hs
    getBlock h =
        ReaderT $ \v -> getBlockH h <$> readTVar v >>= \case
            Nothing -> error "Block not set in STM"
            Just b -> return (Just b)
    getTxData t =
        ReaderT $ \v -> getTxDataH t <$> readTVar v >>= \case
            Nothing -> error "Tx data not set in STM"
            Just d -> return (Just d)
    getSpender op =
        ReaderT $ \v -> do
        m <- getSpenderH op <$> readTVar v
        case m of
            Just (Modified s) -> return (Just s)
            Just Deleted      -> return Nothing
            Nothing           -> return Nothing
    getUnspent op =
        ReaderT $ \v -> do
        m <- getUnspentH op <$> readTVar v
        case m of
            Just (Modified u) -> return (Just (valToUnspent op u))
            Just Deleted      -> return Nothing
            Nothing           -> return Nothing
    getBalance a =
        ReaderT $ \v -> getBalanceH a <$> readTVar v >>= \case
            Just b  -> return $ Just (valToBalance a b)
            Nothing -> error "Balance not set in STM"
    getMempool =
        ReaderT $ \v -> getMempoolH <$> readTVar v >>= \case
            Just mp -> return mp
            Nothing -> error "Mempool not set in STM"

runWriter
    :: MonadIO m
    => DatabaseReader
    -> WriterT m a
    -> m a
runWriter bdb@DatabaseReader{databaseHandle = db, databaseNetwork = net} f = do
    hm <- newTVarIO (emptyMemory net)
    x <- R.runReaderT f Writer {getReader = bdb, getState = hm}
    ops <- hashMapOps <$> readTVarIO hm
    writeBatch db ops
    return x

hashMapOps :: Memory -> [BatchOp]
hashMapOps db =
    bestBlockOp (hBest db) <>
    blockHashOps (hBlock db) <>
    blockHeightOps (hHeight db) <>
    txOps (hTx db) <>
    spenderOps (hSpender db) <>
    balOps (hBalance db) <>
    addrTxOps (hAddrTx db) <>
    addrOutOps (hAddrOut db) <>
    maybeToList (mempoolOp <$> hMempool db) <>
    unspentOps (hUnspent db)

bestBlockOp :: Maybe BlockHash -> [BatchOp]
bestBlockOp Nothing  = []
bestBlockOp (Just b) = [insertOp BestKey b]

blockHashOps :: HashMap BlockHash BlockData -> [BatchOp]
blockHashOps = map (uncurry f) . M.toList
  where
    f = insertOp . BlockKey

blockHeightOps :: HashMap BlockHeight [BlockHash] -> [BatchOp]
blockHeightOps = map (uncurry f) . M.toList
  where
    f = insertOp . HeightKey

txOps :: HashMap TxHash TxData -> [BatchOp]
txOps = map (uncurry f) . M.toList
  where
    f = insertOp . TxKey

spenderOps :: HashMap OutPoint (Dirty Spender)
           -> [BatchOp]
spenderOps = map (uncurry f) . M.toList
  where
    f o (Modified s) =
        insertOp (SpenderKey o) s
    f o Deleted =
        deleteOp (SpenderKey o)

balOps :: HashMap Address BalVal -> [BatchOp]
balOps = map (uncurry f) . M.toList
  where
    f = insertOp . BalKey

addrTxOps :: HashMap (Address, TxRef) (Dirty ()) -> [BatchOp]
addrTxOps = map (uncurry f) . M.toList
  where
    f (a, t) (Modified ()) = insertOp (AddrTxKey a t) ()
    f (a, t) Deleted       = deleteOp (AddrTxKey a t)

addrOutOps
    :: HashMap (Address, BlockRef, OutPoint) (Dirty OutVal) -> [BatchOp]
addrOutOps = map (uncurry f) . M.toList
  where
    f (a, b, p) (Modified l) =
        insertOp
            (AddrOutKey { addrOutKeyA = a
                        , addrOutKeyB = b
                        , addrOutKeyP = p })
            l
    f (a, b, p) Deleted =
        deleteOp AddrOutKey { addrOutKeyA = a
                            , addrOutKeyB = b
                            , addrOutKeyP = p }

mempoolOp :: [TxRef] -> BatchOp
mempoolOp =
    insertOp MemKey .
    map (\TxRef { txRefBlock = MemRef t
                , txRefHash = h } -> (t, h))

unspentOps :: HashMap OutPoint (Dirty UnspentVal)
           -> [BatchOp]
unspentOps = map (uncurry f) . M.toList
  where
    f p (Modified u) =
        insertOp (UnspentKey p) u
    f p Deleted =
        deleteOp (UnspentKey p)

getNetworkI :: MonadIO m => Writer -> m Network
getNetworkI Writer {getState = hm} =
    hNet <$> readTVarIO hm

getBestBlockI :: MonadIO m => Writer -> m (Maybe BlockHash)
getBestBlockI Writer {getState = hm, getReader = db} =
    runMaybeT $ MaybeT f <|> MaybeT g
  where
    f = getBestBlockH <$> readTVarIO hm
    g = withDatabaseReader db getBestBlock

getBlocksAtHeightI :: MonadIO m
                   => BlockHeight
                   -> Writer
                   -> m [BlockHash]
getBlocksAtHeightI bh Writer {getState = hm, getReader = db} =
    getBlocksAtHeightH bh <$> readTVarIO hm >>= \case
        Just bs -> return bs
        Nothing -> withDatabaseReader db $ getBlocksAtHeight bh

getBlockI :: MonadIO m
          => BlockHash
          -> Writer
          -> m (Maybe BlockData)
getBlockI bh Writer {getReader = db, getState = hm} =
    runMaybeT $ MaybeT f <|> MaybeT g
  where
    f = getBlockH bh <$> readTVarIO hm
    g = withDatabaseReader db $ getBlock bh

getTxDataI :: MonadIO m
           => TxHash
           -> Writer
           -> m (Maybe TxData)
getTxDataI th Writer {getReader = db, getState = hm} =
    runMaybeT $ MaybeT f <|> MaybeT g
  where
    f = getTxDataH th <$> readTVarIO hm
    g = withDatabaseReader db $ getTxData th

getSpenderI :: MonadIO m => OutPoint -> Writer -> m (Maybe Spender)
getSpenderI op Writer {getReader = db, getState = hm} =
    getSpenderH op <$> readTVarIO hm >>= \case
        Just (Modified s) -> return (Just s)
        Just Deleted -> return Nothing
        Nothing -> withDatabaseReader db (getSpender op)

getBalanceI :: MonadIO m => Address -> Writer -> m (Maybe Balance)
getBalanceI a Writer {getReader = db, getState = hm} =
    getBalanceH a <$> readTVarIO hm >>= \case
        Just b -> return $ Just (valToBalance a b)
        Nothing -> withDatabaseReader db $ getBalance a

getUnspentI :: MonadIO m
            => OutPoint
            -> Writer
            -> m (Maybe Unspent)
getUnspentI op Writer {getReader = db, getState = hm} =
    getUnspentH op <$> readTVarIO hm >>= \case
        Just (Modified u) -> return (Just (valToUnspent op u))
        Just Deleted -> return Nothing
        Nothing -> withDatabaseReader db (getUnspent op)

getMempoolI :: MonadIO m => Writer -> m [TxRef]
getMempoolI Writer {getState = hm, getReader = db} =
    getMempoolH <$> readTVarIO hm >>= \case
        Just xs -> return xs
        Nothing -> withDatabaseReader db getMempool

runTx :: MonadIO m => MemoryTx a -> WriterT m a
runTx f = ReaderT $ atomically . runReaderT f . getState

emptyMemory :: Network -> Memory
emptyMemory net =
    Memory { hNet     = net
           , hBest    = Nothing
           , hBlock   = M.empty
           , hHeight  = M.empty
           , hTx      = M.empty
           , hSpender = M.empty
           , hUnspent = M.empty
           , hBalance = M.empty
           , hAddrTx  = M.empty
           , hAddrOut = M.empty
           , hMempool = Nothing
           }

getBestBlockH :: Memory -> Maybe BlockHash
getBestBlockH = hBest

getBlocksAtHeightH :: BlockHeight -> Memory -> Maybe [BlockHash]
getBlocksAtHeightH h = M.lookup h . hHeight

getBlockH :: BlockHash -> Memory -> Maybe BlockData
getBlockH h = M.lookup h . hBlock

getTxDataH :: TxHash -> Memory -> Maybe TxData
getTxDataH t = M.lookup t . hTx

getSpenderH :: OutPoint -> Memory -> Maybe (Dirty Spender)
getSpenderH op db = M.lookup op (hSpender db)

getBalanceH :: Address -> Memory -> Maybe BalVal
getBalanceH a = M.lookup a . hBalance

getMempoolH :: Memory -> Maybe [TxRef]
getMempoolH = hMempool

getBestH :: Memory -> Maybe BlockHash
getBestH = hBest

setBestH :: BlockHash -> Memory -> Memory
setBestH h db = db {hBest = Just h}

insertBlockH :: BlockData -> Memory -> Memory
insertBlockH bd db =
    db { hBlock =
             M.insert
                  (headerHash (blockDataHeader bd))
                  bd
                  (hBlock db)
       }

setBlocksAtHeightH :: [BlockHash] -> BlockHeight -> Memory -> Memory
setBlocksAtHeightH hs g db = db {hHeight = M.insert g hs (hHeight db)}

insertTxH :: TxData -> Memory -> Memory
insertTxH tx db = db {hTx = M.insert (txHash (txData tx)) tx (hTx db)}

insertSpenderH :: OutPoint -> Spender -> Memory -> Memory
insertSpenderH op s db =
    db { hSpender = M.insert op (Modified s) (hSpender db) }

deleteSpenderH :: OutPoint -> Memory -> Memory
deleteSpenderH op db =
    db { hSpender = M.insert op Deleted (hSpender db) }

setBalanceH :: Balance -> Memory -> Memory
setBalanceH bal db =
    db {hBalance = M.insert (balanceAddress bal) b (hBalance db)}
  where
    b = balanceToVal bal

insertAddrTxH :: Address -> TxRef -> Memory -> Memory
insertAddrTxH a tr db =
    db { hAddrTx = M.insert (a, tr) (Modified ()) (hAddrTx db) }

deleteAddrTxH :: Address -> TxRef -> Memory -> Memory
deleteAddrTxH a tr db =
    db { hAddrTx = M.insert (a, tr) Deleted (hAddrTx db) }

insertAddrUnspentH :: Address -> Unspent -> Memory -> Memory
insertAddrUnspentH a u db =
    let k = (a, unspentBlock u, unspentPoint u)
        v = OutVal { outValAmount = unspentAmount u
                   , outValScript = B.Short.fromShort (unspentScript u) }
     in db { hAddrOut = M.insert k (Modified v) (hAddrOut db) }

deleteAddrUnspentH :: Address -> Unspent -> Memory -> Memory
deleteAddrUnspentH a u db =
    let k = (a, unspentBlock u, unspentPoint u)
     in db { hAddrOut = M.insert k Deleted (hAddrOut db) }

setMempoolH :: [TxRef] -> Memory -> Memory
setMempoolH xs db = db { hMempool = Just xs }

getUnspentH :: OutPoint -> Memory -> Maybe (Dirty UnspentVal)
getUnspentH op db = M.lookup op (hUnspent db)

insertUnspentH :: Unspent -> Memory -> Memory
insertUnspentH u db =
    let (k, v) = unspentToVal u
     in db { hUnspent = M.insert k (Modified v) (hUnspent db) }

deleteUnspentH :: OutPoint -> Memory -> Memory
deleteUnspentH op db =
    db { hUnspent = M.insert op Deleted (hUnspent db) }
