-- | Solve systems of equations {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RankNTypes, FlexibleContexts, TypeFamilies #-} module Data.Random.Generics.Internal.Solver where import Control.Applicative import Data.AEq ( (~==) ) import Numeric.AD.Mode import Numeric.AD.Mode.Forward import Numeric.LinearAlgebra import qualified Data.Vector as V import qualified Data.Vector.Storable as S data SolveArgs = SolveArgs { accuracy :: Double , numIterations :: Int } deriving (Eq, Ord, Show) defSolveArgs :: SolveArgs defSolveArgs = SolveArgs 1e-8 20 findZero :: SolveArgs -> (forall s. V.Vector (AD s (Forward R)) -> V.Vector (AD s (Forward R))) -> Vector R -> Maybe (Vector R) findZero SolveArgs{..} f = newton numIterations where newton 0 _ = Nothing newton n x | norm_y == 1/0 = Nothing | norm_y > accuracy = newton (n - 1) (x - jacobian <\> y) | otherwise = Just x where norm_y = norm_Inf y jacobian = (fromRows . V.toList . fmap (V.convert . snd)) yj y = (V.convert . fmap fst) yj yj = jacobian' f (S.convert x) fixedPoint :: SolveArgs -> (forall a. (Mode a, Scalar a ~ R) => V.Vector a -> V.Vector a) -> Vector R -> Maybe (Vector R) fixedPoint args f = findZero args (liftA2 (V.zipWith (-)) f id) -- | Assuming @p . f@ is satisfied only for positive values in some interval -- @(0, r]@, find @f r@. search :: (Double -> a) -> (a -> Bool) -> a search f p = search' e0 (0 : [2 ^ n | n <- [0 .. 100 :: Int]]) where search' y (x : xs@(x' : _)) | p y' = search' y' xs | otherwise = search'' y x x' where y' = f x' search' _ _ = error "Solution not found. Uncontradictable predicate?" search'' y x x' | x ~== x' = y | p y_ = search'' y_ x_ x' | otherwise = search'' y x x_ where x_ = (x + x') / 2 y_ = f x_ e0 = error "Solution not found. Unsatisfiable predicate?"