{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}

#if __GLASGOW_HASKELL__ >= 800
{-# LANGUAGE LambdaCase #-}
#endif /* __GLASGOW_HASKELL__ >= 800 */

module CRDT.LWW
    ( LWW (..)
      -- * CvRDT
    , initialize
    , assign
    , query
      -- * Implementation detail
    , advanceFromLWW
    ) where

import           Data.Semilattice (Semilattice)

#if __GLASGOW_HASKELL__ >= 800
import           CRDT.Cm (CausalOrd (..), CmRDT (..))
#endif /* __GLASGOW_HASKELL__ >= 800 */

import           CRDT.LamportClock (Clock, LamportTime (LamportTime), advance,
                                    getTime)

-- | Last write wins. Assuming timestamp is unique.
-- This type is both 'CmRDT' and 'CvRDT'.
--
-- Timestamps are assumed unique, totally ordered,
-- and consistent with causal order;
-- i.e., if assignment 1 happened-before assignment 2,
-- the former’s timestamp is less than the latter’s.
data LWW a = LWW
    { LWW a -> a
value :: !a
    , LWW a -> LamportTime
time  :: !LamportTime
    }
    deriving (LWW a -> LWW a -> Bool
(LWW a -> LWW a -> Bool) -> (LWW a -> LWW a -> Bool) -> Eq (LWW a)
forall a. Eq a => LWW a -> LWW a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LWW a -> LWW a -> Bool
$c/= :: forall a. Eq a => LWW a -> LWW a -> Bool
== :: LWW a -> LWW a -> Bool
$c== :: forall a. Eq a => LWW a -> LWW a -> Bool
Eq, Int -> LWW a -> ShowS
[LWW a] -> ShowS
LWW a -> String
(Int -> LWW a -> ShowS)
-> (LWW a -> String) -> ([LWW a] -> ShowS) -> Show (LWW a)
forall a. Show a => Int -> LWW a -> ShowS
forall a. Show a => [LWW a] -> ShowS
forall a. Show a => LWW a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LWW a] -> ShowS
$cshowList :: forall a. Show a => [LWW a] -> ShowS
show :: LWW a -> String
$cshow :: forall a. Show a => LWW a -> String
showsPrec :: Int -> LWW a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> LWW a -> ShowS
Show)

--------------------------------------------------------------------------------
-- CvRDT -----------------------------------------------------------------------

-- | Merge by choosing more recent timestamp.
instance Eq a => Semigroup (LWW a) where
    x :: LWW a
x@(LWW a
xv LamportTime
xt) <> :: LWW a -> LWW a -> LWW a
<> y :: LWW a
y@(LWW a
yv LamportTime
yt)
        | LamportTime
xt LamportTime -> LamportTime -> Bool
forall a. Ord a => a -> a -> Bool
< LamportTime
yt = LWW a
y
        | LamportTime
yt LamportTime -> LamportTime -> Bool
forall a. Ord a => a -> a -> Bool
< LamportTime
xt = LWW a
x
        | a
xv a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
yv = LWW a
x
        | Bool
otherwise = String -> LWW a
forall a. HasCallStack => String -> a
error String
"LWW assumes timestamps to be unique"

-- | See 'CvRDT'
instance Eq a => Semilattice (LWW a)

-- | Initialize state
initialize :: Clock m => a -> m (LWW a)
initialize :: a -> m (LWW a)
initialize a
val = a -> LamportTime -> LWW a
forall a. a -> LamportTime -> LWW a
LWW a
val (LamportTime -> LWW a) -> m LamportTime -> m (LWW a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m LamportTime
forall (m :: * -> *). Clock m => m LamportTime
getTime

-- | Change state as CvRDT operation.
-- Current value is ignored, because new timestamp is always greater.
assign :: Clock m => a -> LWW a -> m (LWW a)
assign :: a -> LWW a -> m (LWW a)
assign a
val LWW a
old = do
    LWW a -> m ()
forall (m :: * -> *) a. Clock m => LWW a -> m ()
advanceFromLWW LWW a
old
    a -> m (LWW a)
forall (m :: * -> *) a. Clock m => a -> m (LWW a)
initialize a
val

-- | Query state
query :: LWW a -> a
query :: LWW a -> a
query = LWW a -> a
forall a. LWW a -> a
value

--------------------------------------------------------------------------------
-- CmRDT -----------------------------------------------------------------------

#if __GLASGOW_HASKELL__ >= 800

instance CausalOrd (LWW a) where
    precedes :: LWW a -> LWW a -> Bool
precedes LWW a
_ LWW a
_ = Bool
False

instance Eq a => CmRDT (LWW a) where
    type Intent  (LWW a) = a
    type Payload (LWW a) = Maybe (LWW a)

    initial :: Payload (LWW a)
initial = Payload (LWW a)
forall a. Maybe a
Nothing

    makeOp :: Intent (LWW a) -> Payload (LWW a) -> Maybe (m (LWW a))
makeOp Intent (LWW a)
val = m (LWW a) -> Maybe (m (LWW a))
forall a. a -> Maybe a
Just (m (LWW a) -> Maybe (m (LWW a)))
-> (Maybe (LWW a) -> m (LWW a))
-> Maybe (LWW a)
-> Maybe (m (LWW a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
        Just LWW a
payload -> a -> LWW a -> m (LWW a)
forall (m :: * -> *) a. Clock m => a -> LWW a -> m (LWW a)
assign a
Intent (LWW a)
val LWW a
payload
        Maybe (LWW a)
Nothing      -> a -> m (LWW a)
forall (m :: * -> *) a. Clock m => a -> m (LWW a)
initialize a
Intent (LWW a)
val

    apply :: LWW a -> Payload (LWW a) -> Payload (LWW a)
apply LWW a
op = LWW a -> Maybe (LWW a)
forall a. a -> Maybe a
Just (LWW a -> Maybe (LWW a))
-> (Maybe (LWW a) -> LWW a) -> Maybe (LWW a) -> Maybe (LWW a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
        Just LWW a
payload -> LWW a
op LWW a -> LWW a -> LWW a
forall a. Semigroup a => a -> a -> a
<> LWW a
payload
        Maybe (LWW a)
Nothing      -> LWW a
op

#endif /* __GLASGOW_HASKELL__ >= 800 */

advanceFromLWW :: Clock m => LWW a -> m ()
advanceFromLWW :: LWW a -> m ()
advanceFromLWW LWW{time :: forall a. LWW a -> LamportTime
time = LamportTime LocalTime
t Pid
_} = LocalTime -> m ()
forall (m :: * -> *). Clock m => LocalTime -> m ()
advance LocalTime
t