{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}

{- | This implements the model environments that users must provide upon running a model;
     such environments assign traces of values to the "observable variables" (random
     variables which can be conditioned against) of a model.
-}

module Env
  ( -- * Observable variable
    ObsVar(..)
  , varToStr
    -- * Model environment
  , Assign(..)
  , Env(..)
  , (<:>)
  , nil
  , Observable(..)
  , Observables(..)
  , UniqueKey
  , LookupType) where

import Data.Kind ( Constraint )
import Data.Proxy ( Proxy(Proxy) )
import FindElem ( FindElem(..), Idx(..) )
import GHC.OverloadedLabels ( IsLabel(..) )
import GHC.TypeLits ( KnownSymbol, Symbol, symbolVal )
import Unsafe.Coerce ( unsafeCoerce )

-- | Containers for observable variables
data ObsVar (x :: Symbol) where
  ObsVar :: KnownSymbol x => ObsVar x

-- | Allows the syntax @#x@ to be automatically lifted to the type @ObsVar "x"@.
instance (KnownSymbol x, x ~ x') => IsLabel x (ObsVar x') where
  fromLabel :: ObsVar x'
fromLabel = ObsVar x'
forall (x :: Symbol). KnownSymbol x => ObsVar x
ObsVar

-- | Convert an observable variable from a type-level string to a value-level string
varToStr :: forall x. ObsVar x -> String
varToStr :: forall (x :: Symbol). ObsVar x -> String
varToStr ObsVar x
ObsVar = Proxy x -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall {t :: Symbol}. Proxy t
Proxy @x)

-- * Model Environments

{- | A model environment assigning traces (lists) of observed values to observable
     variables i.e. the type @Env ((x := a) : env)@ indicates @x@ is assigned a value
     of type @[a]@.
-}
data Env (env :: [Assign Symbol *]) where
  ENil  :: Env '[]
  ECons :: [a] -> Env env -> Env (x := a : env)

-- | Assign or associate a variable @x@ with a value of type @a@
data Assign x a = x := a

-- | Empty model environment
nil :: Env '[]
nil :: Env '[]
nil = Env '[]
ENil

