{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE GADTs #-}
module Raft.Log (
EntryIssuer(..),
EntryValue(..),
EntryHash(..),
genesisHash,
hashEntry,
Entry(..),
Entries,
lastEntryIndex,
LastLogEntry(..),
hashLastLogEntry,
lastLogEntryIndex,
lastLogEntryIssuer,
lastLogEntryTerm,
lastLogEntryIndexAndTerm,
RaftInitLog(..),
ReadEntriesSpec(..),
ReadEntriesRes(..),
IndexInterval(..),
RaftReadLog(..),
RaftWriteLog(..),
RaftDeleteLog(..),
DeleteSuccess(..),
RaftLog(..),
RaftLogError,
RaftLogExceptions,
updateLog,
clientReqData,
readEntries,
) where
import Protolude
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as BS16
import qualified Data.Map as Map
import Data.Serialize
import Data.Sequence (Seq(..), (|>), foldlWithIndex)
import qualified Data.Sequence as Seq
import Database.PostgreSQL.Simple.ToField (Action(..), ToField(..))
import Database.PostgreSQL.Simple.FromField (FromField(..), returnError, ResultError(..))
import Raft.Types
data EntryIssuer
= ClientIssuer ClientId SerialNum
| LeaderIssuer LeaderId
deriving (Show, Read, Eq, Generic, Serialize)
data EntryValue v
= EntryValue v
| NoValue
deriving (Show, Eq, Generic, Serialize)
newtype EntryHash = EntryHash { unEntryHash :: ByteString }
deriving (Show, Read, Eq, Ord, Generic, Serialize)
genesisHash :: EntryHash
genesisHash = EntryHash $ BS16.encode $ BS.replicate 32 0
hashEntry :: Serialize v => Entry v -> EntryHash
hashEntry = EntryHash . BS16.encode . SHA256.hash . encode
data Entry v = Entry
{ entryIndex :: Index
, entryTerm :: Term
, entryValue :: EntryValue v
, entryIssuer :: EntryIssuer
, entryPrevHash :: EntryHash
} deriving (Show, Eq, Generic, Serialize)
type Entries v = Seq (Entry v)
lastEntryIndex :: Entries v -> Maybe Index
lastEntryIndex entries =
case entries of
Empty -> Nothing
_ :|> e -> Just (entryIndex e)
data LastLogEntry v
= LastLogEntry (Entry v)
| NoLogEntries
deriving (Show)
hashLastLogEntry :: Serialize v => LastLogEntry v -> EntryHash
hashLastLogEntry = \case
LastLogEntry e -> hashEntry e
NoLogEntries -> genesisHash
lastLogEntryIndex :: LastLogEntry v -> Index
lastLogEntryIndex = \case
LastLogEntry e -> entryIndex e
NoLogEntries -> index0
lastLogEntryTerm :: LastLogEntry v -> Term
lastLogEntryTerm = \case
LastLogEntry e -> entryTerm e
NoLogEntries -> term0
lastLogEntryIndexAndTerm :: LastLogEntry v -> (Index, Term)
lastLogEntryIndexAndTerm lle = (lastLogEntryIndex lle, lastLogEntryTerm lle)
lastLogEntryIssuer :: LastLogEntry v -> Maybe EntryIssuer
lastLogEntryIssuer = \case
LastLogEntry e -> Just (entryIssuer e)
NoLogEntries -> Nothing
data InvalidLog
= InvalidIndex { expectedIndex :: Index, actualIndex :: Index }
| InvalidPrevHash { expectedHash :: EntryHash, actualHash :: EntryHash }
deriving (Show)
validateLog :: (Serialize v) => Entries v -> Either InvalidLog ()
validateLog es =
case es of
Empty -> Right ()
e :<| _ ->
second (const ()) $
foldlWithIndex accValidateEntry (Right Nothing) es
where
accValidateEntry (Left err) _ _ = Left err
accValidateEntry (Right mPrevEntry) idx e = validateEntry mPrevEntry idx e
validateEntry mPrevEntry expectedIdx currEntry = do
case mPrevEntry of
Nothing -> validatePrevHash genesisHash currEntryPrevHash
Just prevEntry -> validatePrevHash (hashEntry prevEntry) currEntryPrevHash
validateIndex expectedEntryIdx currEntryIdx
pure (Just currEntry)
where
currEntryIdx = entryIndex currEntry
expectedEntryIdx = Index (fromIntegral expectedIdx + 1)
currEntryPrevHash = entryPrevHash currEntry
validateIndex :: Index -> Index -> Either InvalidLog ()
validateIndex expectedIndex currIndex
| expectedIndex /= currIndex =
Left (InvalidIndex expectedIndex currIndex)
| otherwise = Right ()
validatePrevHash :: EntryHash -> EntryHash -> Either InvalidLog ()
validatePrevHash expectedHash currHash
| expectedHash /= currHash =
Left (InvalidPrevHash expectedHash currHash)
| otherwise = Right ()
clientReqData :: Entries v -> Map ClientId (SerialNum, Index)
clientReqData = go mempty
where
go acc es =
case es of
Empty -> acc
e :<| rest ->
case entryIssuer e of
LeaderIssuer _ -> go acc rest
ClientIssuer cid sn -> go (Map.insert cid (sn, entryIndex e) acc) rest
class RaftInitLog m v where
type RaftInitLogError m
initializeLog :: Proxy v -> m (Either (RaftInitLogError m) ())
class (Show (RaftWriteLogError m), Monad m) => RaftWriteLog m v where
type RaftWriteLogError m
writeLogEntries
:: Exception (RaftWriteLogError m)
=> Entries v -> m (Either (RaftWriteLogError m) ())
data DeleteSuccess v = DeleteSuccess
class (Show (RaftDeleteLogError m), Monad m) => RaftDeleteLog m v where
type RaftDeleteLogError m
deleteLogEntriesFrom
:: Exception (RaftDeleteLogError m)
=> Index -> m (Either (RaftDeleteLogError m) (DeleteSuccess v))
class (Show (RaftReadLogError m), Monad m) => RaftReadLog m v where
type RaftReadLogError m
readLogEntry
:: Exception (RaftReadLogError m)
=> Index -> m (Either (RaftReadLogError m) (Maybe (Entry v)))
readLogEntriesFrom
:: Exception (RaftReadLogError m)
=> Index -> m (Either (RaftReadLogError m) (Entries v))
readLastLogEntry
:: Exception (RaftReadLogError m)
=> m (Either (RaftReadLogError m) (Maybe (Entry v)))
default readLogEntriesFrom
:: Exception (RaftReadLogError m)
=> Index
-> m (Either (RaftReadLogError m) (Entries v))
readLogEntriesFrom idx = do
eLastLogEntry <- readLastLogEntry
case eLastLogEntry of
Left err -> pure (Left err)
Right Nothing -> pure (Right Empty)
Right (Just lastLogEntry)
| entryIndex lastLogEntry < idx -> pure (Right Empty)
| otherwise -> fmap (|> lastLogEntry) <$> go (decrIndexWithDefault0 (entryIndex lastLogEntry))
where
go idx'
| idx' < idx || idx' == 0 = pure (Right Empty)
| otherwise = do
eLogEntry <- readLogEntry idx'
case eLogEntry of
Left err -> pure (Left err)
Right Nothing -> panic "Malformed log"
Right (Just logEntry) -> fmap (|> logEntry) <$> go (decrIndexWithDefault0 idx')
type RaftLog m v = (RaftInitLog m v, RaftReadLog m v, RaftWriteLog m v, RaftDeleteLog m v)
type RaftLogExceptions m = (Exception (RaftInitLogError m), Exception (RaftReadLogError m), Exception (RaftWriteLogError m), Exception (RaftDeleteLogError m))
data RaftLogError m where
RaftLogInitError :: Show (RaftInitLogError m) => RaftInitLogError m -> RaftLogError m
RaftLogReadError :: Show (RaftReadLogError m) => RaftReadLogError m -> RaftLogError m
RaftLogWriteError :: Show (RaftWriteLogError m) => RaftWriteLogError m -> RaftLogError m
RaftLogDeleteError :: Show (RaftDeleteLogError m) => RaftDeleteLogError m -> RaftLogError m
deriving instance Show (RaftLogError m)
updateLog
:: forall m v.
( RaftDeleteLog m v, Exception (RaftDeleteLogError m)
, RaftWriteLog m v, Exception (RaftWriteLogError m)
)
=> Entries v
-> m (Either (RaftLogError m) (Maybe Index))
updateLog entries =
case entries of
Empty -> pure (Right Nothing)
e :<| _ -> do
eDel <- deleteLogEntriesFrom @m @v (entryIndex e)
case eDel of
Left err -> pure (Left (RaftLogDeleteError err))
Right DeleteSuccess -> do
eRes <- first RaftLogWriteError <$> writeLogEntries entries
case eRes of
Left err -> pure (Left err)
Right () -> pure (Right (lastEntryIndex entries))
data IndexInterval = IndexInterval (Maybe Index) (Maybe Index)
deriving (Show, Generic, Serialize)
data ReadEntriesSpec
= ByIndex Index
| ByIndices IndexInterval
deriving (Show, Generic, Serialize)
data ReadEntriesError m where
EntryDoesNotExist :: Either EntryHash Index -> ReadEntriesError m
InvalidIntervalSpecified :: (Index, Index) -> ReadEntriesError m
ReadEntriesError :: Exception (RaftReadLogError m) => RaftReadLogError m -> ReadEntriesError m
deriving instance Show (ReadEntriesError m)
deriving instance Typeable m => Exception (ReadEntriesError m)
data ReadEntriesRes v
= OneEntry (Entry v)
| ManyEntries (Entries v)
readEntries
:: forall m v. (RaftReadLog m v, Exception (RaftReadLogError m))
=> ReadEntriesSpec
-> m (Either (ReadEntriesError m) (ReadEntriesRes v))
readEntries res =
case res of
ByIndex idx -> do
res <- readLogEntry idx
case res of
Left err -> pure (Left (ReadEntriesError err))
Right Nothing -> pure (Left (EntryDoesNotExist (Right idx)))
Right (Just e) -> pure (Right (OneEntry e))
ByIndices interval -> fmap ManyEntries <$> readEntriesByIndices interval
readEntriesByIndices
:: forall m v. (RaftReadLog m v, Exception (RaftReadLogError m))
=> IndexInterval
-> m (Either (ReadEntriesError m) (Entries v))
readEntriesByIndices (IndexInterval l h) =
case (l,h) of
(Nothing, Nothing) ->
first ReadEntriesError <$> readLogEntriesFrom (Index 0)
(Nothing, Just hidx) ->
bimap ReadEntriesError (Seq.takeWhileL ((<= hidx) . entryIndex))
<$> readLogEntriesFrom (Index 0)
(Just lidx, Nothing) ->
first ReadEntriesError <$> readLogEntriesFrom lidx
(Just lidx, Just hidx)
| lidx >= hidx ->
pure (Left (InvalidIntervalSpecified (lidx, hidx)))
| otherwise ->
bimap ReadEntriesError (Seq.takeWhileL ((<= hidx) . entryIndex))
<$> readLogEntriesFrom lidx
instance Serialize v => ToField (EntryValue v) where
toField = EscapeByteA . encode
instance (Typeable v, Serialize v) => FromField (EntryValue v) where
fromField f mdata = do
bs <- fromField f mdata
case decode <$> bs of
Nothing -> returnError UnexpectedNull f ""
Just (Left err) -> returnError ConversionFailed f err
Just (Right entry) -> return entry
instance ToField EntryIssuer where
toField entryIssuer = Escape (show entryIssuer)
instance FromField EntryIssuer where
fromField f mdata = do
case readEither . toS <$> mdata of
Nothing -> returnError UnexpectedNull f ""
Just (Left err) -> returnError ConversionFailed f err
Just (Right entryIssuer) -> return entryIssuer
instance ToField EntryHash where
toField (EntryHash hbs) = toField hbs
instance FromField EntryHash where
fromField f = fmap EntryHash . fromField f