{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE InstanceSigs #-}
module Geometry.VertexEnum.LinearCombination
  ( LinearCombination (..)
  , Var
  , newVar
  , VarIndex
  , linearCombination
  , constant
  , cst
  , toRationalLinearCombination
  )
  where
import           Data.AdditiveGroup ( AdditiveGroup(zeroV, negateV, (^+^)) )
import           Data.IntMap.Strict ( IntMap, mergeWithKey )
import qualified Data.IntMap.Strict as IM
import           Data.List          ( intercalate )
import           Data.Tuple         ( swap )
import           Data.VectorSpace   ( linearCombo, VectorSpace(..) )

newtype LinearCombination a = LinearCombination (IntMap a)

toRationalLinearCombination :: Real a => LinearCombination a -> LinearCombination Rational
toRationalLinearCombination :: forall a.
Real a =>
LinearCombination a -> LinearCombination Rational
toRationalLinearCombination (LinearCombination IntMap a
imap) = IntMap Rational -> LinearCombination Rational
forall a. IntMap a -> LinearCombination a
LinearCombination ((a -> Rational) -> IntMap a -> IntMap Rational
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map a -> Rational
forall a. Real a => a -> Rational
toRational IntMap a
imap)

instance (Eq a) => Eq (LinearCombination a) where
  (==) :: LinearCombination a -> LinearCombination a -> Bool
  == :: LinearCombination a -> LinearCombination a -> Bool
(==) (LinearCombination IntMap a
x) (LinearCombination IntMap a
y) = IntMap a
x IntMap a -> IntMap a -> Bool
forall a. Eq a => a -> a -> Bool
== IntMap a
y

instance (Show a) => Show (LinearCombination a) where
  show :: LinearCombination a -> String
  show :: LinearCombination a -> String
show (LinearCombination IntMap a
x) =
    String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
" + " ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
      ((Int, a) -> String) -> [(Int, a)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
i, a
r) -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                      then a -> String
forall a. Show a => a -> String
show a
r
                      else a -> String
forall a. Show a => a -> String
show a
r String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"*x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
          )
          (IntMap a -> [(Int, a)]
forall a. IntMap a -> [(Int, a)]
IM.toAscList IntMap a
x)

instance Num a => AdditiveGroup (LinearCombination a) where
  zeroV :: LinearCombination a
  zeroV :: LinearCombination a
zeroV = IntMap a -> LinearCombination a
forall a. IntMap a -> LinearCombination a
LinearCombination (Int -> a -> IntMap a
forall a. Int -> a -> IntMap a
IM.singleton Int
0 a
0)
  (^+^) :: LinearCombination a -> LinearCombination a -> LinearCombination a
  ^+^ :: LinearCombination a -> LinearCombination a -> LinearCombination a
(^+^) (LinearCombination IntMap a
imap1) (LinearCombination IntMap a
imap2) =
    IntMap a -> LinearCombination a
forall a. IntMap a -> LinearCombination a
LinearCombination
    ((Int -> a -> a -> Maybe a)
-> (IntMap a -> IntMap a)
-> (IntMap a -> IntMap a)
-> IntMap a
-> IntMap a
-> IntMap a
forall a b c.
(Int -> a -> b -> Maybe c)
-> (IntMap a -> IntMap c)
-> (IntMap b -> IntMap c)
-> IntMap a
-> IntMap b
-> IntMap c
mergeWithKey (\Int
_ a
x a
y -> a -> Maybe a
forall a. a -> Maybe a
Just (a
xa -> a -> a
forall a. Num a => a -> a -> a
+a
y)) IntMap a -> IntMap a
forall a. a -> a
id IntMap a -> IntMap a
forall a. a -> a
id IntMap a
imap1 IntMap a
imap2)
  negateV :: LinearCombination a -> LinearCombination a
  negateV :: LinearCombination a -> LinearCombination a
negateV (LinearCombination IntMap a
imap) = IntMap a -> LinearCombination a
forall a. IntMap a -> LinearCombination a
LinearCombination ((a -> a) -> IntMap a -> IntMap a
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map a -> a
forall a. Num a => a -> a
negate IntMap a
imap)

instance Num a => VectorSpace (LinearCombination a) where
  type Scalar (LinearCombination a) = a
  (*^) :: Scalar (LinearCombination a) -> LinearCombination a -> LinearCombination a
  *^ :: Scalar (LinearCombination a)
-> LinearCombination a -> LinearCombination a
(*^) Scalar (LinearCombination a)
lambda (LinearCombination IntMap a
imap) =
    IntMap a -> LinearCombination a
forall a. IntMap a -> LinearCombination a
LinearCombination ((a -> a) -> IntMap a -> IntMap a
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map (Scalar (LinearCombination a)
-> Scalar (LinearCombination a) -> Scalar (LinearCombination a)
forall a. Num a => a -> a -> a
*Scalar (LinearCombination a)
lambda) IntMap a
imap)

type Var a = LinearCombination a
type VarIndex = Int

-- | new variable

newVar :: Num a => VarIndex -> Var a
newVar :: forall a. Num a => Int -> Var a
newVar Int
i = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0
            then IntMap a -> LinearCombination a
forall a. IntMap a -> LinearCombination a
LinearCombination (Int -> a -> IntMap a
forall a. Int -> a -> IntMap a
IM.singleton Int
i a
1)
            else String -> LinearCombination a
forall a. HasCallStack => String -> a
error String
"newVar: negative index"

-- | linear combination from list of terms

linearCombination :: Num a => [(a, Var a)] -> LinearCombination a
linearCombination :: forall a. Num a => [(a, Var a)] -> Var a
linearCombination [(a, LinearCombination a)]
terms = [(LinearCombination a, Scalar (LinearCombination a))]
-> LinearCombination a
forall v. VectorSpace v => [(v, Scalar v)] -> v
linearCombo (((a, LinearCombination a) -> (LinearCombination a, a))
-> [(a, LinearCombination a)] -> [(LinearCombination a, a)]
forall a b. (a -> b) -> [a] -> [b]
map (a, LinearCombination a) -> (LinearCombination a, a)
forall a b. (a, b) -> (b, a)
swap [(a, LinearCombination a)]
terms)
--  LinearCombination (IM.fromListWith (+) (map swap terms))


-- | constant linear combination

constant :: a -> LinearCombination a
constant :: forall a. a -> LinearCombination a
constant a
x = IntMap a -> LinearCombination a
forall a. IntMap a -> LinearCombination a
LinearCombination (Int -> a -> IntMap a
forall a. Int -> a -> IntMap a
IM.singleton Int
0 a
x)

-- | alias for `constant`

cst :: a -> LinearCombination a
cst :: forall a. a -> LinearCombination a
cst = a -> LinearCombination a
forall a. a -> LinearCombination a
constant