module UniqueLogic.ST.Rule (
   -- * Custom rules
   generic2, generic3,
   -- * Common rules
   equ, pair, max, add, mul, square, pow,
   ) where

import qualified UniqueLogic.ST.System as Sys

import Control.Monad.ST (runST, )
import Control.Monad (liftM4, )

import qualified Prelude as P
import Prelude hiding (max)


generic2 :: String ->
   (b -> a) -> (a -> b) ->
   Sys.Variable s a -> Sys.Variable s b -> Sys.M s ()
generic2 name f g x y =
   sequence_ $
   Sys.assignment2 (name++"0") f y x :
   Sys.assignment2 (name++"1") g x y :
   []

generic3 :: String ->
   (b -> c -> a) -> (c -> a -> b) -> (a -> b -> c) ->
   Sys.Variable s a -> Sys.Variable s b -> Sys.Variable s c -> Sys.M s ()
generic3 name f g h x y z =
   sequence_ $
   Sys.assignment3 (name++"0") f y z x :
   Sys.assignment3 (name++"1") g z x y :
   Sys.assignment3 (name++"2") h x y z :
   []

equ :: (Eq a) =>
   Sys.Variable s a -> Sys.Variable s a -> Sys.M s ()
equ = generic2 "Equ" id id

max :: (Ord a) =>
   Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s ()
max =
   Sys.assignment3 "Max" P.max

pair ::
   Sys.Variable s a -> Sys.Variable s b -> Sys.Variable s (a,b) -> Sys.M s ()
pair x y xy =
   Sys.assignment3 "Pair" (,) x y xy >>
   Sys.assignment2 "Fst" fst xy x >>
   Sys.assignment2 "Snd" snd xy y

add :: (Num a) =>
   Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s ()
add = generic3 "Add" subtract (-) (+)

mul :: (Fractional a) =>
   Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s ()
mul = generic3 "Mul" (flip (/)) (/) (*)

square :: (Floating a) =>
   Sys.Variable s a -> Sys.Variable s a -> Sys.M s ()
square = generic2 "Square" sqrt (^(2::Int))

pow :: (Floating a) =>
   Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s ()
pow = generic3 "Pow" (\x y -> y ** recip x) (flip logBase) (**)


-- * Example equation system

{- |
> x=1
> y=2
> z=3
> w=3

> x+y=3
> y*z=6
> z=3
> y^w=8
-}
_example :: (Maybe Double, Maybe Double, Maybe Double, Maybe Double)
_example =
   runST (do
      x <- Sys.globalVariable
      y <- Sys.globalVariable
      z <- Sys.globalVariable
      w <- Sys.globalVariable
      Sys.solve $ do
         c3 <- Sys.constant 3
         c6 <- Sys.constant 6
         c8 <- Sys.constant 8
         add x y c3
         mul y z c6
         equ z c3
         pow y w c8
      liftM4
         (,,,)
         (Sys.query x)
         (Sys.query y)
         (Sys.query z)
         (Sys.query w))