{-# LANGUAGE Safe #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}

-----------------------------------------------------------------------------
{-|
Module      : Math.Tensor.LinearAlgebra.Scalar
Description : Scalar types for usage as Tensor values.
Copyright   : (c) Nils Alex, 2020
License     : MIT
Maintainer  : nils.alex@fau.de

Scalar types for usage as Tensor values.
-}
-----------------------------------------------------------------------------
module Math.Tensor.LinearAlgebra.Scalar
  ( Lin(..)
  , Poly(..)
  , singletonPoly
  , polyMap
  , getVars
  , shiftVars
  , normalize
  ) where

import qualified Data.IntMap.Strict as IM
  ( IntMap
  , singleton
  , null
  , keys
  , map
  , filter
  , mapKeysMonotonic
  , unionWith
  , findMin
  )

import GHC.Generics (Generic)
import Control.DeepSeq (NFData)

-- |Linear combination represented as mapping from
-- variable number to prefactor.
newtype Lin a = Lin (IM.IntMap a) deriving (Int -> Lin a -> ShowS
[Lin a] -> ShowS
Lin a -> String
(Int -> Lin a -> ShowS)
-> (Lin a -> String) -> ([Lin a] -> ShowS) -> Show (Lin a)
forall a. Show a => Int -> Lin a -> ShowS
forall a. Show a => [Lin a] -> ShowS
forall a. Show a => Lin a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Lin a] -> ShowS
$cshowList :: forall a. Show a => [Lin a] -> ShowS
show :: Lin a -> String
$cshow :: forall a. Show a => Lin a -> String
showsPrec :: Int -> Lin a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Lin a -> ShowS
Show, Eq (Lin a)
Eq (Lin a)
-> (Lin a -> Lin a -> Ordering)
-> (Lin a -> Lin a -> Bool)
-> (Lin a -> Lin a -> Bool)
-> (Lin a -> Lin a -> Bool)
-> (Lin a -> Lin a -> Bool)
-> (Lin a -> Lin a -> Lin a)
-> (Lin a -> Lin a -> Lin a)
-> Ord (Lin a)
Lin a -> Lin a -> Bool
Lin a -> Lin a -> Ordering
Lin a -> Lin a -> Lin a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Lin a)
forall a. Ord a => Lin a -> Lin a -> Bool
forall a. Ord a => Lin a -> Lin a -> Ordering
forall a. Ord a => Lin a -> Lin a -> Lin a
min :: Lin a -> Lin a -> Lin a
$cmin :: forall a. Ord a => Lin a -> Lin a -> Lin a
max :: Lin a -> Lin a -> Lin a
$cmax :: forall a. Ord a => Lin a -> Lin a -> Lin a
>= :: Lin a -> Lin a -> Bool
$c>= :: forall a. Ord a => Lin a -> Lin a -> Bool
> :: Lin a -> Lin a -> Bool
$c> :: forall a. Ord a => Lin a -> Lin a -> Bool
<= :: Lin a -> Lin a -> Bool
$c<= :: forall a. Ord a => Lin a -> Lin a -> Bool
< :: Lin a -> Lin a -> Bool
$c< :: forall a. Ord a => Lin a -> Lin a -> Bool
compare :: Lin a -> Lin a -> Ordering
$ccompare :: forall a. Ord a => Lin a -> Lin a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Lin a)
Ord, Lin a -> Lin a -> Bool
(Lin a -> Lin a -> Bool) -> (Lin a -> Lin a -> Bool) -> Eq (Lin a)
forall a. Eq a => Lin a -> Lin a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Lin a -> Lin a -> Bool
$c/= :: forall a. Eq a => Lin a -> Lin a -> Bool
== :: Lin a -> Lin a -> Bool
$c== :: forall a. Eq a => Lin a -> Lin a -> Bool
Eq, (forall x. Lin a -> Rep (Lin a) x)
-> (forall x. Rep (Lin a) x -> Lin a) -> Generic (Lin a)
forall x. Rep (Lin a) x -> Lin a
forall x. Lin a -> Rep (Lin a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Lin a) x -> Lin a
forall a x. Lin a -> Rep (Lin a) x
$cto :: forall a x. Rep (Lin a) x -> Lin a
$cfrom :: forall a x. Lin a -> Rep (Lin a) x
Generic, Lin a -> ()
(Lin a -> ()) -> NFData (Lin a)
forall a. NFData a => Lin a -> ()
forall a. (a -> ()) -> NFData a
rnf :: Lin a -> ()
$crnf :: forall a. NFData a => Lin a -> ()
NFData)

