{-# LANGUAGE PatternGuards #-} {-# LANGUAGE DeriveFunctor #-} module Language.Passage.Term where import Control.Monad(mplus) import qualified Data.IntSet as IS import Data.Ratio(numerator,denominator) import Data.Maybe(fromMaybe) import Language.Passage.Utils import qualified Language.Passage.Lang.LaTeX as LaTeX import Language.Passage.Lang.LaTeX(LaTeX(..)) data Op = TLog | TNeg | TAdd | TMul | TSub | TDiv | TPow | TLogGamma | TExp | TCase -- for "arrays" of det. nodes. 1st arg index, rest array. | TIx | TLit String deriving (Eq, Show, Ord) -- | A term in a stochastic context data Term a = TVar a | TArr a -- A node corresponding to an array | TConst Double | TApp Op [Term a] deriving (Eq, Ord, Show, Functor) {- sizeOf :: Term a -> Int sizeOf (TVar{}) = 1 sizeOf (TConst{}) = 0 sizeOf (TApp _ xs) = 1 + sum (map sizeOf xs) -} type ArrVars = NodeIdx -> IS.IntSet -- | Nodes hanging off of a term leavesOfTerm :: ArrVars -> Term NodeIdx -> IS.IntSet leavesOfTerm _ (TVar b) = IS.singleton b leavesOfTerm arr (TArr b) = arr b leavesOfTerm _ (TConst{}) = IS.empty leavesOfTerm arr (TApp _ ts) = IS.unions (map (leavesOfTerm arr) ts) -- | Returns 'True' for terms that are not applications. isSimpleTerm :: Term a -> Bool isSimpleTerm t = case t of TApp _ _ -> False TArr _ -> True TVar _ -> True TConst _ -> True tcase :: Term a -> [Term a] -> Term a tcase e es = TApp TCase (e:es) tvar :: a -> Term a tvar = TVar tarr :: a -> Term a tarr = TArr tconst :: Double -> Term a tconst = TConst isConst :: Term a -> Maybe Double isConst (TConst d) = Just d isConst _ = Nothing un :: Op -> Term a -> Term a un op x = TApp op [x] bin :: Op -> Term a -> Term a -> Term a bin op x y = TApp op [x,y] termOp :: Term a -> Maybe Op termOp (TApp op _) = Just op termOp _ = Nothing logGamma :: Term a -> Term a logGamma t = case t of TConst a -> TConst (lgg a) _ -> un TLogGamma t where lgg :: Double -> Double lgg z = 0.5 * (log (2*pi) - log z) + z*(log (z+1/(12*z-0.1/z)) - 1) tIx :: Term a -> Term a -> Term a tIx a i = TApp TIx [a,i] -- Split a term into an arrya and indexes. -- Example: a[1][2][3] ---> (a,[1,2,3]) splitArray :: Term a -> (Term a, [Term a]) splitArray t0 = loop t0 [] where loop (TApp TIx [a,i]) is = loop a (i:is) loop t is = (t,is) -------------------------------------------------------------------------------- -- Smarter printing infrastructure precedence :: Op -> (Fixity, Rational) precedence op = case op of TExp -> (Prefix, 100) TLog -> (Prefix, 100) TLogGamma -> (Prefix, 100) TCase -> (Prefix, 100) TNeg -> (Prefix, 100) TAdd -> (Infix ToLeft, 6) TSub -> (Infix ToLeft, 6) TMul -> (Infix ToLeft, 7) TDiv -> (Infix ToLeft, 7) TPow -> (Infix ToRight, 8) TIx -> (Infix ToLeft, 9) TLit s -> (Prefix, 100) instance PP Op where pp op = case op of TCase -> text "choose" TExp -> text "exp" TLogGamma -> text "logGamma" TLog -> text "log" TNeg -> text "-" TAdd -> text "+" TSub -> text "-" TMul -> text "*" TDiv -> text "/" TPow -> text "^" TIx -> text "!" TLit s -> text s wrapLatex :: Op -> Posn -> Maybe Op -> Doc -> Doc wrapLatex _ _ Nothing doc = doc wrapLatex op pos (Just op1) doc = if shouldWrap then wrap else doc where shouldWrap = case (pos,op) of (ToLeft, TPow) -> True (_, TPow) -> False (_, TAdd) -> op1 == TIx (ToRight, TIx) -> op1 == TIx (_, TIx) -> False (ToLeft, TSub) -> False (_, TSub) -> op1 == TSub || op1 == TAdd || op1 == TNeg || op1 == TIx (ToLeft, TMul) -> op1 == TSub || op1 == TAdd (_, TMul) -> op1 == TSub || op1 == TAdd || op1 == TNeg || op1 == TIx (_, TDiv) -> False (_, TLog) -> op1 == TAdd || op1 == TSub || op == TMul (_, TExp) -> op1 == TAdd || op1 == TSub || op == TMul (_, TLogGamma) -> op1 == TAdd || op1 == TSub || op == TMul (_, TNeg ) -> op1 == TAdd || op1 == TSub || op == TMul || op1 == TIx (_, TCase) -> False wrap = text "\\left(" <> doc <> text "\\right)" instance LaTeX a => LaTeX (Term a) where latex (TApp op ts) = case op of TAdd -> dL <+> char '+' <+> dR TSub -> dL <+> char '-' <+> dR TMul -> dL <+> dR TDiv -> LaTeX.frac dL dR TPow -> LaTeX.pow dL dR TLog -> LaTeX.lg <+> dL TExp -> LaTeX.expon <+> dL TLogGamma -> LaTeX.logGamma dL TNeg -> char '-' <> dL TIx -> dL <+> char '!' <+> dR -- XXX: This could be nicer TCase -> let a : as = map latex ts in commaSep as <> char '_' <> braces a TLit s -> (LaTeX.literal s) dL where ds = zipWith pr (ToLeft : ToRight : repeat None) ts pr pos t = wrapLatex op pos (termOp t) (latex t) dL : ds1 = ds dR : _ = ds1 latex (TVar x) = latex x latex (TArr x) = latex x latex (TConst a) = double a ppTerm :: (PP a) => (Posn,Rational) -> Term a -> Doc ppTerm (pos,prec) (TApp op [l,r]) | (Infix dir, myprec) <- precedence op = let this = ppTerm (ToLeft, myprec) l <+> pp op <+> ppTerm (ToRight, myprec) r in if myprec > prec || (myprec == prec && pos == dir) then this else parens this ppTerm (_,p) (TApp op ts) = let this = pp op <+> commaSep [ ppTerm (None,0) t | t <- ts ] in if snd (precedence op) > p then this else parens this ppTerm (_,n) (TArr x) = text "!" <> ppPrec n x ppTerm (_,n) (TVar x) = text "?" <> ppPrec n x ppTerm _ (TConst a) = double a instance PP a => PP (Term a) where ppPrec n = ppTerm (None,n) liftTerm1 :: (Term a -> Term a) -> (Double -> Double) -> Term a -> Term a liftTerm1 _ c (TConst a) = TConst (c a) liftTerm1 s _ a = s a liftTerm2 :: (Term a -> Term a -> Term a) -> (Double -> Double -> Double) -> Term a -> Term a -> Term a liftTerm2 _ c (TConst a) (TConst b) = TConst (a `c` b) liftTerm2 s _ a b = a `s` b tbd1 :: Show a => String -> Term a -> b tbd1 w x = tbd ("Term." ++ w ++ " " ++ show x) tbd2 :: Show a => String -> Term a -> Term a -> b tbd2 w x y = tbd ("Term." ++ w ++ show (x, y)) instance (Eq a, Show a) => Num (Term a) where -- addition TConst 0 + b = b a + TConst 0 = a TConst x + TConst y = TConst (x + y) a + (TApp TAdd [b,c]) = (a + b) + c a + TApp TNeg [b] = a - b TApp TNeg [a] + b = b - a TApp TDiv [a, x] + TApp TDiv [b, y] | x == y = (a+b) / x a + b = liftTerm2 (bin TAdd) (+) a b {- a + b | sizeOf a < sizeOf b = walkAdd a b | True = walkAdd b a -} -- multiplication TConst 0 * _ = TConst 0 TConst 1 * b = b TConst (-1) * b = negate b TConst x * TConst y = TConst (x * y) TConst x * TApp TDiv [TConst y, z] = TConst (x*y) / z TConst x * TApp TDiv [z, TConst y] = TConst (x/y) * z a * b@(TConst _) = b * a -- constants float left TApp TNeg [a] * b = negate (a * b) a * TApp TNeg [b] = negate (a * b) a * b = liftTerm2 (bin TMul) (*) a b -- subtraction a - TConst 0 = a a - TApp TNeg [b] = a + b TApp TNeg [a] - b = negate (a + b) TApp TSub [b, c] - d = TApp TSub [b, c+d] a - b = liftTerm2 (bin TSub) (-) a b -- negation negate (TApp TNeg [x]) = x negate x = liftTerm1 (un TNeg) negate x -- others abs = liftTerm1 (tbd1 "abs") abs signum = liftTerm1 (tbd1 "signum") signum fromInteger = TConst . fromInteger instance (Eq a, Show a) => Fractional (Term a) where -- division a / TConst 1 = a TConst x / TConst y | y /= 0 = TConst (x / y) (TApp TDiv [TConst c1, x]) / TConst c2 = TConst (c1/c2) / x a / b = liftTerm2 (bin TDiv) (/) a b recip x = 1 / x fromRational x = fromInteger (numerator x) / fromInteger (denominator x) instance (Eq a, Show a) => Floating (Term a) where pi = TConst pi exp = liftTerm1 (un TExp) exp sqrt = liftTerm1 (tbd1 "sqrt") sqrt log = liftTerm1 (un TLog) log -- power _ ** TConst 0 = TConst 1 TConst 0 ** _ = TConst 0 a ** b = liftTerm2 (bin TPow) (**) a b -- TBD: Add support for these as needed logBase = liftTerm2 (tbd2 "logBase") logBase sin = liftTerm1 (tbd1 "sin") sin tan = liftTerm1 (tbd1 "tan") tan cos = liftTerm1 (tbd1 "cos") cos asin = liftTerm1 (tbd1 "asin") asin atan = liftTerm1 (tbd1 "atan") atan acos = liftTerm1 (tbd1 "acos") acos sinh = liftTerm1 (tbd1 "sinh") sinh tanh = liftTerm1 (tbd1 "tanh") tanh cosh = liftTerm1 (tbd1 "cosh") cosh asinh = liftTerm1 (tbd1 "asinh") asinh atanh = liftTerm1 (tbd1 "atanh") atanh acosh = liftTerm1 (tbd1 "acosh") acosh -- Add two terms symbolically, trying to simplify as much as possible. -- TODO: This process is currently quite ad-hoc, due to the lack of a -- well defined normal form for terms. Also, developers will likely to -- be able to add their own rules if necessary.. -- -- sAdd returns Nothing if it did Nothing, and Just t, if it was able -- to do something interesting. sAdd :: (Show t, Eq t) => Term t -> Term t -> Maybe (Term t) -- x + x == 2x sAdd x y | x == y = Just (2 * x) -- x - x == 0 sAdd x y | x == -y || -x == y = Just 0 -- a + ab = a (b+1) -- b + ab = b (a+1) sAdd x (TApp TMul [a,b]) | x == a = Just (a * (b + 1)) | x == b = Just (b * (a + 1)) -- x + (q + x) = q + 2*x sAdd x (TApp TAdd [q, y]) | x == y = Just (q + 2*x) | x == q = Just (y + 2*x) -- -x + (q - x) = q - 2*x sAdd (TApp TNeg [x]) (TApp TSub [q, y]) | x == y = Just (q - 2*x) -- ay + ab = a (y+b) -- xb + ab = (x+a) b sAdd (TApp TMul [x,y]) (TApp TMul [a,b]) | x == a = Just (x * (y `add` b)) | y == b = Just ((x `add` a) * y) where add p q = fromMaybe (p + q) (sAdd p q) -- Try associative rule to see if it simplifies things -- We arbitrarily prefer the first rule below if both apply -- x + (a+b) = (x+a) + b -- x + (a+b) = a + (x+b) sAdd x (TApp TAdd [a,b]) = case fmap (+ b) (sAdd x a) `mplus` fmap (a +) (sAdd x b) of Nothing -> Nothing r@(Just (TApp TAdd [t1, t2])) -> maybe r Just (sAdd t1 t2) r -> r -- x + (a - b) = (x+a) - b sAdd x (TApp TSub [a,b]) = fmap (subtract b) (sAdd x a) -- Otherwise we weren't able to do anything smart; so just report Nothing sAdd _ _ = Nothing -- Walk a deep-tree and add a term; taking care of collapsing if possible. walkAdd :: (Eq a, Show a) => Term a -> Term a -> Term a walkAdd a b = case walk b of Just b' -> b' Nothing -> bin TAdd a b where walk t | a == t = Just (2*t) walk (TApp TMul [c, t]) | a == t = Just $ (c+1) * t walk t = case t of TApp TAdd [x, y] -> case walk x of Just x' -> Just $ bin TAdd x' y Nothing -> case walk y of Just y' -> Just $ bin TAdd x y' Nothing -> Nothing _ -> Nothing -- Split a term into its summands. -- We use distributivity laws to try to get as smaller terms as possible, -- in the hope that they might contain fewer varaibles. -- (i.e., we convert the term to a sum-of-products). summands :: (Eq a, Show a) => Term a -> [Term a] summands te = loop [te] where loop [] = [] loop (t : ts) = case t of TApp op [x,y] -- x+y; split them | op == TAdd -> loop (x : y : ts) -- x-y; split them | op == TSub -> loop (x : negate y : ts) -- x*(y1+y2+..+yN); distribute | op == TMul, composite ys -> loop (map (x *) ys ++ ts) -- (x1+x2+..+xN)*y; distribute | op == TMul, composite xs -> loop (map (* y) xs ++ ts) -- (x1+x2+..+xN)/y; distribute | op == TDiv, composite xs -> loop (map (/ y) xs ++ ts) where xs = summands x ys = summands y -- Does it have at least two terms? composite (_ : _ : _) = True composite _ = False -- Otherwise keep t; and factor the rest _ -> t : loop ts -- Split a term into a product of two terms, with the property that -- only the first term contains the given variable. factorVar :: ArrVars -> NodeIdx -> Term NodeIdx -> (Term NodeIdx, Term NodeIdx) factorVar arr x t = case t of TApp op ts | op == TNeg, [t1] <- ts -> let (a,b) = factorVar arr x t1 in (a, negate b) | op == TMul -> let ([a,b],bs) = unzip $ map (factorVar arr x) ts in (opt_mul a b, product bs) | op == TDiv -> let ([a,b],[c,d]) = unzip $ map (factorVar arr x) ts in (a / b, c / d) _ | x `IS.member` leavesOfTerm arr t -> (t, 1) _ -> (1,t) where opt_mul a b = fromMaybe (a * b) (mul a b) mul a b | a == b = Just (a ** 2) mul (TApp TPow [a,b]) (TApp TPow [a1, c]) | a == a1 = Just (a ** (b + c)) mul a (TApp TPow [b, c]) | a == b = Just (a ** (1 + c)) mul (TApp TPow [b,c]) a | a == b = Just (a ** (1 + c)) mul a (TApp TMul [b,c]) = case mul a b of Just b1 -> Just (opt_mul b1 c) Nothing -> case mul a c of Just c1 -> Just (opt_mul b c1) Nothing -> Nothing mul a (TApp TDiv [b,c]) = Just (opt_mul a b / c) mul _ _ = Nothing