infixr 5 <:>
-- | Prepend a variable assignment to a model environment
(<:>) :: UniqueKey x env ~ True => Assign (ObsVar x) [a] -> Env env -> Env ((x ':= a) ': env)
(ObsVar x
_ := [a]
as) <:> :: forall (x :: Symbol) (env :: [Assign Symbol (*)]) a.
(UniqueKey x env ~ 'True) =>
Assign (ObsVar x) [a] -> Env env -> Env ((x ':= a) : env)
<:> Env env
env = [a] -> Env env -> Env ((x ':= a) : env)
forall a (env :: [Assign Symbol (*)]) (x :: Symbol).
[a] -> Env env -> Env ((x ':= a) : env)
ECons [a]
as Env env
env

instance (KnownSymbol x, Show a, Show (Env env)) => Show (Env ((x := a) ': env)) where
  show :: Env ((x ':= a) : env) -> String
show (ECons [a]
a Env env
env) = ObsVar x -> String
forall (x :: Symbol). ObsVar x -> String
varToStr (forall (x :: Symbol). KnownSymbol x => ObsVar x
ObsVar @x) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
":=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [a] -> String
forall a. Show a => a -> String
show [a]
a String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Env env -> String
forall a. Show a => a -> String
show Env env
env
instance Show (Env '[]) where
  show :: Env '[] -> String
show Env '[]
ENil = String
"[]"

instance FindElem x ((x := a) : env) where
  findElem :: Idx x ((x ':= a) : env)
findElem = Int -> Idx x ((x ':= a) : env)
forall {k} {k} (x :: k) (xs :: k). Int -> Idx x xs
Idx Int
0
instance {-# OVERLAPPABLE #-} FindElem x env => FindElem x ((x' := a) : env) where
  findElem :: Idx x ((x' ':= a) : env)
findElem = Int -> Idx x ((x' ':= a) : env)
forall {k} {k} (x :: k) (xs :: k). Int -> Idx x xs
Idx (Int -> Idx x ((x' ':= a) : env))
-> Int -> Idx x ((x' ':= a) : env)
forall a b. (a -> b) -> a -> b
$ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Idx x env -> Int
forall {k} {k} (x :: k) (xs :: k). Idx x xs -> Int
unIdx (Idx x env
forall {k} {k} (x :: k) (xs :: k). FindElem x xs => Idx x xs
findElem :: Idx x env)

-- | Retrieve the type of an observable variable @x@ from an environment @env@
type family LookupType x env where
  LookupType x ((x := a) : env) = a
  LookupType x ((x' := a) : env) = LookupType x env

-- | Specifies that an environment @Env env@ has an observable variable @x@ whose observed values are of type @a@
class (FindElem x env, LookupType x env ~ a)
  => Observable env x a where
  get  :: ObsVar x -> Env env -> [a]
  set  :: ObsVar x -> [a] -> Env env -> Env env

instance (FindElem x env, LookupType x env ~ a)
  => Observable env x a where
  get :: ObsVar x -> Env env -> [a]
get ObsVar x
_ Env env
env =
    let idx :: Int
idx = Idx x env -> Int
forall {k} {k} (x :: k) (xs :: k). Idx x xs -> Int
unIdx (Idx x env -> Int) -> Idx x env -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} {k} (x :: k) (xs :: k). FindElem x xs => Idx x xs
forall (x :: Symbol) (xs :: [Assign Symbol (*)]).
FindElem x xs =>
Idx x xs
findElem @x @env
        f :: Int -> Env env' -> [a]
        f :: forall (env' :: [Assign Symbol (*)]). Int -> Env env' -> [a]
f Int
n (ECons [a]
a Env env
env) = if   Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                            then [a] -> [a]
forall a b. a -> b
unsafeCoerce [a]
a
                            else Int -> Env env -> [a]
forall (env' :: [Assign Symbol (*)]). Int -> Env env' -> [a]
f (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Env env
env
    in  Int -> Env env -> [a]
forall (env' :: [Assign Symbol (*)]). Int -> Env env' -> [a]
f Int
idx Env env
env
  set :: ObsVar x -> [a] -> Env env -> Env env
set ObsVar x
_ [a]
a' Env env
env =
    let idx :: Int
idx = Idx x env -> Int
forall {k} {k} (x :: k) (xs :: k). Idx x xs -> Int
unIdx (Idx x env -> Int) -> Idx x env -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} {k} (x :: k) (xs :: k). FindElem x xs => Idx x xs
forall (x :: Symbol) (xs :: [Assign Symbol (*)]).
FindElem x xs =>
Idx x xs
findElem @x @env
        f :: Int -> Env env' -> Env env'
        f :: forall (env' :: [Assign Symbol (*)]). Int -> Env env' -> Env env'
f Int
n (ECons [a]
a Env env
env) = if   Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                            then [a] -> Env env -> Env ((x ':= a) : env)
forall a (env :: [Assign Symbol (*)]) (x :: Symbol).
[a] -> Env env -> Env ((x ':= a) : env)
ECons ([a] -> [a]
forall a b. a -> b
unsafeCoerce [a]
a') Env env
env
                            else [a] -> Env env -> Env ((x ':= a) : env)
forall a (env :: [Assign Symbol (*)]) (x :: Symbol).
[a] -> Env env -> Env ((x ':= a) : env)
ECons [a]
a (Int -> Env env -> Env env
forall (env' :: [Assign Symbol (*)]). Int -> Env env' -> Env env'
f (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Env env
env)
    in  Int -> Env env -> Env env
forall (env' :: [Assign Symbol (*)]). Int -> Env env' -> Env env'
f Int
idx Env env
env

-- | For each observable variable @x@ in @xs@, construct the constraint @Observable env x a@
type family Observables env (ks :: [Symbol]) a :: Constraint where
  Observables env (x ': xs) a = (Observable env x a, Observables env xs a)
  Observables env '[] a = ()

-- | Check whether an observable variable @x@ is unique in model environment @env@
type family UniqueKey x env where
  UniqueKey x ((x ':= a) : env) = False
  UniqueKey x ((x' ':= a) : env) = UniqueKey x env
  UniqueKey x '[] = True