-- |Polynomial: Can be constant, affine, or something of higher
-- rank which is not yet implemented.
data Poly a = Const !a -- ^ constant value
            | Affine !a !(Lin a) -- ^ constant value plus linear term
            |  NotSupported -- ^ higher rank
  deriving (Int -> Poly a -> ShowS
[Poly a] -> ShowS
Poly a -> String
(Int -> Poly a -> ShowS)
-> (Poly a -> String) -> ([Poly a] -> ShowS) -> Show (Poly a)
forall a. Show a => Int -> Poly a -> ShowS
forall a. Show a => [Poly a] -> ShowS
forall a. Show a => Poly a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Poly a] -> ShowS
$cshowList :: forall a. Show a => [Poly a] -> ShowS
show :: Poly a -> String
$cshow :: forall a. Show a => Poly a -> String
showsPrec :: Int -> Poly a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Poly a -> ShowS
Show, Eq (Poly a)
Eq (Poly a)
-> (Poly a -> Poly a -> Ordering)
-> (Poly a -> Poly a -> Bool)
-> (Poly a -> Poly a -> Bool)
-> (Poly a -> Poly a -> Bool)
-> (Poly a -> Poly a -> Bool)
-> (Poly a -> Poly a -> Poly a)
-> (Poly a -> Poly a -> Poly a)
-> Ord (Poly a)
Poly a -> Poly a -> Bool
Poly a -> Poly a -> Ordering
Poly a -> Poly a -> Poly a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Poly a)
forall a. Ord a => Poly a -> Poly a -> Bool
forall a. Ord a => Poly a -> Poly a -> Ordering
forall a. Ord a => Poly a -> Poly a -> Poly a
min :: Poly a -> Poly a -> Poly a
$cmin :: forall a. Ord a => Poly a -> Poly a -> Poly a
max :: Poly a -> Poly a -> Poly a
$cmax :: forall a. Ord a => Poly a -> Poly a -> Poly a
>= :: Poly a -> Poly a -> Bool
$c>= :: forall a. Ord a => Poly a -> Poly a -> Bool
> :: Poly a -> Poly a -> Bool
$c> :: forall a. Ord a => Poly a -> Poly a -> Bool
<= :: Poly a -> Poly a -> Bool
$c<= :: forall a. Ord a => Poly a -> Poly a -> Bool
< :: Poly a -> Poly a -> Bool
$c< :: forall a. Ord a => Poly a -> Poly a -> Bool
compare :: Poly a -> Poly a -> Ordering
$ccompare :: forall a. Ord a => Poly a -> Poly a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Poly a)
Ord, Poly a -> Poly a -> Bool
(Poly a -> Poly a -> Bool)
-> (Poly a -> Poly a -> Bool) -> Eq (Poly a)
forall a. Eq a => Poly a -> Poly a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Poly a -> Poly a -> Bool
$c/= :: forall a. Eq a => Poly a -> Poly a -> Bool
== :: Poly a -> Poly a -> Bool
$c== :: forall a. Eq a => Poly a -> Poly a -> Bool
Eq, (forall x. Poly a -> Rep (Poly a) x)
-> (forall x. Rep (Poly a) x -> Poly a) -> Generic (Poly a)
forall x. Rep (Poly a) x -> Poly a
forall x. Poly a -> Rep (Poly a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Poly a) x -> Poly a
forall a x. Poly a -> Rep (Poly a) x
$cto :: forall a x. Rep (Poly a) x -> Poly a
$cfrom :: forall a x. Poly a -> Rep (Poly a) x
Generic, Poly a -> ()
(Poly a -> ()) -> NFData (Poly a)
forall a. NFData a => Poly a -> ()
forall a. (a -> ()) -> NFData a
rnf :: Poly a -> ()
$crnf :: forall a. NFData a => Poly a -> ()
NFData)

