{-# LANGUAGE DeriveFunctor, GADTs, PatternGuards #-}
module DecisionTree where
import Contract
import Observable (Steps(..))
import qualified Observable as Obs
import Display
import Prelude hiding (product, until, and)
import Data.List hiding (and)
import Control.Monad hiding (when)
import Text.XML.HaXml.Namespaces (localName)
import Text.XML.HaXml.XmlContent
data DecisionStep x = Done
| Trade TradeDir Double Tradeable x
| Choose Party ChoiceId x x
| ObserveCond (Obs Bool) x x
| ObserveValue (Obs Double) (Double -> x)
| Wait [(Obs Bool, Time -> x)] [(ChoiceId, Time -> x)]
deriving Functor
data Party = FirstParty | Counterparty | ThirdParty PartyName
deriving (Eq, Show)
data DecisionTree = DecisionTree Time (DecisionStep DecisionTree)
unfoldDecisionTree :: (a -> (DecisionStep a, Time)) -> a -> DecisionTree
unfoldDecisionTree next = unfold
where
unfold x = case next x of
(step, time) -> DecisionTree time (fmap unfold step)
decisionTree :: Time -> Contract -> DecisionTree
decisionTree t c = unfoldDecisionTree
(\pst@(PSt t' _ _) -> (decisionStep pst, t'))
(initialProcessState t c)
data ProcessState = PSt Time
[Blocked ThreadState]
[ThreadState]
deriving (Show, Eq)
data ThreadState = TSt Contract
[Obs Bool]
ScaleFactor
TradeDir
deriving (Show, Eq)
data Blocked c =
BlockedOnWhen (Obs Bool) c
| BlockedOnAnytime Bool ChoiceId (Obs Bool) c
deriving (Show, Eq)
initialProcessState :: Time -> Contract -> ProcessState
initialProcessState time contract =
let initialThread = TSt contract [] 1 TradeDir2To1
in PSt time [] [initialThread]
currentContract :: ProcessState -> Contract
currentContract (PSt _time blocked runnable) =
allOf
[ evalTradeDir tradeDir
(scaleBy scaleFactor
(foldr until contract untilObss))
| let threads = runnable ++ map unBlocked blocked
, TSt contract untilObss scaleFactor tradeDir <- threads
, contract /= zero ]
where
scaleBy 1.0 c = c
scaleBy s c = scale (konst s) c
unBlocked (BlockedOnWhen o cth) = update (when o) cth
unBlocked (BlockedOnAnytime _ cid o cth) = update (anytime cid o) cth
update f (TSt c uos sf d) = TSt (f c) uos sf d
allOf [] = zero
allOf xs = foldr1 and xs
decisionStep :: ProcessState -> DecisionStep ProcessState
decisionStep (PSt time blocked runnable) =
go blocked runnable
where
go bs (TSt c uos sf d:rs) = case c of
Zero -> go bs rs
One t -> Trade d sf t (PSt time bs rs)
Give c1 -> let r' = TSt c1 uos sf (flipTradeDir d)
in go bs (r':rs)
Party p c1 -> let d' = setThirdParty p d
r' = TSt c1 uos sf d'
in go bs (r':rs)
And c1 c2 -> let r1 = TSt c1 uos sf d
r2 = TSt c2 uos sf d
in go bs (r1:r2:rs)
Or cid c1 c2 -> let r1 = TSt c1 uos sf d
r2 = TSt c2 uos sf d
(p,_) = tradeDirParties d
in Choose p cid (PSt time bs (r1:rs))
(PSt time bs (r2:rs))
Cond o c1 c2 -> let r1 = TSt c1 uos sf d
r2 = TSt c2 uos sf d
in ObserveCond o (PSt time bs (r1:rs))
(PSt time bs (r2:rs))
Scale o c1 -> let r' v = TSt c1 uos (v * sf) d
in ObserveValue o (\v -> PSt time bs (r' v:rs))
Read n o c1 -> let r' v = TSt (subst n v c1) uos sf d
in ObserveValue o (\v -> PSt time bs (r' v:rs))
When o c1 -> let b = BlockedOnWhen o (TSt c1 uos sf d)
in go (b:bs) rs
Anytime cid o c1 -> let b = BlockedOnAnytime True cid o (TSt c1 uos sf d)
in go (b:bs) rs
Until o c1 -> let r' = TSt c1 (o:uos) sf d
in ObserveCond o (PSt time bs rs)
(PSt time bs (r':rs))
go [] [] = Done
go bs [] = Wait (whens ++ untils) opts
where
whens = [ ( blockedObs b
, case nextThread b of
Left r' -> \time' -> PSt time' bs' [r']
Right b' -> \time' -> PSt time' (b':bs') [] )
| (b, bs') <- each bs ]
untils = [ ( uo
, \time' -> PSt time' bs' [] )
| (b, bs') <- each bs
, let TSt _ uos _ _ = blockedThr b
, uo <- uos ]
opts = [ (cid, \time' -> PSt time' bs' [r'])
| (BlockedOnAnytime False cid _ r', bs') <- each bs ]
blockedObs (BlockedOnWhen o _) = o
blockedObs (BlockedOnAnytime True _ o _) = o
blockedObs (BlockedOnAnytime False _ o _) = Obs.not o
blockedThr (BlockedOnWhen _ x) = x
blockedThr (BlockedOnAnytime _ _ _ x) = x
nextThread (BlockedOnWhen _ r) = Left r
nextThread (BlockedOnAnytime v cid o b) =
Right (BlockedOnAnytime (not v) cid o b)
subst :: String -> Double -> Contract -> Contract
subst n v c = case c of
Zero -> c
One _ -> c
Give c1 -> Give (subst n v c1)
Party p c1 -> Party p (subst n v c1)
And c1 c2 -> And (subst n v c1) (subst n v c2)
Or id c1 c2 -> Or id (subst n v c1) (subst n v c2)
Cond o c1 c2 -> Cond (Obs.subst n v o) (subst n v c1) (subst n v c2)
Scale o c1 -> Scale (Obs.subst n v o) (subst n v c1)
Read n' o c1
| n == n' -> Read n' (Obs.subst n v o) c1
| otherwise-> Read n' (Obs.subst n v o) (subst n v c1)
When o c1 -> When (Obs.subst n v o) (subst n v c1)
Anytime cid o c1 -> Anytime cid (Obs.subst n v o) (subst n v c1)
Until o c1 -> Until (Obs.subst n v o) (subst n v c1)
each :: [a] -> [(a, [a])]
each xs = [ (xs !! n, [ x' | (n',x') <- nxs, n' /= n ] )
| n <- [0..length xs-1] ]
where
nxs = zip [0..] xs
data TradeDir
= TradeDir2To1
| TradeDir1To2
| TradeDirPTo1 PartyName
| TradeDirPTo2 PartyName
| TradeDir1ToP PartyName
| TradeDir2ToP PartyName
| TradeDirPToQ PartyName PartyName
| TradeDirQToP PartyName PartyName
deriving (Show, Eq)
evalTradeDir :: TradeDir -> (Contract -> Contract)
evalTradeDir TradeDir2To1 = id
evalTradeDir TradeDir1To2 = give
evalTradeDir (TradeDirPTo1 p) = party p
evalTradeDir (TradeDirPTo2 p) = give . party p
evalTradeDir (TradeDir1ToP p) = party p . give
evalTradeDir (TradeDir2ToP p) = give . party p . give
evalTradeDir (TradeDirPToQ p q) = party q . give . party p
evalTradeDir (TradeDirQToP p q) = party q . give . party p . give
setThirdParty :: PartyName -> TradeDir -> TradeDir
setThirdParty p TradeDir2To1 = TradeDirPTo1 p
setThirdParty p TradeDir1To2 = TradeDirPTo2 p
setThirdParty p (TradeDirPTo1 q) = TradeDirPTo1 p
setThirdParty p (TradeDirPTo2 q) = TradeDirPTo2 p
setThirdParty p (TradeDir1ToP q) = TradeDirPToQ p q
setThirdParty p (TradeDir2ToP q) = TradeDirPToQ p q
setThirdParty p (TradeDirPToQ q r) = TradeDirPToQ p r
setThirdParty p (TradeDirQToP q r) = TradeDirPToQ p q
flipTradeDir :: TradeDir -> TradeDir
flipTradeDir TradeDir2To1 = TradeDir1To2
flipTradeDir TradeDir1To2 = TradeDir2To1
flipTradeDir (TradeDirPTo1 p) = TradeDir1ToP p
flipTradeDir (TradeDirPTo2 p) = TradeDir2ToP p
flipTradeDir (TradeDir1ToP p) = TradeDirPTo1 p
flipTradeDir (TradeDir2ToP p) = TradeDirPTo2 p
flipTradeDir (TradeDirPToQ p q) = TradeDirQToP p q
flipTradeDir (TradeDirQToP p q) = TradeDirPToQ p q
tradeDirParties :: TradeDir -> (Party, Party)
tradeDirParties TradeDir2To1 = (FirstParty, Counterparty)
tradeDirParties TradeDir1To2 = (Counterparty, FirstParty)
tradeDirParties (TradeDirPTo1 p) = (FirstParty, ThirdParty p)
tradeDirParties (TradeDirPTo2 p) = (Counterparty, ThirdParty p)
tradeDirParties (TradeDir1ToP p) = (ThirdParty p, FirstParty)
tradeDirParties (TradeDir2ToP p) = (ThirdParty p, Counterparty)
tradeDirParties (TradeDirPToQ p q) = (ThirdParty q, ThirdParty p)
tradeDirParties (TradeDirQToP p q) = (ThirdParty p, ThirdParty q)
instance Show (DecisionStep x) where
show Done = "Done"
show (Trade dir n t _) = case dir of
TradeDir2To1 -> "Receive " ++ quantityOfStuff
TradeDir1To2 -> "Provide " ++ quantityOfStuff
TradeDirPTo1 p -> "Receive from " ++ partyQuantityOfStuff p
TradeDirPTo2 p -> "Counterparty receives from " ++ partyQuantityOfStuff p
TradeDir1ToP p -> "Provide to " ++ partyQuantityOfStuff p
TradeDir2ToP p -> "Counterparty provides to " ++ partyQuantityOfStuff p
TradeDirPToQ p q -> p ++ " provides to " ++ q ++ " " ++ quantityOfStuff
TradeDirQToP p q -> q ++ " provides to " ++ p ++ " " ++ quantityOfStuff
where
quantityOfStuff = show n ++ " " ++ show t
partyQuantityOfStuff p = p ++ " " ++ quantityOfStuff
show (Choose p cid _ _) = "Choose " ++ show p ++ " " ++ cid
show (ObserveCond o _ _) = "ObserveCond " ++ show o
show (ObserveValue o _) = "ObserveValue " ++ show o
show (Wait conds opts) = "Wait for one to become true...\n"
++ unlines (map (show . fst) conds)
++ "Or pick an available option\n"
++ unlines (map (show . fst) opts)
instance Display DecisionTree where
toTree (DecisionTree time st) = case st of
Done -> Node "done" []
Trade dir n t st1 -> Node descr [toTree st1]
where
descr = dirDescr ++ " " ++ show n ++ " " ++ show t
++ "\n" ++ show time
dirDescr = case dir of
TradeDir2To1 -> "receive"
TradeDir1To2 -> "provide"
TradeDirPTo1 p -> "receive from " ++ p
TradeDirPTo2 p -> "counterparty receives from " ++ p
TradeDir1ToP p -> "provide to " ++ p
TradeDir2ToP p -> "counterparty provides to " ++ p
TradeDirPToQ p q -> p ++ " provides to " ++ q
TradeDirQToP p q -> q ++ " provides to " ++ p
Choose p cid st1 st2 -> Node (partyDescr ++ cid ++ "\n" ++ show time)
[toTree st1, toTree st2]
where
partyDescr = case p of
FirstParty -> "choose "
Counterparty -> "counterparty choice "
ThirdParty p -> "3rd party " ++ p ++ " choice "
ObserveCond obs st1 st2 -> Node "observe cond" [toTree obs
,toTree st1
,toTree st2]
ObserveValue obs st1 -> Node "observe val" [ toTree obs
, case Obs.eval time obs of
Result v -> toTree (st1 v)
_ -> toTree (st1 0) ]
Wait conds opts -> Node ("wait\n" ++ show time)
$ [ Node (show obs)
[ case Obs.timeHorizon time obs of
Nothing -> toTree (cont time)
Just time' -> toTree (cont time') ]
| (obs, cont) <- conds ]
++ [ Node "option" [toTree (cont time)]
| (_c, cont) <- opts ]
instance HTypeable Party where
toHType _ = Defined "Party" [] []
instance XmlContent Party where
parseContents = do
e@(Elem t _ _) <- element ["Party", "Counterparty", "ThirdParty"]
commit $ interior e $ case localName t of
"Party" -> return FirstParty
"Counterparty" -> return Counterparty
"ThirdParty" -> liftM ThirdParty text
toContents FirstParty = [mkElemC "Party" []]
toContents Counterparty = [mkElemC "Counterparty" []]
toContents (ThirdParty p) = [mkElemC "ThirdParty" (toText p)]
instance HTypeable TradeDir where
toHType _ = Defined "TradeDir" [] []
instance XmlContent TradeDir where
parseContents = do
e@(Elem t _ _) <- element ["TradeDir2To1","TradeDir1To2"
,"TradeDirPTo1","TradeDirPTo2"
,"TradeDir1ToP","TradeDir2ToP"
,"TradeDirPToQ","TradeDirQToP"]
commit $ interior e $ case localName t of
"TradeDir2To1" -> return TradeDir2To1
"TradeDir1To2" -> return TradeDir1To2
"TradeDirPTo1" -> liftM TradeDirPTo1 text
"TradeDirPTo2" -> liftM TradeDirPTo2 text
"TradeDir1ToP" -> liftM TradeDir1ToP text
"TradeDir2ToP" -> liftM TradeDir2ToP text
"TradeDirPToQ" -> liftM2 TradeDirPToQ text text
"TradeDirQToP" -> liftM2 TradeDirQToP text text
toContents TradeDir2To1 = [mkElemC "TradeDir2To1" []]
toContents TradeDir1To2 = [mkElemC "TradeDir1To2" []]
toContents (TradeDirPTo1 p) = [mkElemC "TradeDirPTo1" (toText p)]
toContents (TradeDirPTo2 p) = [mkElemC "TradeDirPTo2" (toText p)]
toContents (TradeDir1ToP p) = [mkElemC "TradeDir1ToP" (toText p)]
toContents (TradeDir2ToP p) = [mkElemC "TradeDir2ToP" (toText p)]
toContents (TradeDirPToQ p q) = [mkElemC "TradeDirPToQ" (toText p ++ toText q)]
toContents (TradeDirQToP p q) = [mkElemC "TradeDirQToP" (toText p ++ toText q)]
instance HTypeable (Blocked c) where
toHType _ = Defined "Blocked" [] []
instance XmlContent c => XmlContent (Blocked c) where
parseContents = do
e@(Elem t _ _) <- element ["BlockedOnWhen", "BlockedOnAnytime"]
commit $ interior e $ case localName t of
"BlockedOnWhen" -> liftM2 BlockedOnWhen (fmap unObsCondition parseContents)
parseContents
"BlockedOnAnytime" -> liftM4 BlockedOnAnytime parseContents
(inElement "ChoiceId" text)
(fmap unObsCondition parseContents)
parseContents
toContents (BlockedOnWhen obs c) =
[mkElemC "BlockedOnWhen" (toContents (ObsCondition obs) ++ toContents c)]
toContents (BlockedOnAnytime val cid obs c) =
[mkElemC "BlockedOnAnytime" (toContents val ++ [mkElemC "ChoiceId" (toText cid)] ++
toContents (ObsCondition obs) ++ toContents c)]
newtype ObsCondition = ObsCondition { unObsCondition :: (Obs Bool) }
instance HTypeable ObsCondition where
toHType _ = Defined "ObsCondition" [] []
instance XmlContent ObsCondition where
parseContents = inElement "ObsCondition" $
liftM ObsCondition Obs.parseObsCond
toContents (ObsCondition obs) = [mkElemC "ObsCondition" [Obs.printObs obs]]
instance HTypeable ThreadState where
toHType _ = Defined "ThreadState" [] []
instance XmlContent ThreadState where
parseContents =
inElement "ThreadState" $
liftM4 TSt parseContents
(inElement "UntilConditions" (fmap (map unObsCondition) parseContents))
parseContents parseContents
toContents (TSt c obss sf dir) =
[mkElemC "ThreadState" (toContents c ++
mkElemC "UntilConditions" (toContents (map ObsCondition obss)) :
toContents sf ++ toContents dir)]
instance HTypeable ProcessState where
toHType _ = Defined "ProcessState" [] []
instance XmlContent ProcessState where
parseContents =
inElement "ProcessState" $
liftM3 PSt parseContents
(inElement "BlockedThreads" parseContents)
(inElement "RunnableThreads" parseContents)
toContents (PSt t blocked runnable) =
[mkElemC "ProcessState" (toContents t ++
[ mkElemC "BlockedThreads" (toContents blocked)
, mkElemC "RunnableThreads" (toContents runnable) ])]