{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- |
-- Module      :   Grisette.IR.SymPrim.Data.TabularFun
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.IR.SymPrim.Data.TabularFun
  ( type (=->) (..),
  )
where

import Control.DeepSeq
import Data.Hashable
import GHC.Generics
import Grisette.Core.Data.Class.Function
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
import Language.Haskell.TH.Syntax

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim

-- |
-- Functions as a table. Use the `#` operator to apply the function.
--
-- >>> :set -XTypeOperators
-- >>> let f = TabularFun [(1, 2), (3, 4)] 0 :: Int =-> Int
-- >>> f # 1
-- 2
-- >>> f # 2
-- 0
-- >>> f # 3
-- 4
data (=->) a b = TabularFun {forall a b. (a =-> b) -> [(a, b)]
funcTable :: [(a, b)], forall a b. (a =-> b) -> b
defaultFuncValue :: b}
  deriving (Int -> (a =-> b) -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b. (Show a, Show b) => Int -> (a =-> b) -> ShowS
forall a b. (Show a, Show b) => [a =-> b] -> ShowS
forall a b. (Show a, Show b) => (a =-> b) -> String
showList :: [a =-> b] -> ShowS
$cshowList :: forall a b. (Show a, Show b) => [a =-> b] -> ShowS
show :: (a =-> b) -> String
$cshow :: forall a b. (Show a, Show b) => (a =-> b) -> String
showsPrec :: Int -> (a =-> b) -> ShowS
$cshowsPrec :: forall a b. (Show a, Show b) => Int -> (a =-> b) -> ShowS
Show, (a =-> b) -> (a =-> b) -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall a b. (Eq a, Eq b) => (a =-> b) -> (a =-> b) -> Bool
/= :: (a =-> b) -> (a =-> b) -> Bool
$c/= :: forall a b. (Eq a, Eq b) => (a =-> b) -> (a =-> b) -> Bool
== :: (a =-> b) -> (a =-> b) -> Bool
$c== :: forall a b. (Eq a, Eq b) => (a =-> b) -> (a =-> b) -> Bool
Eq, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a b x. Rep (a =-> b) x -> a =-> b
forall a b x. (a =-> b) -> Rep (a =-> b) x
$cto :: forall a b x. Rep (a =-> b) x -> a =-> b
$cfrom :: forall a b x. (a =-> b) -> Rep (a =-> b) x
Generic, forall a a. Rep1 ((=->) a) a -> a =-> a
forall a a. (a =-> a) -> Rep1 ((=->) a) a
forall k (f :: k -> *).
(forall (a :: k). f a -> Rep1 f a)
-> (forall (a :: k). Rep1 f a -> f a) -> Generic1 f
$cto1 :: forall a a. Rep1 ((=->) a) a -> a =-> a
$cfrom1 :: forall a a. (a =-> a) -> Rep1 ((=->) a) a
Generic1, forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> m Exp
forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> Code m (a =-> b)
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => (a =-> b) -> m Exp
forall (m :: * -> *). Quote m => (a =-> b) -> Code m (a =-> b)
liftTyped :: forall (m :: * -> *). Quote m => (a =-> b) -> Code m (a =-> b)
$cliftTyped :: forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> Code m (a =-> b)
lift :: forall (m :: * -> *). Quote m => (a =-> b) -> m Exp
$clift :: forall a b (m :: * -> *).
(Lift a, Lift b, Quote m) =>
(a =-> b) -> m Exp
Lift, forall a. (a -> ()) -> NFData a
forall a b. (NFData a, NFData b) => (a =-> b) -> ()
rnf :: (a =-> b) -> ()
$crnf :: forall a b. (NFData a, NFData b) => (a =-> b) -> ()
NFData, forall a a. NFData a => (a -> ()) -> (a =-> a) -> ()
forall (f :: * -> *).
(forall a. (a -> ()) -> f a -> ()) -> NFData1 f
liftRnf :: forall a. (a -> ()) -> (a =-> a) -> ()
$cliftRnf :: forall a a. NFData a => (a -> ()) -> (a =-> a) -> ()
NFData1)

infixr 0 =->

instance
  (SupportedPrim a, SupportedPrim b) =>
  SupportedPrim (a =-> b)
  where
  type PrimConstraint (a =-> b) = (SupportedPrim a, SupportedPrim b)
  defaultValue :: a =-> b
defaultValue = forall a b. [(a, b)] -> b -> a =-> b
TabularFun [] (forall t. SupportedPrim t => t
defaultValue @b)

instance (Eq a) => Function (a =-> b) where
  type Arg (a =-> b) = a
  type Ret (a =-> b) = b
  (TabularFun [(a, b)]
table b
d) # :: (a =-> b) -> Arg (a =-> b) -> Ret (a =-> b)
# Arg (a =-> b)
a = [(a, b)] -> b
go [(a, b)]
table
    where
      go :: [(a, b)] -> b
go [] = b
d
      go ((a
av, b
bv) : [(a, b)]
s)
        | Arg (a =-> b)
a forall a. Eq a => a -> a -> Bool
== a
av = b
bv
        | Bool
otherwise = [(a, b)] -> b
go [(a, b)]
s

instance (Hashable a, Hashable b) => Hashable (a =-> b)