{-# LANGUAGE FlexibleContexts , TupleSections , MultiWayIf #-} module Linear.Simplex.Primal ( module X , simplexPrimal , nextRow , nextColumn , coeffRatio , pivot , flatten , compensate , diffZip , getSubst , makeSlackVars , populate ) where import Linear.Simplex.Primal.Types as X import Linear.Grammar import Data.List import Data.Maybe import Data.Bifunctor import Control.Monad.State import Control.Applicative -- | Takes an objective function, a set of constraints, and an operation mode, -- then returns a substitution. -- . -- Objective function should be in the form of @Ax + By + Cz + P = 0@, where @P@ is the -- resule, and free in the constraint set. simplexPrimal :: IneqStdForm -> [IneqStdForm] -> [(String, Rational)] simplexPrimal (LteStd _ _) _ = error "Can't run simplex over an inequality - objective function must be ==." simplexPrimal (GteStd _ _) _ = error "Can't run simplex over an inequality - objective function must be ==." simplexPrimal f cs = let -- objective function is not an inequality, so no slacks will be introduced. tableau = populate $ evalState (mapM makeSlackVars (f:cs)) 0 in getSubst $ run tableau where -- list of inequalities includes objective function. run :: [IneqSlack] -> [IneqSlack] run (objective:constrs) = let mCol = nextColumn objective mRow = nextRow constrs =<< mCol in if isNothing mCol || isNothing mRow then objective:constrs -- solved else run $ pivot (fromJust mRow, fromJust mCol) objective constrs -- | finds next column index from objective function nextColumn :: IneqSlack -> Maybe Int nextColumn (IneqSlack (EquStd xs _) _) | minimum (map varCoeff xs) < 0 = findIndex (hasCoeff $ minimum (map varCoeff xs)) xs | otherwise = Nothing -- simplex is finished nextColumn _ = error "`nextColumn` called on an inequality." -- | Finds next pivot row by the smallest ratio in each row. -- Note: row list should be non-empty nextRow :: [IneqSlack] -> Int -> Maybe Int nextRow xs col = if null xs then error "Empty tableau supplied to `nextRow`." else minIdxMaybe $ map (`coeffRatio` col) xs where minIdxMaybe :: Ord a => [Maybe a] -> Maybe Int minIdxMaybe xs = fst <$> foldl go Nothing (catMaybes $ mapIdxs xs) where go Nothing x = Just x go (Just n) x | snd x < snd n = Just x | otherwise = Just n mapIdxs :: Functor f => [f a] -> [f (Int, a)] mapIdxs = go 0 where go n [] = [] go n (fx:xs) = ((n,) <$> fx):go (n+1) xs -- | Computes coefficient ratio to constant, based on a column index. Warning: -- @Int@ parameter must be less than the length of the primal variables. coeffRatio :: IneqSlack -> Int -> Maybe Rational coeffRatio x col = let xs = getStdVars $ slackIneq x xc = getStdConst $ slackIneq x in if | col >= length xs -> error "`coeffRatio` called with a column index larger than the length of variables." | varCoeff (xs !! col) /= 0 -> let ratio = xc / varCoeff (xs !! col) in if ratio < 0 then Nothing -- negative ratio else Just ratio | otherwise -> Nothing -- undefined ratio -- | flattens targeted row to form the identity at it's column coefficient, then -- reduces each non-zero row at this column, via a multiple of this flattened row. -- Heart of the simplex method. Also prepends @objective@ back on @constrs@. pivot :: (Int, Int) -> Objective -> [IneqSlack] -> [IneqSlack] pivot (row,col) objective constrs = let focal = flatten (constrs !! row) col initConstrs = map (\x -> compensate focal x col) $ take row constrs tailConstrs = map (\x -> compensate focal x col) $ drop (row+1) constrs objective' = compensate focal objective col in objective':(initConstrs ++ (focal:tailConstrs)) -- | "Flattens" a row for further processing. flatten :: IneqSlack -> Int -> IneqSlack flatten (IneqSlack x ys) n = let invertedCoeff = recip $ varCoeff $ getStdVars x !! n mapRecip = map $ mapCoeff (invertedCoeff *) newStdIneq = mapStdVars mapRecip $ mapStdConst (invertedCoeff *) x newStdIneq' = mapStdVars (\xs -> replaceNth n (LinVar (varName (xs !! n)) 1) xs) newStdIneq in IneqSlack newStdIneq' $ mapRecip ys -- | Takes the focal row, the row to affect, and the column in question to facilitate -- the sum-oriented part of the pivot. compensate :: IneqSlack -> IneqSlack -> Int -> IneqSlack compensate focal target col = let coeff = varCoeff $ getStdVars (slackIneq target) !! col -- multiply all literals by @coeff@ newFocal = focal { slackIneq = mapStdVars (map $ mapCoeff (coeff *)) $ mapStdConst (coeff *) $ slackIneq focal , slackVars = map (mapCoeff (coeff *)) $ slackVars focal } in target `diffZip` newFocal -- | Note: Must have identical occurrances of variables, and must be @EquStd@. -- subtracts rhs from lhs. diffZip :: IneqSlack -> IneqSlack -> IneqSlack diffZip (IneqSlack (EquStd xs xc) xvs) (IneqSlack (EquStd ys yc) yvs) = IneqSlack (EquStd (zipDiff xs ys) $ xc - yc) $ zipDiff xvs yvs where zipDiff = zipWith (\x y -> LinVar (varName x) $ varCoeff x - varCoeff y) diffZip _ _ = error "`diffZip` called with non `EquStd` input." -- | Extracts resulting data from tableau, excluding junk data getSubst :: [IneqSlack] -> [(String, Rational)] getSubst xs = let (vars, solutions) = transposeTableau xs solutionIdxs = getSolutionIdxs vars in map (`getSolution` solutions) solutionIdxs where -- Takes the list of constants as a paramter getSolution :: (String, Maybe Int) -> [Rational] -> (String, Rational) getSolution (n, Nothing) _ = (n, 0) getSolution (n, Just i) ss = (n, ss !! i) -- dependent on rightward bias getSolutionIdxs :: [(String, [Rational])] -> [(String, Maybe Int)] getSolutionIdxs [] = [] getSolutionIdxs ((v,ds):vs) = (v,findIdx ds):getSolutionIdxs vs where findIdx :: [Rational] -> Maybe Int findIdx ds | maximum ds == 1 -- the index of the only `1` element && length (filter (== 1) ds) == 1 = elemIndex 1 ds | otherwise = Nothing transposeTableau :: [IneqSlack] -> ([(String, [Rational])], [Rational]) transposeTableau = foldl go ([],[]) where go :: ([(String, [Rational])], [Rational]) -> IneqSlack -> ([(String, [Rational])], [Rational]) go (vars,solutions) (IneqSlack x ys) = let newvarsmap :: [(String, [Rational])] newvarsmap = map (\z -> (varName z,[varCoeff z])) $ getStdVars x ++ ys -- result after combining acc & current varsvals :: [[Rational]] varsvals = if null vars then snd $ unzip newvarsmap -- pointwise combine acc vars & current vars else zipWith (++) (snd $ unzip vars) (snd $ unzip newvarsmap) in ( if null vars then newvarsmap else zip (fst $ unzip vars) varsvals , solutions ++ [getStdConst x] ) -- | Also translates @Ax >= Q@ to @-Ax <= -Q@. Ie; result will __exclude__ @GteStd@. makeSlackVars :: MonadState Integer m => IneqStdForm -> m IneqSlack makeSlackVars a@(EquStd _ _) = return $ IneqSlack a [] makeSlackVars (LteStd xs xc) = do suffix <- get put $ suffix + 1 return $ IneqSlack (EquStd xs xc) [LinVar ("s" ++ show suffix) 1] makeSlackVars (GteStd xs xc) = -- invert equation to <= form makeSlackVars $ LteStd (map (mapCoeff ((-1) *)) xs) ((-1) * xc) -- | Fills missing variables. List of inequalities includes objective function. populate :: [IneqSlack] -> [IneqSlack] populate xs = let allnames :: ([String], [String]) allnames = bimap (nub . concat) (nub . concat) $ unzip $ map varNames xs in map (fill allnames) xs where -- left is user-level vars, right are slack vars varNames :: IneqSlack -> ([String], [String]) varNames x = ( map varName $ getStdVars $ slackIneq x , map varName $ slackVars x ) -- populates missing variables with @0@ as coefficients, and sorts the result. fill :: ([String], [String]) -> IneqSlack -> IneqSlack fill (allns,allss) x = let (oldns,oldss) = varNames x in case (allns \\ oldns, allss \\ oldss) of (names,slacks) -> x { slackIneq = -- instantiate empty user-level vars mapStdVars (\xs -> sort $ xs ++ map (flip LinVar 0) names) $ slackIneq x , slackVars = -- instantiate empty slack vars sort $ slackVars x ++ map (flip LinVar 0) slacks } replaceNth :: Int -> a -> [a] -> [a] replaceNth n newVal (x:xs) | n == 0 = newVal:xs | otherwise = x:replaceNth (n-1) newVal xs