{-# LANGUAGE Rank2Types, TypeFamilies #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.RAD -- Copyright : (c) Edward Kmett 2010 -- License : BSD3 -- Maintainer : ekmett@gmail.com -- Stability : experimental -- Portability : GHC only -- -- Reverse Mode Automatic Differentiation via overloading to perform -- nonstandard interpretation that replaces original numeric type with -- a bundle that contains a value of the original type and the tape that -- will be used to recover the value of the sensitivity. -- -- This package uses StableNames internally to recover sharing information from -- the tape to avoid combinatorial explosion, and thus runs asymptotically faster -- than it could without such sharing information, but the use of side-effects -- contained herein is benign. -- -- The API has been built to be close to the design of 'Numeric.FAD' from the 'fad' package -- by Barak Pearlmutter and Jeffrey Mark Siskind and contains portions of that code, with minor liberties taken. -- ----------------------------------------------------------------------------- module Numeric.RAD ( -- * First-Order Reverse Mode Automatic Differentiation RAD , lift -- * First-Order Differentiation Operators , diffUU , diffUF , diff2UU , diff2UF -- * Common access patterns , diff , diff2 , jacobian , jacobian2 , grad , grad2 -- * Optimization Routines , zeroNewton , inverseNewton , fixedPointNewton , extremumNewton , argminNaiveGradient ) where import Prelude hiding (mapM) import Control.Applicative (Applicative(..),(<$>)) import Control.Monad.ST import Control.Monad (forM_) import Data.List (foldl') import Data.Array.ST import Data.Array import Data.Ix import Text.Show import Data.Graph (graphFromEdges', topSort, Vertex) import Data.Reify (reifyGraph, MuRef(..)) import qualified Data.Reify.Graph as Reified import Data.Traversable (Traversable, mapM) import System.IO.Unsafe (unsafePerformIO) newtype RAD s a = RAD (Tape a (RAD s a)) data Tape a t = Literal a | Var a Int | Binary a a a t t | Unary a a t instance Show a => Show (RAD s a) where showsPrec d = disc1 (showsPrec d) -- | The 'lift' function injects a primal number into the RAD data type with a 0 derivative. -- If reverse-mode AD numbers formed a monad, then 'lift' would be 'return'. lift :: a -> RAD s a lift = RAD . Literal {-# INLINE lift #-} primal :: RAD s a -> a primal (RAD (Literal y)) = y primal (RAD (Var y _)) = y primal (RAD (Binary y _ _ _ _)) = y primal (RAD (Unary y _ _)) = y {-# INLINE primal #-} var :: a -> Int -> RAD s a var a v = RAD (Var a v) -- TODO: A higher-order data-reify -- mapDeRef :: (Applicative f) => (forall a . Num a => RAD s a -> f (u a)) -> a -> f (Tape a (u a)) instance MuRef (RAD s a) where type DeRef (RAD s a) = Tape a mapDeRef f (RAD (Literal a)) = pure (Literal a) mapDeRef f (RAD (Var a v)) = pure (Var a v) mapDeRef f (RAD (Binary a jb jc x1 x2)) = Binary a jb jc <$> f x1 <*> f x2 mapDeRef f (RAD (Unary a j x)) = Unary a j <$> f x on :: (a -> a -> c) -> (b -> a) -> b -> b -> c on f g a b = f (g a) (g b) instance Eq a => Eq (RAD s a) where (==) = (==) `on` primal instance Ord a => Ord (RAD s a) where compare = compare `on` primal instance Bounded a => Bounded (RAD s a) where maxBound = lift maxBound minBound = lift minBound unary_ :: (a -> a) -> a -> RAD s a -> RAD s a unary_ f _ (RAD (Literal b)) = RAD (Literal (f b)) unary_ f g b = RAD (Unary (disc1 f b) g b) {-# INLINE unary_ #-} unary :: (a -> a) -> (a -> a) -> RAD s a -> RAD s a unary f _ (RAD (Literal b)) = RAD (Literal (f b)) unary f g b = RAD (Unary (disc1 f b) (disc1 g b) b) {-# INLINE unary #-} binary_ :: (a -> a -> a) -> a -> a -> RAD s a -> RAD s a -> RAD s a binary_ f _ _ (RAD (Literal b)) (RAD (Literal c)) = RAD (Literal (f b c)) binary_ f gb gc b c = RAD (Binary (f vb vc) gb gc b c) where vb = primal b; vc = primal c {-# INLINE binary_ #-} -- binary_ with partials binary :: (a -> a -> a) -> (a -> a) -> (a -> a) -> RAD s a -> RAD s a -> RAD s a binary f _ _ (RAD (Literal b)) (RAD (Literal c)) = RAD (Literal (f b c)) binary f gb gc b c = RAD (Binary (f vb vc) (gb vc) (gc vb) b c) where vb = primal b; vc = primal c {-# INLINE binary #-} disc1 :: (a -> b) -> RAD s a -> b disc1 f x = f (primal x) {-# INLINE disc1 #-} disc2 :: (a -> b -> c) -> RAD s a -> RAD s b -> c disc2 f x y = f (primal x) (primal y) {-# INLINE disc2 #-} disc3 :: (a -> b -> c -> d) -> RAD s a -> RAD s b -> RAD s c -> d disc3 f x y z = f (primal x) (primal y) (primal z) {-# INLINE disc3 #-} from :: Num a => RAD s a -> a -> RAD s a from (RAD (Literal a)) x = RAD (Literal x) from a x = RAD (Unary x 1 a) fromBy :: Num a => RAD s a -> RAD s a -> Int -> a -> RAD s a fromBy (RAD (Literal a)) _ _ x = RAD (Literal x) fromBy a delta n x = RAD (Binary x 1 (fromIntegral n) a delta) instance (Num a, Enum a) => Enum (RAD s a) where succ = unary_ succ 1 pred = unary_ pred 1 toEnum = lift . toEnum fromEnum = disc1 fromEnum -- the enumerated results vary with the lower bound and so their derivatives reflect that enumFrom a = from a <$> disc1 enumFrom a enumFromTo a b = from a <$> disc2 enumFromTo a b -- these results vary with respect to both the lower bound and the delta between that and the second argument enumFromThen a b = zipWith (fromBy a delta) [0..] $ disc2 enumFromThen a b where delta = b - a enumFromThenTo a b c = zipWith (fromBy a delta) [0..] $ disc3 enumFromThenTo a b c where delta = b - a instance Num a => Num (RAD s a) where fromInteger = lift . fromInteger (+) = binary_ (+) 1 1 (-) = binary_ (-) 1 (-1) negate = unary_ negate (-1) (*) = binary (*) id id -- incorrect if the argument is complex abs = unary abs signum signum = lift . signum . primal -- notComplex :: Num a => a -> Bool -- notComplex x = s == 0 || s == 1 || s == -1 -- where s = signum x instance Real a => Real (RAD s a) where toRational = disc1 toRational instance RealFloat a => RealFloat (RAD s a) where floatRadix = disc1 floatRadix floatDigits = disc1 floatDigits floatRange = disc1 floatRange decodeFloat = disc1 decodeFloat encodeFloat m e = lift (encodeFloat m e) scaleFloat n = unary_ (scaleFloat n) (scaleFloat n 1) isNaN = disc1 isNaN isInfinite = disc1 isInfinite isDenormalized = disc1 isDenormalized isNegativeZero = disc1 isNegativeZero isIEEE = disc1 isIEEE exponent x | m == 0 = 0 | otherwise = n + floatDigits x where (m,n) = decodeFloat x significand x = unary_ significand (scaleFloat (- floatDigits x) 1) x atan2 (RAD (Literal x)) (RAD (Literal y)) = RAD (Literal (atan2 x y)) atan2 x y = RAD (Binary (atan2 vx vy) (vy*r) (-vx*r) x y) where vx = primal x vy = primal y r = recip (vx^2 + vy^2) instance RealFrac a => RealFrac (RAD s a) where properFraction (RAD (Literal a)) = (w, RAD (Literal p)) where (w, p) = properFraction a properFraction a = (w, RAD (Unary p 1 a)) where (w, p) = properFraction (primal a) truncate = disc1 truncate round = disc1 truncate ceiling = disc1 truncate floor = disc1 truncate instance Fractional a => Fractional (RAD s a) where (/) = binary (/) recip id -- recip = unary recip (const . negate . (^2)) fromRational r = lift $ fromRational r instance Floating a => Floating (RAD s a) where pi = lift pi exp = unary exp exp log = unary log recip sqrt = unary sqrt (recip . (2*) . sqrt) RAD (Literal x) ** RAD (Literal y) = lift (x ** y) x ** y = RAD (Binary vz (vy*vz/vx) (vz*log vx) x y) where vx = primal x vy = primal y vz = vx ** vy sin = unary sin cos cos = unary cos (negate . sin) asin = unary asin (recip . sqrt . (1-) . (^2)) acos = unary acos (negate . recip . sqrt . (1-) . (^2)) atan = unary atan (recip . (1+) . (^2)) sinh = unary sinh cosh cosh = unary cosh sinh asinh = unary asinh (recip . sqrt . (1+) . (^2)) acosh = unary acosh (recip . sqrt . (-1+) . (^2)) atanh = unary atanh (recip . (1-) . (^2)) -- back propagate sensitivities along the tape. backprop :: (Ix t, Ord t, Num a) => (Vertex -> (Tape a t, t, [t])) -> STArray s t a -> Vertex -> ST s () backprop vmap ss v = do case node of Unary _ g b -> do da <- readArray ss i db <- readArray ss b writeArray ss b (db + g*da) Binary _ gb gc b c -> do da <- readArray ss i db <- readArray ss b writeArray ss b (db + gb*da) dc <- readArray ss c writeArray ss c (dc + gc*da) _ -> return () where (node, i, _) = vmap v runTape :: Num a => (Int, Int) -> RAD s a -> Array Int a runTape vbounds tape = accumArray (+) 0 vbounds [ (id, sensitivities ! ix) | (ix, Var _ id) <- xs ] where Reified.Graph xs start = unsafePerformIO $ reifyGraph tape (g, vmap) = graphFromEdges' (edgeSet <$> filter nonConst xs) sensitivities = runSTArray $ do ss <- newArray (sbounds xs) 0 writeArray ss start 1 forM_ (topSort g) $ backprop vmap ss return ss sbounds ((a,_):as) = foldl' (\(lo,hi) (b,_) -> (min lo b, max hi b)) (a,a) as edgeSet (i, t) = (t, i, successors t) nonConst (_, Literal{}) = False nonConst _ = True successors (Unary _ _ b) = [b] successors (Binary _ _ _ b c) = [b,c] successors _ = [] -- this isn't _quite_ right, as it should allow negative zeros to multiply through -- but then we have to know what an isNegativeZero looks like, and that rather limits -- our underlying data types we can permit. -- this approach however, allows for the occasional cycles to be resolved in the -- dependency graph by breaking the cycle on 0 edges. -- test x = y where y = y * 0 + x -- successors (Unary _ db b) = edge db b [] -- successors (Binary _ db dc b c) = edge db b (edge dc c []) -- successors _ = [] -- edge 0 x xs = xs -- edge _ x xs = x : xs d :: Num a => RAD s a -> a d r = runTape (0,0) r ! 0 d2 :: Num a => RAD s a -> (a,a) d2 r = (primal r, d r) -- | The 'diffUU' function calculates the first derivative of a -- scalar-to-scalar function. diffUU :: Num a => (forall s. RAD s a -> RAD s a) -> a -> a diffUU f a = d $ f (var a 0) -- | The 'diffUF' function calculates the first derivative of -- scalar-to-nonscalar function. diffUF :: (Functor f, Num a) => (forall s. RAD s a -> f (RAD s a)) -> a -> f a diffUF f a = d <$> f (var a 0) -- diffMU :: Num a => (forall s. [RAD s a] -> RAD s a) -> [a] -> [a] -> a -- TODO: finish up diffMU and their ilk -- avoid dependency on MTL newtype S a = S { runS :: Int -> (a,Int) } instance Monad S where return a = S (\s -> (a,s)) S g >>= f = S (\s -> let (a,s') = g s in runS (f a) s') bind :: Traversable f => f a -> (f (RAD s a), (Int,Int)) bind xs = (r,(0,s)) where (r,s) = runS (mapM freshVar xs) 0 freshVar a = S (\s -> let s' = s + 1 in s' `seq` (RAD (Var a s), s')) unbind :: Functor f => f (RAD s b) -> Array Int a -> f a unbind xs ys = fmap (\(RAD (Var _ i)) -> ys ! i) xs -- | The 'diff2UU' function calculates the value and derivative, as a -- pair, of a scalar-to-scalar function. diff2UU :: Num a => (forall s. RAD s a -> RAD s a) -> a -> (a, a) diff2UU f a = d2 $ f (var a 0) -- | Note that the signature differs from that used in Numeric.FAD, because while you can always -- 'unzip' an arbitrary functor, not all functors can be zipped. diff2UF :: (Functor f, Num a) => (forall s. RAD s a -> f (RAD s a)) -> a -> f (a, a) diff2UF f a = d2 <$> f (var a 0) -- | The 'diff' function is a synonym for 'diffUU'. diff :: Num a => (forall s. RAD s a -> RAD s a) -> a -> a diff = diffUU -- | The 'diff2' function is a synonym for 'diff2UU'. diff2 :: Num a => (forall s. RAD s a -> RAD s a) -> a -> (a, a) diff2 = diff2UU -- requires the input list to be finite in length grad :: (Traversable f, Num a) => (forall s. f (RAD s a) -> RAD s a) -> f a -> f a grad f as = unbind s (runTape bounds $ f s) where (s,bounds) = bind as -- compute the primal and gradient grad2 :: (Traversable f, Num a) => (forall s. f (RAD s a) -> RAD s a) -> f a -> (a, f a) grad2 f as = (primal r, unbind s (runTape bounds r)) where (s,bounds) = bind as r = f s -- | The 'jacobian' function calcualtes the Jacobian of a -- nonscalar-to-nonscalar function, using m invocations of reverse AD, -- where m is the output dimensionality. When the output dimensionality is -- significantly greater than the input dimensionality you should use 'Numeric.FAD.jacobian' instead. jacobian :: (Traversable f, Functor g, Num a) => (forall s. f (RAD s a) -> g (RAD s a)) -> f a -> g (f a) jacobian f as = unbind s . runTape bounds <$> f s where (s, bounds) = bind as -- | The 'jacobian2' function calcualtes both the result and the Jacobian of a -- nonscalar-to-nonscalar function, using m invocations of reverse AD, -- where m is the output dimensionality. -- 'fmap snd' on the result will recover the result of 'jacobian' jacobian2 :: (Traversable f, Functor g, Num a) => (forall s. f (RAD s a) -> g (RAD s a)) -> f a -> g (a, f a) jacobian2 f as = row <$> f s where (s, bounds) = bind as row a = (primal a, unbind s (runTape bounds a)) -- | The 'zeroNewton' function finds a zero of a scalar function using -- Newton's method; its output is a stream of increasingly accurate -- results. (Modulo the usual caveats.) -- -- TEST CASE: -- @take 10 $ zeroNewton (\\x->x^2-4) 1 -- converge to 2.0@ -- -- TEST CASE -- :module Data.Complex Numeric.RAD -- @take 10 $ zeroNewton ((+1).(^2)) (1 :+ 1) -- converge to (0 :+ 1)@ -- zeroNewton :: Fractional a => (forall s. RAD s a -> RAD s a) -> a -> [a] zeroNewton f x0 = iterate (\x -> let (y,y') = diff2UU f x in x - y/y') x0 -- | The 'inverseNewton' function inverts a scalar function using -- Newton's method; its output is a stream of increasingly accurate -- results. (Modulo the usual caveats.) -- -- TEST CASE: -- @take 10 $ inverseNewton sqrt 1 (sqrt 10) -- converge to 10@ -- inverseNewton :: Fractional a => (forall s. RAD s a -> RAD s a) -> a -> a -> [a] inverseNewton f x0 y = zeroNewton (\x -> f x - lift y) x0 -- | The 'fixedPointNewton' function find a fixedpoint of a scalar -- function using Newton's method; its output is a stream of -- increasingly accurate results. (Modulo the usual caveats.) fixedPointNewton :: Fractional a => (forall s. RAD s a -> RAD s a) -> a -> [a] fixedPointNewton f = zeroNewton (\x -> f x - x) -- | The 'extremumNewton' function finds an extremum of a scalar -- function using Newton's method; produces a stream of increasingly -- accurate results. (Modulo the usual caveats.) extremumNewton :: Fractional a => (forall s t. RAD t (RAD s a) -> RAD t (RAD s a)) -> a -> [a] extremumNewton f x0 = zeroNewton (diffUU f) x0 -- | The 'argminNaiveGradient' function performs a multivariate -- optimization, based on the naive-gradient-descent in the file -- @stalingrad\/examples\/flow-tests\/pre-saddle-1a.vlad@ from the -- VLAD compiler Stalingrad sources. Its output is a stream of -- increasingly accurate results. (Modulo the usual caveats.) -- This is /O(n)/ faster than 'Numeric.FAD.argminNaiveGradient' argminNaiveGradient :: (Fractional a, Ord a) => (forall s. [RAD s a] -> RAD s a) -> [a] -> [[a]] argminNaiveGradient f x0 = let gf = grad f loop x fx gx eta i = -- should check gx = 0 here let x1 = zipWith (+) x (map ((-eta)*) gx) fx1 = lowerFU f x1 gx1 = gf x1 in if eta == 0 then [] else if (fx1 > fx) then loop x fx gx (eta/2) 0 else if all (==0) gx then [] -- else if fx1 == fx then loop x1 fx1 gx1 eta (i+1) else x1:(if (i==10) then loop x1 fx1 gx1 (eta*2) 0 else loop x1 fx1 gx1 eta (i+1)) in loop x0 (lowerFU f x0) (gf x0) 0.1 0 {- lowerUU :: (forall s. RAD s a -> RAD s b) -> a -> b lowerUU f = primal . f . lift lowerUF :: Functor f => (forall s. RAD s a -> f (RAD s b)) -> a -> f b lowerUF f = fmap primal . f . lift lowerFF :: (Functor f, Functor g) => (forall s. f (RAD s a) -> g (RAD s b)) -> f a -> g b lowerFF f = fmap primal . f . fmap lift -} lowerFU :: Functor f => (forall s. f (RAD s a) -> RAD s b) -> f a -> b lowerFU f = primal . f . fmap lift