{-# LANGUAGE MultiParamTypeClasses #-} module Linear.Grammar where import Data.List (find) import Data.String -- * User-facing API -- | User-facing abstract syntax tree data LinAst = EVar String | ELit Double | ECoeff LinAst Double | EAdd LinAst LinAst deriving (Show, Eq) -- | Doesn't solve the ridde, but it helps. instance IsString LinAst where fromString = EVar (.+.) :: LinAst -> LinAst -> LinAst (.+.) = EAdd infixr 8 .+. class Coefficient x y where (.*.) :: x -> y -> LinAst infixr 9 .*. instance Coefficient LinAst Double where (.*.) = ECoeff instance Coefficient Double LinAst where (.*.) = flip ECoeff -- | Pushes @ECoeff@ down the tree, leaving @EAdd@ at the top level. -- After using this funciton, all @ECoeff@ constructors @LinAst@ parameter will -- be @EVar@. multLin :: LinAst -> LinAst multLin (EVar n) = EVar n multLin (ELit x) = ELit x multLin (ECoeff e x) = case multLin e of (ELit y) -> ELit (y * x) (EVar n) -> ECoeff (EVar n) x (ECoeff e' y) -> ECoeff e' (y * x) (EAdd e1 e2) -> EAdd (multLin $ ECoeff e1 x) (multLin $ ECoeff e2 x) multLin (EAdd e1 e2) = EAdd (multLin e1) (multLin e2) -- * Linear Expressions data LinVar = LinVar { varName :: String , varCoeff :: Double } deriving (Show, Eq) -- | For sorting tableaus instance Ord LinVar where compare (LinVar x _) (LinVar y _) = compare x y hasName :: LinVar -> LinVar -> Bool hasName (LinVar n _) (LinVar m _) = n == m -- | Linear expressions suited for normal and standard form. data LinExpr = LinExpr { exprCoeffs :: [LinVar] , exprConst :: Double } deriving (Show, Eq) mergeLinExpr :: LinExpr -> LinExpr -> LinExpr mergeLinExpr (LinExpr vs1 x) (LinExpr vs2 y) = LinExpr (vs1 ++ vs2) (x + y) -- | Turns @LinAst@ to @LinExpr@ - should be done /after/ @multLin@. addLin :: LinAst -> LinExpr addLin = go (LinExpr [] 0) where go :: LinExpr -> LinAst -> LinExpr go (LinExpr vs c) (EVar n) = LinExpr (LinVar n 1:vs) c go (LinExpr vs c) (ELit x) = LinExpr vs (c + x) go (LinExpr vs c) (ECoeff (EVar n) x) = LinExpr (LinVar n x:vs) c go le (EAdd e1 e2) = mergeLinExpr (go le e1) (go le e2) -- | Merged duplicate @LinVar@s in a @LinExpr@. Should be used /after/ @addLin@. removeDupLin :: LinExpr -> LinExpr removeDupLin (LinExpr vs c) = LinExpr (foldr go [] vs) c where go :: LinVar -> [LinVar] -> [LinVar] go x [] = [x] go (LinVar n x) xs = case find (hasName $ LinVar n x) xs of Just (LinVar m y) -> LinVar m (y + x):filter (not . hasName (LinVar n x)) xs Nothing -> xs makeLinExpr :: LinAst -> LinExpr makeLinExpr = removeDupLin . addLin . multLin -- * Linear Inequalities data Ineq = Equ LinExpr LinExpr | Lte LinExpr LinExpr deriving (Show, Eq) (.==.) :: LinAst -> LinAst -> Ineq x .==. y = Equ (makeLinExpr x) (makeLinExpr y) infixl 7 .==. (.<=.) :: LinAst -> LinAst -> Ineq x .<=. y = Lte (makeLinExpr x) (makeLinExpr y) infixl 7 .<=. (.=>.) :: LinAst -> LinAst -> Ineq (.=>.) = flip (.<=.) infixl 7 .=>. -- * Standard Form data IneqStdForm = EquStd [LinVar] Double | LteStd [LinVar] Double deriving (Show, Eq) -- TODO: Test idempotency standardize :: Ineq -> Ineq standardize (Equ (LinExpr xs xc) (LinExpr ys yc)) | null xs = Equ (LinExpr [] (xc - yc)) (LinExpr ys 0) | null ys = Equ (LinExpr xs 0) (LinExpr [] (yc - xc)) | otherwise = let ys' = map (\y -> y {varCoeff = varCoeff y * (-1)}) ys in Equ (removeDupLin $ LinExpr (ys' ++ xs) 0) (LinExpr [] (yc - xc)) standardize (Lte (LinExpr xs xc) (LinExpr ys yc)) | null xs = Lte (LinExpr [] (xc - yc)) (LinExpr ys 0) | null ys = Lte (LinExpr xs 0) (LinExpr [] (yc - xc)) | otherwise = let ys' = map (\y -> y {varCoeff = varCoeff y * (-1)}) ys in Lte (removeDupLin $ LinExpr (ys' ++ xs) 0) (LinExpr [] (yc - xc)) standardForm :: Ineq -> IneqStdForm standardForm = go . standardize where go (Equ (LinExpr xs xc) (LinExpr ys yc)) | null xs && yc == 0 = EquStd ys xc | null ys && xc == 0 = EquStd xs yc go (Lte (LinExpr xs xc) (LinExpr ys yc)) | null xs && yc == 0 = LteStd ys xc | null ys && xc == 0 = LteStd xs yc go _ = error "Non-standard Ineq"