module UniqueLogic.ST.Expression (
   T,
   -- * Construct primitive expressions
   constant, fromVariable,
   -- * Operators from rules with small numbers of arguments
   fromRule1, fromRule2, fromRule3,
   -- * Operators from rules with any number of arguments
   Apply, arg, runApply,
   -- * Predicates on expressions
   (=:=),
   -- * Common operators (see also 'Num' and 'Fractional' instances)
   (=!=),
   sqr, sqrt,
   max, maximum,
   pair,
   ) where

import qualified UniqueLogic.ST.Rule as Rule
import qualified UniqueLogic.ST.System as Sys

import Control.Monad.ST (runST, )
import Control.Monad (liftM2, ap, )
import Control.Applicative (Applicative, pure, liftA, liftA2, (<*>), )

-- import Control.Category ((.))
-- import Data.Maybe (Maybe)

-- import Prelude (Double, Eq, Ord, (+), (*), (/))
import qualified Prelude as P
import Prelude hiding (max, maximum, sqrt)


{- |
An expression is defined by a set of equations
and the variable at the top-level.
The value of the expression equals the value of the top variable.
-}
newtype T s a = Cons (Sys.M s (Sys.Variable s a))


{- |
Make a constant expression of a simple numeric value.
-}
constant :: a -> T s a
constant = Cons . Sys.constant

fromVariable :: Sys.Variable s a -> T s a
fromVariable = Cons . return


fromRule1 ::
   (Sys.Variable s a -> Sys.M s ()) ->
   (T s a)
fromRule1 rule = Cons $ do
   xv <- Sys.localVariable
   rule xv
   return xv

fromRule2, _fromRule2 ::
   (Sys.Variable s a -> Sys.Variable s b -> Sys.M s ()) ->
   (T s a -> T s b)
fromRule2 rule (Cons x) = Cons $ do
   xv <- x
   yv <- Sys.localVariable
   rule xv yv
   return yv

fromRule3, _fromRule3 ::
   (Sys.Variable s a -> Sys.Variable s b -> Sys.Variable s c -> Sys.M s ()) ->
   (T s a -> T s b -> T s c)
fromRule3 rule (Cons x) (Cons y) = Cons $ do
   xv <- x
   yv <- y
   zv <- Sys.localVariable
   rule xv yv zv
   return zv


newtype Apply s f = Apply (Sys.M s f)

instance Functor (Apply s) where
   fmap f (Apply a) = Apply $ fmap f a

instance Applicative (Apply s) where
   pure a = Apply $ return a
   Apply f <*> Apply a = Apply $ ap f a


{- |
This function allows to generalize 'fromRule2' and 'fromRule3' to more arguments
using 'Applicative' combinators.

Example:

> fromRule3 rule x y
>    = runApply $ liftA2 rule (arg x) (arg y)
>    = runApply $ pure rule <*> arg x <*> arg y

Building rules with 'arg' provides more granularity
than using auxiliary 'pair' rules!
-}
arg ::
   T s a -> Apply s (Sys.Variable s a)
arg (Cons x) = Apply x

runApply ::
   Apply s (Sys.Variable s a -> Sys.M s ()) ->
   T s a
runApply (Apply rule) = Cons $ do
   f <- rule
   xv <- Sys.localVariable
   f xv
   return xv

{-
examples of how to use 'arg' and 'runApply'
-}
_fromRule2 rule x = runApply $ liftA rule $ arg x
_fromRule3 rule x y = runApply $ liftA2 rule (arg x) (arg y)


instance (P.Fractional a) => P.Num (T s a) where
   fromInteger = constant . fromInteger
   (+) = fromRule3 Rule.add
   (-) = fromRule3 (\z x y -> Rule.add x y z)
   (*) = fromRule3 Rule.mul
   abs = fromRule2 (Sys.assignment2 "abs" abs)
   signum = fromRule2 (Sys.assignment2 "signum" signum)

instance (P.Fractional a) => P.Fractional (T s a) where
   fromRational = constant . fromRational
   (/) = fromRule3 (\z x y -> Rule.mul x y z)

sqr :: P.Floating a => T s a -> T s a
sqr = fromRule2 Rule.square

sqrt :: P.Floating a => T s a -> T s a
sqrt = fromRule2 (flip Rule.square)


infixl 4 =!=

(=!=) :: (Eq a) => T s a -> T s a -> T s a
(=!=) (Cons x) (Cons y) = Cons $ do
   xv <- x
   yv <- y
   Rule.equ xv yv
   return xv

infix 0 =:=

(=:=) :: (Eq a) => T s a -> T s a -> Sys.M s ()
(=:=) (Cons x) (Cons y) = do
   xv <- x
   yv <- y
   Rule.equ xv yv


{- |
We are not able to implement a full Ord instance
including Eq superclass and comparisons,
but we need to compute maxima.
-}
max :: (Ord a) => T s a -> T s a -> T s a
max = fromRule3 Rule.max

maximum :: (Ord a) => [T s a] -> T s a
maximum = foldl1 max


{- |
Construct or decompose a pair.
-}
pair :: T s a -> T s b -> T s (a,b)
pair = fromRule3 Rule.pair


_example :: (Maybe Double, Maybe Double)
_example =
   runST (do
      xv <- Sys.globalVariable
      yv <- Sys.globalVariable
      Sys.solve $ do
         let x = fromVariable xv
             y = fromVariable yv
         x*3 =:= y/2
         5 =:= 2+x
      liftM2
         (,)
         (Sys.query xv)
         (Sys.query yv))