module Data.Symbolic.Diff where
import Data.Symbolic.TypedCode
instance Num a => Num (Code a) where
x + y = op'add `appC` x `appC` y
x y = op'sub `appC` x `appC` y
x * y = op'mul `appC` x `appC` y
negate x = op'negate `appC` x
fromInteger = integerC
instance Fractional a => Fractional (Code a) where
x / y = op'div `appC` x `appC` y
recip x = op'recip `appC` x
fromRational = rationalC
instance Floating a => Floating (Code a) where
pi = op'pi
sin x = op'sin `appC` x
cos x = op'cos `appC` x
testf1 :: Num a => a
testf1 = 1 + 2
testf1' = return (testf1 :: Code Int)
testf1'' = showQC testf1'
test1f x = let y = x * x in y + 1
test1 = test1f (2.0::Float)
test1c = new'diffVar >>= \ (v::Var Float) -> return $ (test1f (var'exp v),v)
test1r = test1c >>= \ (c,v) -> reflectDF v c
test1cp = showQC test1r
diffC :: (Floating a, Floating b) => Var b -> Code a -> Code a
diffC v c | Just _ <- on'litC c = 0
diffC v c | Just ev <- on'varC v c = either (const 1) (const 0) ev
diffC v c | Just (x,y) <- on'2opC op'add c =
(diffC v x) + (diffC v y)
diffC v c | Just (x,y) <- on'2opC op'sub c =
(diffC v x) (diffC v y)
diffC v c | Just (x,y) <- on'2opC op'mul c =
((diffC v x) * y) + (x * (diffC v y))
diffC v c | Just (x,y) <- on'2opC op'div c =
((diffC v x) * y x * (diffC v y)) / (y*y)
diffC v c | Just x <- on'1opC op'negate c =
negate (diffC v x)
diffC v c | Just x <- on'1opC op'recip c =
negate (diffC v x) / (x*x)
diffC v c | Just x <- on'1opC op'sin c =
(diffC v x) * cos x
diffC v c | Just x <- on'1opC op'cos c =
negate ((diffC v x) * sin x)
diffC v c = error $ "Cannot handle code: " ++ show c
test1d = test1c >>= \ (c,v) -> reflectDF v $ diffC v c
test1dp = showQC test1d
simpleC :: Floating a => Var b -> Code a -> Code a
simpleC v c | Just c' <- simpleCL v c = simpleC v c'
simpleC v c = c
simpleCL :: Floating a => Var b -> Code a -> Maybe (Code a)
simpleCL v c | Just _ <- on'litC c = Nothing
simpleCL v c | Just _ <- on'varC v c = Nothing
simpleCL v c | Just (x,y) <- on'2opC op'add c =
simple'recur op'add sadd v x y
where
sadd x y | Just 0 <- on'litRationalC x = Just y
sadd x y | Just 0 <- on'litRationalC y = Just x
sadd x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y)
= Just (fromRational $ x + y)
sadd x y = Nothing
simpleCL v c | Just (x,y) <- on'2opC op'sub c =
simple'recur op'sub ssub v x y
where
ssub x y | Just 0 <- on'litRationalC y = Just x
ssub x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y)
= Just (fromRational $ x y)
ssub x y = Nothing
simpleCL v c | Just (x,y) <- on'2opC op'mul c =
simple'recur op'mul smul v x y
where
smul x y | Just 0 <- on'litRationalC x = Just (fromRational 0)
smul x y | Just 0 <- on'litRationalC y = Just (fromRational 0)
smul x y | Just 1 <- on'litRationalC x = Just y
smul x y | Just 1 <- on'litRationalC y = Just x
smul x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y)
= Just (fromRational $ x * y)
smul x y = Nothing
simpleCL v c | Just (x,y) <- on'2opC op'div c =
simple'recur op'div sdiv v x y
where
sdiv x y | Just 0 <- on'litRationalC x = Just (fromRational 0)
sdiv x y = Nothing
simpleCL v c | Just x <- on'1opC op'negate c =
simple'recur1 op'negate sneg v x
where
sneg x | Just 0 <- on'litRationalC x = Just (fromRational 0)
sneg x = Nothing
simpleCL v c = Nothing
simple'recur op fn v x y =
case (simpleCL v x, simpleCL v y) of
(Nothing,Nothing) -> fn x y
(Just x,Nothing) -> Just (op `appC` x `appC` y)
(Nothing,Just y) -> Just (op `appC` x `appC` y)
(Just x,Just y) -> Just (op `appC` x `appC` y)
simple'recur1 op fn v x =
case simpleCL v x of
Nothing -> fn x
Just x -> Just (op `appC` x)
test1ds = test1c >>= \ (c,v) -> reflectDF v $ simpleC v $ diffC v c
test1dsp = showQC test1ds
diff_fn :: Floating b => (forall a. Floating a => a -> a) -> QCode (b -> b)
diff_fn f =
do
v <- new'diffVar
let body = f (var'exp v)
reflectDF v . simpleC v . diffC v $ body
show_fn :: (forall a. Floating a => a -> a) -> IO ()
show_fn f = showQC (
do
v <- new'diffVar
reflectDF v (f (var'exp v)))
test2f x = foldl (\z c -> x*z + c) 0 [1,2,3]
test2n = test2f (4::Float)
test2s = show_fn test2f
test2ds = showQC (diff_fn test2f)
test11f x = 2*x + 3*x
test11ds = showQC (diff_fn test11f)
test5f x = sin (5*x + pi/2) + cos(1 / x)
test5n = test5f (pi::Float)
test5ds = showQC (diff_fn test5f)
test3f x y = (x*y + (5*x*x)) / y
test4x y = diff_fn (\x -> test3f x (fromIntegral y))
test4y x = diff_fn (test3f (fromInteger x))
test4xds = showQC (test4x 1)
test4yds = showQC (test4y 5)