module Ipopt.AnyRF where
import Data.Sequence (Seq)
import Data.Vector (Vector)
import Data.Monoid
import Control.Monad.Identity
import qualified Data.VectorSpace as VectorSpace
import Data.VectorSpace (VectorSpace, Scalar)
import qualified Numeric.AD as AD
import qualified Numeric.AD.Types as AD
import qualified Numeric.AD.Internal.Classes as AD
data AnyRF cb = AnyRF (forall a. AnyRFCxt a => Vector a -> cb a)
type AnyRFCxt a = (RealFloat a, VectorSpace a, Scalar a ~ a)
liftOp0 :: (forall a. AnyRFCxt a => a) -> AnyRF Identity
liftOp0 op = AnyRF $ \x -> Identity op
liftOp1 :: (forall a. AnyRFCxt a => a -> a) -> AnyRF Identity -> AnyRF Identity
liftOp1 op (AnyRF a) = AnyRF $ \x -> Identity (op (runIdentity (a x)))
liftOp2 :: (forall a. AnyRFCxt a => a -> a -> a) -> AnyRF Identity -> AnyRF Identity -> AnyRF Identity
liftOp2 op (AnyRF a) (AnyRF b) = AnyRF $ \x -> Identity (runIdentity (a x) `op` runIdentity (b x))
instance Num (AnyRF Identity) where
(+) = liftOp2 (+)
(*) = liftOp2 (*)
() = liftOp2 ()
abs = liftOp1 abs
signum = liftOp1 signum
fromInteger n = liftOp0 (fromInteger n)
instance Fractional (AnyRF Identity) where
(/) = liftOp2 (/)
recip = liftOp1 recip
fromRational n = liftOp0 (fromRational n)
instance Floating (AnyRF Identity) where
pi = liftOp0 pi
exp = liftOp1 exp
sqrt = liftOp1 sqrt
log = liftOp1 log
sin = liftOp1 sin
tan = liftOp1 tan
cos = liftOp1 cos
asin = liftOp1 asin
atan = liftOp1 atan
acos = liftOp1 acos
sinh = liftOp1 sinh
tanh = liftOp1 tanh
cosh = liftOp1 cosh
asinh = liftOp1 asinh
atanh = liftOp1 atanh
acosh = liftOp1 acosh
(**) = liftOp2 (**)
logBase = liftOp2 logBase
instance Real (AnyRF Identity) where
toRational _ = error "Real AnyRF Identity"
instance Ord (AnyRF Identity) where
compare _ = error "anyRF compare"
max = liftOp2 max
min = liftOp2 min
instance Eq (AnyRF Identity) where
(==) = error "anyRF =="
instance RealFrac (AnyRF Identity) where
properFraction = error "properFraction AnyRF"
instance RealFloat (AnyRF Identity) where
isInfinite = error "isInfinite AnyRF"
isNaN = error "isNaN AnyRF"
decodeFloat = error "decodeFloat AnyRF"
floatRange = error "floatRange AnyRF"
isNegativeZero = error "isNegativeZero AnyRF"
isIEEE = error "isIEEE AnyRF"
isDenormalized = error "isDenormalized AnyRF"
floatDigits _ = floatDigits (error "RealFrac AnyRF Identity floatDigits" :: Double)
floatRadix _ = floatRadix (error "RealFrac AnyRF Identity floatRadix" :: Double)
atan2 = liftOp2 atan2
significand = liftOp1 significand
scaleFloat n = liftOp1 (scaleFloat n)
encodeFloat a b = liftOp0 (encodeFloat a b)
instance Monoid (AnyRF Seq) where
AnyRF f `mappend` AnyRF g = AnyRF (f `mappend` g)
mempty = AnyRF mempty
instance VectorSpace.VectorSpace (AnyRF Identity) where
type Scalar (AnyRF Identity) = Double
x *^ v = realToFrac x*v
instance VectorSpace.AdditiveGroup (AnyRF Identity) where
zeroV = liftOp0 0
(^+^) = (+)
negateV = negate
instance (Num a, AD.Mode f) => VectorSpace.AdditiveGroup (AD.AD f a) where
zeroV = AD.zero
(^+^) = (AD.<+>)
negateV = AD.negate1
instance (Num a, AD.Mode f) => VectorSpace.VectorSpace (AD.AD f a) where
type Scalar (AD.AD f a) = AD.AD f a
(*^) = (AD.*!)