-- |Produces an affine value \(c + a\cdot x_i\)
singletonPoly :: a       -- ^ constant
              -> Int     -- ^ variable number
              -> a       -- ^ prefactor
              -> Poly a
singletonPoly :: a -> Int -> a -> Poly a
singletonPoly a
a Int
i a
v = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine a
a (Lin a -> Poly a) -> Lin a -> Poly a
forall a b. (a -> b) -> a -> b
$ IntMap a -> Lin a
forall a. IntMap a -> Lin a
Lin (IntMap a -> Lin a) -> IntMap a -> Lin a
forall a b. (a -> b) -> a -> b
$ Int -> a -> IntMap a
forall a. Int -> a -> IntMap a
IM.singleton Int
i a
v

-- |Maps over 'Poly'
polyMap :: (a -> b) -> Poly a -> Poly b
polyMap :: (a -> b) -> Poly a -> Poly b
polyMap a -> b
f (Const a
a) = b -> Poly b
forall a. a -> Poly a
Const (a -> b
f a
a)
polyMap a -> b
f (Affine a
a (Lin IntMap a
lin)) = b -> Lin b -> Poly b
forall a. a -> Lin a -> Poly a
Affine (a -> b
f a
a) (Lin b -> Poly b) -> Lin b -> Poly b
forall a b. (a -> b) -> a -> b
$ IntMap b -> Lin b
forall a. IntMap a -> Lin a
Lin (IntMap b -> Lin b) -> IntMap b -> Lin b
forall a b. (a -> b) -> a -> b
$ (a -> b) -> IntMap a -> IntMap b
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map a -> b
f IntMap a
lin
polyMap a -> b
_ Poly a
_ = Poly b
forall a. Poly a
NotSupported

instance (Num a, Eq a) => Num (Poly a) where
  Const a
a + :: Poly a -> Poly a -> Poly a
+ Const a
b = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> a -> Poly a
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b
  Const a
a + Affine a
b Lin a
lin = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine (a
aa -> a -> a
forall a. Num a => a -> a -> a
+a
b) Lin a
lin
  Affine a
a Lin a
lin + Const a
b = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine (a
aa -> a -> a
forall a. Num a => a -> a -> a
+a
b) Lin a
lin
  Affine a
a (Lin IntMap a
m1) + Affine a
b (Lin IntMap a
m2)
      | IntMap a -> Bool
forall a. IntMap a -> Bool
IM.null IntMap a
m' = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> a -> Poly a
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> a
forall a. Num a => a -> a -> a
+ a
b
      | Bool
