module Linear.Simplex.Primal
( module X
, simplexPrimal
, nextRow
, nextColumn
, coeffRatio
, pivot
, flatten
, compensate
, diffZip
, getSubst
, makeSlackVars
, populate
) where
import Linear.Simplex.Primal.Types as X
import Linear.Grammar
import Data.List
import Data.Maybe
import Data.Bifunctor
import Control.Monad.State
import Control.Applicative
simplexPrimal :: IneqStdForm -> [IneqStdForm] -> [(String, Rational)]
simplexPrimal (LteStd _ _) _ = error "Can't run simplex over an inequality - objective function must be ==."
simplexPrimal (GteStd _ _) _ = error "Can't run simplex over an inequality - objective function must be ==."
simplexPrimal f cs =
let
tableau = populate $ evalState (mapM makeSlackVars (f:cs)) 0
in
getSubst $ run tableau
where
run :: [IneqSlack] -> [IneqSlack]
run (objective:constrs) =
let mCol = nextColumn objective
mRow = nextRow constrs =<< mCol
in
if isNothing mCol || isNothing mRow
then objective:constrs
else run $ pivot (fromJust mRow, fromJust mCol) objective constrs
nextColumn :: IneqSlack -> Maybe Int
nextColumn (IneqSlack (EquStd xs _) _)
| minimum (map varCoeff xs) < 0 = findIndex (hasCoeff $ minimum (map varCoeff xs)) xs
| otherwise = Nothing
nextColumn _ = error "`nextColumn` called on an inequality."
nextRow :: [IneqSlack] -> Int -> Maybe Int
nextRow xs col = if null xs
then error "Empty tableau supplied to `nextRow`."
else minIdxMaybe $ map (`coeffRatio` col) xs
where
minIdxMaybe :: Ord a => [Maybe a] -> Maybe Int
minIdxMaybe xs =
fst <$> foldl go Nothing (catMaybes $ mapIdxs xs)
where
go Nothing x = Just x
go (Just n) x | snd x < snd n = Just x
| otherwise = Just n
mapIdxs :: Functor f => [f a] -> [f (Int, a)]
mapIdxs = go 0
where
go n [] = []
go n (fx:xs) = ((n,) <$> fx):go (n+1) xs
coeffRatio :: IneqSlack -> Int -> Maybe Rational
coeffRatio x col =
let xs = getStdVars $ slackIneq x
xc = getStdConst $ slackIneq x
in
if | col >= length xs -> error "`coeffRatio` called with a column index larger than the length of variables."
| varCoeff (xs !! col) /= 0 ->
let ratio = xc / varCoeff (xs !! col) in
if ratio < 0
then Nothing
else Just ratio
| otherwise -> Nothing
pivot :: (Int, Int) -> Objective -> [IneqSlack] -> [IneqSlack]
pivot (row,col) objective constrs =
let
focal = flatten (constrs !! row) col
initConstrs = map (\x -> compensate focal x col) $ take row constrs
tailConstrs = map (\x -> compensate focal x col) $ drop (row+1) constrs
objective' = compensate focal objective col
in
objective':(initConstrs ++ (focal:tailConstrs))
flatten :: IneqSlack -> Int -> IneqSlack
flatten (IneqSlack x ys) n =
let invertedCoeff = recip $ varCoeff $ getStdVars x !! n
mapRecip = map $ mapCoeff (invertedCoeff *)
newStdIneq = mapStdVars mapRecip $ mapStdConst (invertedCoeff *) x
newStdIneq' = mapStdVars (\xs -> replaceNth n (LinVar (varName (xs !! n)) 1) xs) newStdIneq
in
IneqSlack newStdIneq' $ mapRecip ys
compensate :: IneqSlack -> IneqSlack -> Int -> IneqSlack
compensate focal target col =
let
coeff = varCoeff $ getStdVars (slackIneq target) !! col
newFocal = focal { slackIneq = mapStdVars (map $ mapCoeff (coeff *)) $
mapStdConst (coeff *) $ slackIneq focal
, slackVars = map (mapCoeff (coeff *)) $ slackVars focal
}
in
target `diffZip` newFocal
diffZip :: IneqSlack -> IneqSlack -> IneqSlack
diffZip (IneqSlack (EquStd xs xc) xvs) (IneqSlack (EquStd ys yc) yvs) =
IneqSlack (EquStd (zipDiff xs ys) $ xc yc) $ zipDiff xvs yvs
where
zipDiff = zipWith (\x y -> LinVar (varName x) $ varCoeff x varCoeff y)
diffZip _ _ = error "`diffZip` called with non `EquStd` input."
getSubst :: [IneqSlack] -> [(String, Rational)]
getSubst xs =
let (vars, solutions) = transposeTableau xs
solutionIdxs = getSolutionIdxs vars
in
map (`getSolution` solutions) solutionIdxs
where
getSolution :: (String, Maybe Int) -> [Rational] -> (String, Rational)
getSolution (n, Nothing) _ = (n, 0)
getSolution (n, Just i) ss = (n, ss !! i)
getSolutionIdxs :: [(String, [Rational])] -> [(String, Maybe Int)]
getSolutionIdxs [] = []
getSolutionIdxs ((v,ds):vs) = (v,findIdx ds):getSolutionIdxs vs
where
findIdx :: [Rational] -> Maybe Int
findIdx ds | maximum ds == 1
&& length (filter (== 1) ds) == 1 = elemIndex 1 ds
| otherwise = Nothing
transposeTableau :: [IneqSlack] -> ([(String, [Rational])], [Rational])
transposeTableau = foldl go ([],[])
where
go :: ([(String, [Rational])], [Rational])
-> IneqSlack
-> ([(String, [Rational])], [Rational])
go (vars,solutions) (IneqSlack x ys) =
let
newvarsmap :: [(String, [Rational])]
newvarsmap = map (\z -> (varName z,[varCoeff z])) $ getStdVars x ++ ys
varsvals :: [[Rational]]
varsvals = if null vars
then snd $ unzip newvarsmap
else zipWith (++) (snd $ unzip vars) (snd $ unzip newvarsmap)
in
( if null vars
then newvarsmap
else zip (fst $ unzip vars) varsvals
, solutions ++ [getStdConst x]
)
makeSlackVars :: MonadState Integer m => IneqStdForm -> m IneqSlack
makeSlackVars a@(EquStd _ _) = return $ IneqSlack a []
makeSlackVars (LteStd xs xc) = do
suffix <- get
put $ suffix + 1
return $ IneqSlack (EquStd xs xc) [LinVar ("s" ++ show suffix) 1]
makeSlackVars (GteStd xs xc) =
makeSlackVars $ LteStd (map (mapCoeff ((1) *)) xs)
((1) * xc)
populate :: [IneqSlack] -> [IneqSlack]
populate xs =
let
allnames :: ([String], [String])
allnames = bimap (nub . concat) (nub . concat) $ unzip $ map varNames xs
in
map (fill allnames) xs
where
varNames :: IneqSlack -> ([String], [String])
varNames x = ( map varName $ getStdVars $ slackIneq x
, map varName $ slackVars x
)
fill :: ([String], [String]) -> IneqSlack -> IneqSlack
fill (allns,allss) x =
let (oldns,oldss) = varNames x in
case (allns \\ oldns, allss \\ oldss) of
(names,slacks) ->
x { slackIneq =
mapStdVars (\xs -> sort $ xs ++ map (flip LinVar 0) names) $
slackIneq x
, slackVars =
sort $ slackVars x ++ map (flip LinVar 0) slacks
}
replaceNth :: Int -> a -> [a] -> [a]
replaceNth n newVal (x:xs)
| n == 0 = newVal:xs
| otherwise = x:replaceNth (n1) newVal xs