module Language.CalDims.Action
( Command (..)
, Result (..)
, Mon
, run
, eval
, doCall
, dependencies
, evalDimsPart
, process) where
import Language.CalDims.Helpers
import Language.CalDims.Types
import Language.CalDims.Texts as Texts
import Language.CalDims.State
import Language.CalDims.Expr ()
import Control.Monad.State hiding (State)
import qualified Control.Monad.State as SM
import Control.Monad.Error
import qualified Data.Map as Map
import Data.Ratio
import Data.List (nub)
#ifdef DEBUG
import Debug.Trace
#endif
data Command
= AddFunction Name Args Expr
| AddBasicUnit Name
| AddUnit Name Expr
| Echo String
| Remove Name
| RemoveCascade Name
| Help
| Eval Expr Conversion
| EvalDimsPart Expr
| DebugExpr Expr
| DebugName Name
| DebugDependencies Name
| GetState
| WriteState String deriving Show
instance Pretty Command where pretty = show
data Result = Ok (Maybe String) | StringResult String | EvaledResult (R, Dims) | ExprResult Expr | DimsResult Dims | WriteStringToFile String String deriving Show
instance Pretty Result where
pretty (Ok Nothing) = "ok."
pretty (Ok (Just s)) = "ok, but " ++ s
pretty (StringResult s) = s
pretty (EvaledResult (r, dims)) = let d = pretty dims in
pretty r ++ (if null d then "" else " " ++ pretty dims)
pretty (DimsResult dims) = pretty dims
pretty (ExprResult e) = pretty e
pretty (WriteStringToFile _ fn) = "Writing state to " ++ fn ++ "."
minEval :: Expr -> Mon (R, Dims)
minEval e = do
x <- eval e
dims <- minDims (snd x)
convert x dims
type Mon a = ErrorAndState String State a
process :: Command -> Mon Result
process (AddFunction n args expr) = insert n $ Function args expr
process (AddBasicUnit n) = insert n $ BasicDimension
process (AddUnit n expr) = insert n $ Dimension expr
process (Echo s) = return $ StringResult s
process (Remove n) = remove False n
process (RemoveCascade n) = remove True n
process Help = return $ StringResult (show Texts.helpText)
process (Eval expr conv) = do
res <- eval expr
case conv of
Keep -> return $ EvaledResult res
Explicit dims -> do
res' <- convert res dims
return $ EvaledResult res'
InTermsOf tExpr -> do
r@(n,_) <- minEval tExpr
if (n==0)
then fail "cannot express in terms of 0."
else do
r' <- minEval (Bin Div expr (Evaled r))
return $ ExprResult (Bin Mul (Evaled r') tExpr)
Basic -> do
res' <- convertBasic res
return $ EvaledResult res'
Minimal -> do
dims <- minDims (snd res)
res' <- convert res dims
return $ EvaledResult res'
process (EvalDimsPart expr) = do
res <- evalDimsPart expr
return $ DimsResult res
process (DebugExpr expr) = return $ StringResult $ pretty expr
process (DebugName n) = do
entry <- requireEntry n
return $ StringResult (pretty (n, entry))
process (DebugDependencies n) = do
d <- dependencies n
r <- reverseDependencies n
return $ StringResult (unlines' [pretty d, pretty r])
process GetState = do
state <- get
return $ StringResult (pretty state)
process (WriteState file) = do
state <- get
return $ WriteStringToFile (pretty state) file
instance Pretty State where
pretty s = case (runState $ runErrorT f) s of
(Left e, _) -> error e
(Right l, _) -> unlines' l
where
f = do
state <- get
depList <- liftM nub $ liftM reverse $ liftM (filter (not . isBuiltin state)) $ liftM concat $ sequence $ map dependencies' (Map.keys $ getScope state)
mapM pretty' depList
pretty' :: Name -> Mon String
pretty' n = do
e <- requireEntry n
return $ pretty (n, e)
remove :: Bool -> Name -> Mon Result
remove casc n = do
rd <- reverseDependencies n
case (null rd, casc) of
(True, False) -> remove_ >> return (Ok Nothing)
(_, True) -> removeList_ (n:rd) >> return (Ok Nothing)
(False, False) -> throwError $ "Please remove " ++ pretty rd ++ " first, or remove " ++ pretty n ++ " cascading."
where
remove_ = modi $ Map.delete n
removeList_ l = modi $ Map.filterWithKey (\x _ -> not $ x `elem` l)
modi :: (Scope -> Scope) -> Mon ()
modi f = modify (\x -> x {getScope = f (getScope x)})
insert :: Name -> StateEntry -> Mon Result
insert n newEntry = let
doInsert = modi $ Map.insert n newEntry in do
oldEntry <- getEntry n
warn <- case oldEntry of
Nothing -> doInsert >> return Nothing
Just oldEntry_ -> do
i <- typeIsomorph newEntry oldEntry_
if i
then doInsert >> return (Just "an object has changed")
else throwError $ "There already is an object `" ++ pretty n ++ "` with another type."
return (Ok warn)
dependencies' :: Name -> Mon [Name]
dependencies' n = do
d <- dependencies n
return $ n:d
dependencies :: Name -> Mon [Name]
dependencies n = do
e <- requireEntry n
return $ nub $ case e of
(Function args def) -> f args def
(Dimension expr) -> getDepsExpr expr
(BasicDimension) -> []
(Builtin args def) -> f args def
where f args def = concatMap getDepsArg args ++ getDepsExpr def
getDepsExpr :: Expr -> [Name]
getDepsExpr (Bin _ expr1 expr2) = getDepsExpr expr1 ++ getDepsExpr expr2
getDepsExpr (Uni _ expr) = getDepsExpr expr
getDepsExpr (ArgRef _) = []
getDepsExpr (Call n es) = n : concatMap getDepsExpr es
getDepsExpr (Evaled (_, d)) = getDepsDims d
getDepsArg :: Arg -> [Name]
getDepsArg = getDepsDims . getArgType
getDepsDims :: Dims -> [Name]
getDepsDims (Dims d) = Map.keys d
reverseDependencies :: Name -> Mon [Name]
reverseDependencies n = do
state <- get
let
all_ = Map.keys (getScope state)
filterF name_ = do
d <- dependencies name_
return $ n `elem` d
res <- liftM nub $ filterM filterF all_
if n `elem` all_
then return res
else fail ("no such thing: " ++ unName n)
getEntry :: Name -> Mon (Maybe StateEntry)
getEntry s = do
state <- get
return $ Map.lookup s (getScope state)
requireEntry :: Name -> Mon StateEntry
requireEntry n = do
e <- getEntry n
case e of
Nothing -> throwError $ "No such object: " ++ pretty n
Just sth -> return sth
requireFunction :: Name -> Mon (Args, Expr)
requireFunction n = do
e <- requireEntry n
case e of
Function args expr -> return (args, expr)
Builtin args expr -> return (args, expr)
_ -> throwError $ pretty n ++ " is not a function"
typeIsomorph :: StateEntry -> StateEntry -> Mon Bool
typeIsomorph a b = case (a, b) of
(BasicDimension, BasicDimension) -> return True
(_, BasicDimension) -> return False
(BasicDimension, _) -> return False
(Dimension e1, Dimension e2) -> check1 e1 e2
(Dimension _, _) -> return False
(_, Dimension _) -> return False
(Function a1 e1, Function a2 e2) -> check2 e1 e2 a1 a2
(Builtin a1 e1, Builtin a2 e2) -> check2 e1 e2 a1 a2
(Builtin a1 e1, Function a2 e2) -> check2 e1 e2 a1 a2
(Function a1 e1, Builtin a2 e2) -> check2 e1 e2 a1 a2
check1 :: Expr -> Expr -> Mon Bool
check1 e1 e2 = checkReturnDims (e1, []) (e2, [])
checkReturnDims :: (Expr, Args) -> (Expr, Args) -> Mon Bool
checkReturnDims (e1, a1) (e2, a2) = do
(_, d1) <- doCall' a1 e1 (f a1)
(_, d2) <- doCall' a2 e2 (f a2)
(return d1 #==# return d2)
where
f = map f'
f' :: Arg -> (R, Dims)
f' a = (undefined, getArgType a)
check2 :: Expr -> Expr -> Args -> Args -> Mon Bool
check2 e1 e2 a1 a2 = (return a1 #==# return a2) #&&# checkReturnDims (e1, a1) (e2, a2)
evalDimsPart :: Expr -> Mon Dims
evalDimsPart e = do
res <- eval e
return $ snd res
eval :: Expr -> Mon (R, Dims)
eval (Call n args) = do
args' <- mapM eval args
doCall n args'
eval (ArgRef (Arg _ i _)) = do
state <- get
getArgValues state #!!# i
eval (Bin Add e1 e2) = eval e1 #+# eval e2
eval (Bin Sub e1 e2) = eval e1 #-# eval e2
eval (Bin Mul e1 e2) = eval e1 #*# eval e2
eval (Bin Div e1 e2) = eval e1 #/# eval e2
eval (Bin Exp e1 e2) = eval e1 #^# eval e2
eval (Bin LogBase e1 e2) = eval e1 #~# eval e2
eval (Uni Negate e) = negateM $ eval e
eval (Uni Expot e) = expM $ eval e
eval (Uni Log e) = logM $ eval e
eval (Uni Sin e) = sinM $ eval e
eval (Uni Cos e) = cosM $ eval e
eval (Uni Tan e) = tanM $ eval e
eval (Uni Asin e) = asinM $ eval e
eval (Uni Acos e) = acosM $ eval e
eval (Uni Atan e) = atanM $ eval e
eval (Uni Sinh e) = sinhM $ eval e
eval (Uni Cosh e) = coshM $ eval e
eval (Uni Tanh e) = tanhM $ eval e
eval (Uni Asinh e) = asinhM $ eval e
eval (Uni Acosh e) = acoshM $ eval e
eval (Uni Atanh e) = atanhM $ eval e
eval (Evaled e) = return e
(#!!#) :: Pretty a => [a] -> Int -> Mon a
vals #!!# i = if (i < 0) || (length vals < i + 1)
then throwError $ "Index out of bounds: " ++ pretty vals ++ "!!" ++ pretty i
else return $ vals !! i
doCall :: Name -> [(R, Dims)] -> Mon (R, Dims)
doCall n args = do
(sig, expr) <- requireFunction n
doCall' sig expr args
doCall' :: Args -> Expr -> [(R, Dims)] -> Mon (R, Dims)
doCall' sig expr args = do
state <- get
let old = getArgValues state
modify (\x -> x {getArgValues = args})
when (length sig /= length args) (throwError "wrong number of arguments")
c <- check (zip (map snd args) sig)
when (not c) (throwError "wrong typed argument")
res <- eval expr
modify (\x -> x {getArgValues = old})
return $ res
check, check' :: [(Dims, Arg)] -> Mon Bool
#ifdef DEBUG
check l = trace (show l) (check' l)
#else
check = check'
#endif
check' [] = return True
check' ((d', (Arg _ _ d)):as) = (return d' #==# return d) #&&# (check' as)
instance EqM Arg (ErrorT String (SM.State State)) where
(#==#) a b = do
(Arg _ i1 t1) <- a
(Arg _ i2 t2) <- b
t <- (return t1) #==# (return t2)
return (t && i1 == i2)
instance EqM [Arg] (ErrorT String (SM.State State)) where
(#==#) l1 l2 = let
nullM = liftM null
nNullM = liftM (not . null)
headM = liftM head
tailM = liftM tail in
caseM
[ (nullM l1 #&&# nullM l2, return True)
, (nNullM l1 #&&# nullM l2, return False)
, (nullM l1 #&&# nNullM l2, return False)]
((headM l1 #==# headM l2) #&&# (tailM l1 #==# tailM l2))
caseM :: Monad m => [(m Bool, m a)] -> m a -> m a
caseM [] base = base
caseM ((cond, res):rest) base = ifM cond res (caseM rest base)
ifM :: Monad m => (m Bool) -> m a -> m a -> m a
ifM cond f s = do
cond' <- cond
if cond' then f else s
convertBasic :: (R, Dims) -> Mon (R, Dims)
convertBasic (r1, d1) = do
(r2, d2) <- compileDims d1
return (r1*r2, d2)
convert :: (R, Dims) -> Dims -> Mon (R, Dims)
convert (r, d) target = ifM
((return d) #==# (return target))
(do
(ra, _) <- compileDims d
(rb, _) <- compileDims target
return (r * ra / rb, target))
(throwError "Conversion to incompatible dimensions not possible.")
compileDims :: Dims -> Mon (R, Dims)
compileDims d = do
(r, d_) <- compileDims' d
return (r, Dims $ (Map.filter (/=0) . unDims) d_)
compileDims' :: Dims -> Mon (R, Dims)
compileDims' d = case dims2list d of
[] -> return (1, noDims)
[(n, i)] -> do
e <- requireEntry n
case e of
BasicDimension -> return (1, Dims $ Map.singleton n i)
Dimension expr -> do
(r1, d1) <- eval (Bin Exp expr (Evaled (i, noDims)))
(r2, d2) <- compileDims d1
return (r1*r2, d2)
_ -> throwError $ pretty n ++ " is not a unit"
((n,i):ds) -> do
(r1, d1) <- compileDims (Dims $ Map.singleton n i)
(r2, d2) <- compileDims (Dims $ Map.fromList ds)
d_ <- (return d1) #*# (return d2)
return (r1*r2, d_)
dims2list :: Dims -> [(Name, R)]
dims2list d = Map.toList ((Map.filter (/=0) . unDims) d)
canonicalizeDims :: Dims -> Mon Dims
canonicalizeDims d = do
c <- compileDims d
return $ snd c
instance EqM Dims (ErrorT String (SM.State State)) where
(#==#) d1 d2 = do
d1' <- d1
d2' <- d2
let
c1 = canonicalizeDims d1'
c2 = canonicalizeDims d2'
res <- (liftM2 (\a b -> unDims a == unDims b)) c1 c2
# ifdef DEBUG
d1'' <- c1
d2'' <- c2
return $ trace (show d1'' ++ " #==# " ++ show d2'' ++ " -> " ++ show res) res
# else
return res
# endif
instance MulM Dims (ErrorT String (SM.State State)) where
(#*#) = std (+)
instance DivM Dims (ErrorT String (SM.State State)) where
(#/#) e1 e2 = std (+) e1 (std' negate e2)
dou :: R -> D
dou x = (fromIntegral . numerator) x / (fromIntegral . denominator) x
rat :: D -> R
rat = toRational
instance ExpM (R, Dims) (ErrorT String (SM.State State)) where
a #^# b = do
(ra, da) <- a
(rb, db) <- b
na <- nullD da
nb <- nullD db
let rn = rat $ dou ra ** dou rb
case (na, nb) of
(True, True) -> return (rn, noDims)
(False, True) -> if_ (denominator rb == 1)
(do d <- da *** rb; return (rn, d))
(do
(r, d) <- convertBasic (ra, da)
d' <- d *** rb
return (rat $ dou r ** dou rb, d'))
(_, False) -> throwError $ "Operation requires a number without unit: " ++ pretty (Evaled (ra, da)) ++ " #^# " ++ pretty (Evaled (rb, db))
toI :: Integral a => Ratio a -> Mon a
toI x = let a = numerator x; b = denominator x in
if_ (b == 1) (return a) (throwError "Integer required.")
instance LogBaseM (R, Dims) (ErrorT String (SM.State State)) where
a #~# b = do
(ra, da) <- a
(rb, db) <- b
ifM (nullD da #&&# nullD db) (return $ (rat $ logBase (dou ra) (dou rb), noDims))
(do
ia <- toI ra
ib <- toI rb
case logBaseInt ia ib of
Nothing -> throwError "Integer exponent required for logbase on values with dimensions."
Just re -> do
ifM (da *** (re%1) #==# return db)
(return ((re%1), noDims))
(throwError "Dimensions are in other logbase relation than the numbers"))
(***) :: Dims -> R -> Mon Dims
(***) (Dims d) r = return (Dims $ Map.map (\x -> x * r) d)
std :: Monad m => (R -> R -> R) -> (m Dims -> m Dims -> m Dims)
std f = liftM2 (\ a b -> Dims $ Map.unionWith f (Map.filter (/=0) $ unDims a) (Map.filter (/=0) $ unDims b))
std' :: Monad m => (R -> R) -> (m Dims -> m Dims)
std' f = liftM (\ a -> Dims $ Map.map f (Map.filter (/=0) $ unDims a))
nullD :: Dims -> Mon Bool
nullD d = return d #==# return noDims
run :: State -> Command -> (Either String Result, State)
run state command = (runState $ runErrorT (process command)) state
uniInstance :: (Double -> Double) -> Mon (R, Dims) -> Mon (R, Dims)
uniInstance fr a = do
(r, d) <- a
return (rat . fr . dou $ r, d)
minDims :: Dims -> Mon Dims
minDims d = do
state <- get
let
isDims = (\x -> case x of BasicDimension -> True; (Dimension _) -> True; _ -> False)
neg a = justDo (return noDims #/# return a)
allDims = map (\x -> Dims $ Map.singleton x 1) $ Map.keys $ Map.filter isDims $ getScope state
test :: Dims -> Mon Bool
test d' = return d' #==# return d
merge :: [Dims] -> Dims
merge = foldl (\a b -> justDo (return a #*# return b)) noDims
combineUpTo :: Int -> [a] -> [[a]]
combineUpTo i l = concatMap (\i' -> combination i' l) [1..i]
unique :: [Dims] -> [(Dims, [Dims])]
unique = eqClassesWith (\a b -> justDo (return a #==# return b))
dims = map fst $ unique allDims
justDo f = case fst $ (runState $ runErrorT f) state of Left e -> error e; Right b -> b
dimsList = map merge $ combineUpTo 2 $ (map neg dims ++ dims)
liftM (headWithDefault d) $ filterM test $ dimsList
instance AddM (R, Dims) (ErrorT String (SM.State State)) where
( #+# ) = linear (+)
linear :: (R -> R -> R) -> Mon (R, Dims) -> Mon (R, Dims) -> Mon (R, Dims)
linear f a b = do
(ra, da) <- a
b'@(_, db) <- b
(rb, _) <- convert b' da
ifM (return da #==# return db)
(return (f ra rb, da))
(throwError "Incompatible dimensions cannot be added.")
muldiv :: (R -> R -> R) -> (Mon Dims -> Mon Dims -> Mon Dims)
-> Mon (R, Dims) -> Mon (R, Dims) -> Mon (R, Dims)
muldiv fr fd a b = do
(ra, da) <- a
(rb, db) <- b
d <- fd (return da) (return db)
return (fr ra rb, d)
instance SubM (R, Dims) (ErrorT String (SM.State State)) where
( #-# ) = linear ()
instance MulM (R, Dims) (ErrorT String (SM.State State)) where
( #*# ) = muldiv (*) ( #*# )
instance DivM (R, Dims) (ErrorT String (SM.State State)) where
( #/# ) = muldiv (/) ( #/# )
instance NegateM (R, Dims) (ErrorT String (SM.State State)) where
negateM = liftM (\ (r,d) -> (negate r, d))
instance ExpotM (R, Dims) (ErrorT String (SM.State State)) where
expM = uniInstance exp
instance LogM (R, Dims) (ErrorT String (SM.State State)) where
logM = uniInstance log
instance SinM (R, Dims) (ErrorT String (SM.State State)) where
sinM = uniInstance sin
instance CosM (R, Dims) (ErrorT String (SM.State State)) where
cosM = uniInstance cos
instance TanM (R, Dims) (ErrorT String (SM.State State)) where
tanM = uniInstance tan
instance AsinM (R, Dims) (ErrorT String (SM.State State)) where
asinM = uniInstance asin
instance AcosM (R, Dims) (ErrorT String (SM.State State)) where
acosM = uniInstance acos
instance AtanM (R, Dims) (ErrorT String (SM.State State)) where
atanM = uniInstance atan
instance SinhM (R, Dims) (ErrorT String (SM.State State)) where
sinhM = uniInstance sinh
instance CoshM (R, Dims) (ErrorT String (SM.State State)) where
coshM = uniInstance cosh
instance TanhM (R, Dims) (ErrorT String (SM.State State)) where
tanhM = uniInstance tanh
instance AsinhM (R, Dims) (ErrorT String (SM.State State)) where
asinhM = uniInstance asinh
instance AcoshM (R, Dims) (ErrorT String (SM.State State)) where
acoshM = uniInstance acosh
instance AtanhM (R, Dims) (ErrorT String (SM.State State)) where
atanhM = uniInstance atanh