otherwise  = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine (a
aa -> a -> a
forall a. Num a => a -> a -> a
+a
b) (IntMap a -> Lin a
forall a. IntMap a -> Lin a
Lin IntMap a
m')
    where
      m' :: IntMap a
m' = (a -> Bool) -> IntMap a -> IntMap a
forall a. (a -> Bool) -> IntMap a -> IntMap a
IM.filter (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/=a
0) (IntMap a -> IntMap a) -> IntMap a -> IntMap a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IM.unionWith a -> a -> a
forall a. Num a => a -> a -> a
(+) IntMap a
m1 IntMap a
m2
  Poly a
NotSupported + Poly a
_ = Poly a
forall a. Poly a
NotSupported 
  Poly a
_ + Poly a
NotSupported = Poly a
forall a. Poly a
NotSupported

  negate :: Poly a -> Poly a
negate = (a -> a) -> Poly a -> Poly a
forall a b. (a -> b) -> Poly a -> Poly b
polyMap a -> a
forall a. Num a => a -> a
negate

  abs :: Poly a -> Poly a
abs (Const a
a) = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> a -> Poly a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Num a => a -> a
abs a
a
  abs Poly a
_         = Poly a
forall a. Poly a
NotSupported

  signum :: Poly a -> Poly a
signum (Const a
a) = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> a -> Poly a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Num a => a -> a
signum a
a
  signum Poly a
_      = Poly a
forall a. Poly a
NotSupported

  fromInteger :: Integer -> Poly a
fromInteger   = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> (Integer -> a) -> Integer -> Poly a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger

  Const a
a * :: Poly a -> Poly a -> Poly a
* Const a
b = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> a -> Poly a
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b
  Const a
a * Affine a
b (Lin IntMap a
lin)
    | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0    = a -> Poly a
forall a. a -> Poly a
Const a
0
    | Bool
otherwise = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine (a
aa -> a -> a
forall a. Num a => a -> a -> a
*a
b) (Lin a -> Poly a) -> Lin a -> Poly a
forall a b. (a -> b) -> a -> b
$ IntMap a -> Lin a
forall a. IntMap a -> Lin a
Lin (IntMap a -> Lin a) -> IntMap a -> Lin a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> IntMap a -> IntMap a
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map (a
aa -> a -> a
forall a. Num a => a -> a -> a
*) IntMap a
lin
  Affine a
a (Lin IntMap a
lin) * Const a
b
    | a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0    = a -> Poly a
forall a. a -> Poly a
Const a
0
    | Bool
otherwise = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine (a
aa -> a -> a
forall a. Num a => a -> a -> a
*a
b) (Lin a -> Poly a) -> Lin a -> Poly a
forall a b. (a -> b) -> a -> b
$ IntMap a -> Lin a
forall a. IntMap a -> Lin a
Lin (IntMap a -> Lin a) -> IntMap a -> Lin a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> IntMap a -> IntMap a
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map (a -> a -> a
forall a. Num a => a -> a -> a
*a
b) IntMap a
lin
  Poly a
_       * Poly a
_            = Poly a
forall a. Poly a
NotSupported

-- |Returns list of variable numbers present in the polynomial.
getVars :: Poly a -> [Int]
getVars :: Poly a -> [Int]
getVars (Const a
_) = []
getVars Poly a
NotSupported = []
getVars (Affine a
_ (Lin IntMap a
lm)) = IntMap a -> [Int]
forall a. IntMap a -> [Int]
IM.keys IntMap a
lm

-- |Shifts variable numbers in the polynomial by a constant value.
shiftVars :: Int -> Poly a -> Poly a
shiftVars :: Int -> Poly a -> Poly a
shiftVars Int
_ (Const a
a) = a -> Poly a
forall a. a -> Poly a
Const a
a
shiftVars Int
_ Poly a
NotSupported = Poly a
forall a. Poly a
NotSupported
shiftVars Int
s (Affine a
a (Lin IntMap a
lin)) =
  a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine a
a (Lin a -> Poly a) -> Lin a -> Poly a
forall a b. (a -> b) -> a -> b
$ IntMap a -> Lin a
forall a. IntMap a -> Lin a
Lin (IntMap a -> Lin a) -> IntMap a -> Lin a
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> IntMap a -> IntMap a
forall a. (Int -> Int) -> IntMap a -> IntMap a
IM.mapKeysMonotonic (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
s) IntMap a
lin

-- |Normalizes a polynomial:
-- \[
--    \mathrm{normalize}(c) = 1 \\
--    \mathrm{normalize}(c + a_1\cdot x_1 + a_2\cdot x_2 + \dots + a_n\cdot x_n) = \frac{c}{a_1} + 1\cdot x_1 + \frac{a_2}{a_1}\cdot x_2 + \dots + \frac{a_n}{a_1}\cdot x_n
-- \]
normalize :: (Fractional a, Eq a) => Poly a -> Poly a
normalize :: Poly a -> Poly a
normalize (Const a
c) = a -> Poly a
forall a. a -> Poly a
Const (a -> Poly a) -> a -> Poly a
forall a b. (a -> b) -> a -> b
$ if a
c a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 then a
0 else a
1
normalize Poly a
NotSupported = Poly a
forall a. Poly a
NotSupported
normalize (Affine a
a (Lin IntMap a
lin)) = a -> Lin a -> Poly a
forall a. a -> Lin a -> Poly a
Affine (a
aa -> a -> a
forall a. Fractional a => a -> a -> a
/a
v) (Lin a -> Poly a) -> Lin a -> Poly a
forall a b. (a -> b) -> a -> b
$ IntMap a -> Lin a
forall a. IntMap a -> Lin a
Lin (IntMap a -> Lin a) -> IntMap a -> Lin a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> IntMap a -> IntMap a
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
v) IntMap a
lin
  where
    (Int
_,a
v) = IntMap a -> (Int, a)
forall a. IntMap a -> (Int, a)
IM.findMin IntMap a
lin