{-# options_ghc -XGADTs -XKindSignatures -XFlexibleInstances -XOverlappingInstances -XScopedTypeVariables -XEmptyDataDecls #-} 
module PGames where 

import Random
import Iso
import Debug.Trace

-- A game for type t, Game t, is a potentially infinite decision tree
-- with extra information about how to ask questions in the branches,
-- and elements of the datatype in the leaves.
-- We now include probabilities in the branches

-- /Game/
-- More general would be n-ary nodes, subsuming Split and Single
-- Easier in Coq, where we can define 
-- Split : forall (a : list {t:Type & nat * Game t}), ISO t (SumOver a) -> Type
data Void

data GamesOver :: * -> * where
  NilGames :: GamesOver Void
  ConsGames :: Int -> Game t -> GamesOver s -> GamesOver (Either t s)
  
data Game :: * -> * where
  Single :: ISO t () -> Game t
  Split :: ISO t s -> GamesOver s -> Game t
    
totalWeight :: GamesOver s -> Int
totalWeight NilGames = 0
totalWeight (ConsGames w _ go) = w + totalWeight go

split3 :: ISO t (Either t1 (Either t2 (Either t3 Void))) -> Int -> Game t1 -> Int -> Game t2 -> Int -> Game t3 -> Game t
split3 i w1 g1 w2 g2 w3 g3 = Split i (ConsGames w1 g1 $ ConsGames w2 g2 $ ConsGames w3 g3 $ NilGames)
        
flat2 :: ISO t (Either t1 t2) -> ISO t (Either t1 (Either t2 Void))
flat2 (Iso i j) = Iso (\x -> case i x of Left y -> Left y; Right z -> Right (Left z))
                      (\x -> case x of Left y -> j (Left y); Right (Left z) -> j (Right z))
          
flat3 :: ISO t (Either t1 (Either t2 t3)) -> ISO t (Either t1 (Either t2 (Either t3 Void)))
flat3 (Iso i j) = Iso (\x -> case i x of Left y -> Left y; Right (Left z) -> Right (Left z); Right (Right z) -> Right (Right (Left z)))
                      (\x -> case x of Left y -> j (Left y); Right (Left z) -> j (Right (Left z)); Right (Right (Left z)) -> j (Right (Right z)))
          
split2 :: ISO t (Either t1 t2) -> Int -> Game t1 -> Int -> Game t2 -> Game t
split2 i w1 g1 w2 g2 = Split (flat2 i) (ConsGames w1 g1 $ ConsGames w2 g2 $ NilGames)

split :: ISO t (Either t1 t2) -> Game t1 -> Game t2 -> Game t
split i g1 g2 = split2 i 1 g1 1 g2
-- /End/
                             
-- Coerce a game, via an isomorphism 
-- /coerceGame/
(+>) :: Game t -> ISO s t -> Game s 
(Single j) +> i   = Single (i `seqI` j)
(Split j gs) +> i = Split  (i `seqI` j) gs
-- /End/ 

infixl 4 +>

-- /Bit/
type Bit = Int  -- 0 or 1
-- /End/

type MInterval = (Int,Int,Int) 

-- Interval is specified by lower and upper bounds
type Interval = (Int,Int)

-- Expanded interval
type EInterval = (Int,Interval)

w1, w2, w3, w4 :: Int
w1 = 08192 --- 2^13    = w4/4
w2 = 16384 --- 2^14    = w4/2
w3 = 24576 --- 3*2^13  = 3*w4/4
w4 = 32768 --- 2^15    = w4

e :: Int
e = 15

unit :: Interval
unit = (0,w4)

narrow :: Interval -> MInterval -> Interval
narrow (l,r) (p,q,d) = (l + (w*p) `div` d, l + (w*q) `div` d)
  where w = r-l

nextBits :: EInterval -> Maybe ([Bit],EInterval)
nextBits (n,(l,r))
  | r <= w2   = Just (bits n 0,(0,(2*l,2*r)))
  | w2 <= l   = Just (bits n 1,(0,(2*l-w4,2*r-w4)))
  | otherwise = Nothing

enarrow :: EInterval -> MInterval -> EInterval
enarrow ei int2 = (n,narrow int1 int2)
  where (n,int1) = expand ei

expand :: EInterval -> EInterval
expand (n,(l,r))
  | w1 <= l && r <= w3 = expand (n+1,(2*l - w2,2*r - w2))
  | otherwise          = (n,(l,r))

bits :: Int -> Bit -> [Bit]
bits n b = b:replicate n (1-b)

stream :: EInterval -> [MInterval] -> [Bit]
stream z xs = case nextBits z of
  Just(y,z')  -> y ++ stream z' xs
  Nothing     -> case xs of
    []   -> []
    x:xs -> stream (enarrow z x) xs

arithEncAux :: EInterval -> Game t -> t -> [Bit]                              
arithEncAux ei g x = stream ei (encodeSyms g x)

encodeSyms :: Game t -> t -> [MInterval]
encodeSyms (Single _) x = []
encodeSyms (Split (Iso ask _) gs) x = encodeSym 0 gs (ask x)
  where encodeSym :: Int -> GamesOver t -> t -> [MInterval]
        encodeSym n (ConsGames w g gs) x = 
          case x of 
            Left y -> (n,n+w,total) : encodeSyms g y
            Right z -> encodeSym (n+w) gs z
        total = totalWeight gs
        
enc :: Game t -> t -> [Bit]
enc = arithEncAux (0,unit)

decode :: EInterval -> [Bit] -> Game t -> t
decode ei bs g = destream ei (c,ds) g
  where c = foldl (\x b -> 2*x + b) 0 cs
        (cs,ds) = splitAt e (bs ++ 1:replicate (e-1) 0)

ominus :: (Int,[Bit]) -> [Bit] -> (Int,[Bit])
ominus (c,ds) bs = foldl op (c,ds) bs
  where op (c,ds) b = (2*c - w4*b + head ds,tail ds)

fscale :: (Int,(Int,[Bit])) -> Int
fscale (n,(x,ds)) = foldl step x (take n ds)
  where step x b = 2*x + b - w2

destream :: EInterval -> (Int, [Bit]) -> Game t -> t
destream ei w g = case nextBits ei of
  Just (y,ei')  ->   destream ei' (ominus w y) g
  Nothing      -> 
    case g of    
      Single (Iso _ bld) -> bld ()
      Split (Iso _ bld) gs ->decodeSym bld gs 0
        where
          (n,(l,r)) = expand ei    
          k = fscale (n,w)
          t = ((k-l+1)*d - 1) `div` (r-l)
          d = totalWeight gs
          
          decodeSym :: (s -> t) -> GamesOver s -> Int -> t
          decodeSym bld (ConsGames weight g gs) n =
            if n' > t then bld (Left (destream (enarrow ei (n,n',d)) w g))
            else decodeSym (bld . Right) gs n'
                 where n' = n+weight

dec g bs = decode (0,unit) bs g
                                
testGame :: Game t -> t -> t
testGame g = dec g . enc g