module UniqueLogic.ST.Rule ( -- * Custom rules generic2, generic3, -- * Common rules equ, pair, max, add, mul, square, pow, ) where import qualified UniqueLogic.ST.System as Sys import Control.Monad.ST (runST, ) import Control.Monad (liftM4, ) import qualified Prelude as P import Prelude hiding (max) generic2 :: String -> (b -> a) -> (a -> b) -> Sys.Variable s a -> Sys.Variable s b -> Sys.M s () generic2 name f g x y = sequence_ $ Sys.assignment2 (name++"0") f y x : Sys.assignment2 (name++"1") g x y : [] generic3 :: String -> (b -> c -> a) -> (c -> a -> b) -> (a -> b -> c) -> Sys.Variable s a -> Sys.Variable s b -> Sys.Variable s c -> Sys.M s () generic3 name f g h x y z = sequence_ $ Sys.assignment3 (name++"0") f y z x : Sys.assignment3 (name++"1") g z x y : Sys.assignment3 (name++"2") h x y z : [] equ :: (Eq a) => Sys.Variable s a -> Sys.Variable s a -> Sys.M s () equ = generic2 "Equ" id id max :: (Ord a) => Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s () max = Sys.assignment3 "Max" P.max pair :: Sys.Variable s a -> Sys.Variable s b -> Sys.Variable s (a,b) -> Sys.M s () pair x y xy = Sys.assignment3 "Pair" (,) x y xy >> Sys.assignment2 "Fst" fst xy x >> Sys.assignment2 "Snd" snd xy y add :: (Num a) => Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s () add = generic3 "Add" subtract (-) (+) mul :: (Fractional a) => Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s () mul = generic3 "Mul" (flip (/)) (/) (*) square :: (Floating a) => Sys.Variable s a -> Sys.Variable s a -> Sys.M s () square = generic2 "Square" sqrt (^(2::Int)) pow :: (Floating a) => Sys.Variable s a -> Sys.Variable s a -> Sys.Variable s a -> Sys.M s () pow = generic3 "Pow" (\x y -> y ** recip x) (flip logBase) (**) -- * Example equation system {- | > x=1 > y=2 > z=3 > w=3 > x+y=3 > y*z=6 > z=3 > y^w=8 -} _example :: (Maybe Double, Maybe Double, Maybe Double, Maybe Double) _example = runST (do x <- Sys.globalVariable y <- Sys.globalVariable z <- Sys.globalVariable w <- Sys.globalVariable Sys.solve $ do c3 <- Sys.constant 3 c6 <- Sys.constant 6 c8 <- Sys.constant 8 add x y c3 mul y z c6 equ z c3 pow y w c8 liftM4 (,,,) (Sys.query x) (Sys.query y) (Sys.query z) (Sys.query w))