{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
module Numeric.SGD.ParamSet
(
ParamSet(..)
, GPMap
, GAdd
, GSub
, GDiv
, GMul
, GNorm2
) where
import GHC.Generics
import GHC.TypeNats (KnownNat)
import Prelude hiding (div)
import qualified Data.Map.Strict as M
import qualified Numeric.LinearAlgebra.Static as LA
class ParamSet a where
pmap :: (Double -> Double) -> a -> a
zero :: a -> a
zero = pmap (const 0.0)
add :: a -> a -> a
sub :: a -> a -> a
mul :: a -> a -> a
div :: a -> a -> a
norm_2 :: a -> Double
default pmap
:: (Generic a, GPMap (Rep a))
=> (Double -> Double) -> a -> a
pmap = genericPMap
{-# INLINE pmap #-}
default add :: (Generic a, GAdd (Rep a)) => a -> a -> a
add = genericAdd
{-# INLINE add #-}
default sub :: (Generic a, GSub (Rep a)) => a -> a -> a
sub = genericSub
{-# INLINE sub #-}
default mul :: (Generic a, GMul (Rep a)) => a -> a -> a
mul = genericMul
{-# INLINE mul #-}
default div :: (Generic a, GDiv (Rep a)) => a -> a -> a
div = genericDiv
{-# INLINE div #-}
default norm_2 :: (Generic a, GNorm2 (Rep a)) => a -> Double
norm_2 = genericNorm2
{-# INLINE norm_2 #-}
{-# RULES
"ParamSet pmap/pmap" forall f g p. pmap f (pmap g p) = pmap (f . g) p
#-}
genericAdd :: (Generic a, GAdd (Rep a)) => a -> a -> a
genericAdd x y = to $ gadd (from x) (from y)
{-# INLINE genericAdd #-}
genericSub :: (Generic a, GSub (Rep a)) => a -> a -> a
genericSub x y = to $ gsub (from x) (from y)
{-# INLINE genericSub #-}
genericDiv :: (Generic a, GDiv (Rep a)) => a -> a -> a
genericDiv x y = to $ gdiv (from x) (from y)
{-# INLINE genericDiv #-}
genericMul :: (Generic a, GMul (Rep a)) => a -> a -> a
genericMul x y = to $ gmul (from x) (from y)
{-# INLINE genericMul #-}
genericNorm2 :: (Generic a, GNorm2 (Rep a)) => a -> Double
genericNorm2 x = gnorm_2 (from x)
{-# INLINE genericNorm2 #-}
genericPMap :: (Generic a, GPMap (Rep a)) => (Double -> Double) -> a -> a
genericPMap f x = to $ gpmap f (from x)
{-# INLINE genericPMap #-}
class GAdd f where
gadd :: f t -> f t -> f t
instance ParamSet a => GAdd (K1 i a) where
gadd (K1 x) (K1 y) = K1 (add x y)
{-# INLINE gadd #-}
instance (GAdd f, GAdd g) => GAdd (f :*: g) where
gadd (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
where
!x3 = gadd x1 x2
!y3 = gadd y1 y2
{-# INLINE gadd #-}
instance GAdd V1 where
gadd = \case {}
{-# INLINE gadd #-}
instance GAdd U1 where
gadd _ _ = U1
{-# INLINE gadd #-}
instance GAdd f => GAdd (M1 i c f) where
gadd (M1 x) (M1 y) = M1 (gadd x y)
{-# INLINE gadd #-}
class GSub f where
gsub :: f t -> f t -> f t
instance ParamSet a => GSub (K1 i a) where
gsub (K1 x) (K1 y) = K1 (sub x y)
{-# INLINE gsub #-}
instance (GSub f, GSub g) => GSub (f :*: g) where
gsub (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
where
!x3 = gsub x1 x2
!y3 = gsub y1 y2
{-# INLINE gsub #-}
instance GSub V1 where
gsub = \case {}
{-# INLINE gsub #-}
instance GSub U1 where
gsub _ _ = U1
{-# INLINE gsub #-}
instance GSub f => GSub (M1 i c f) where
gsub (M1 x) (M1 y) = M1 (gsub x y)
{-# INLINE gsub #-}
class GMul f where
gmul :: f t -> f t -> f t
instance ParamSet a => GMul (K1 i a) where
gmul (K1 x) (K1 y) = K1 (mul x y)
{-# INLINE gmul #-}
instance (GMul f, GMul g) => GMul (f :*: g) where
gmul (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
where
!x3 = gmul x1 x2
!y3 = gmul y1 y2
{-# INLINE gmul #-}
instance GMul V1 where
gmul = \case {}
{-# INLINE gmul #-}
instance GMul U1 where
gmul _ _ = U1
{-# INLINE gmul #-}
instance GMul f => GMul (M1 i c f) where
gmul (M1 x) (M1 y) = M1 (gmul x y)
{-# INLINE gmul #-}
class GDiv f where
gdiv :: f t -> f t -> f t
instance ParamSet a => GDiv (K1 i a) where
gdiv (K1 x) (K1 y) = K1 (div x y)
{-# INLINE gdiv #-}
instance (GDiv f, GDiv g) => GDiv (f :*: g) where
gdiv (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
where
!x3 = gdiv x1 x2
!y3 = gdiv y1 y2
{-# INLINE gdiv #-}
instance GDiv V1 where
gdiv = \case {}
{-# INLINE gdiv #-}
instance GDiv U1 where
gdiv _ _ = U1
{-# INLINE gdiv #-}
instance GDiv f => GDiv (M1 i c f) where
gdiv (M1 x) (M1 y) = M1 (gdiv x y)
{-# INLINE gdiv #-}
class GNorm2 f where
gnorm_2 :: f t -> Double
instance ParamSet a => GNorm2 (K1 i a) where
gnorm_2 (K1 x) = norm_2 x
{-# INLINE gnorm_2 #-}
instance (GNorm2 f, GNorm2 g) => GNorm2 (f :*: g) where
gnorm_2 (x1 :*: y1) =
sqrt ((x2 ^ (2 :: Int)) + (y2 ^ (2 :: Int)))
where
!x2 = gnorm_2 x1
!y2 = gnorm_2 y1
{-# INLINE gnorm_2 #-}
instance GNorm2 V1 where
gnorm_2 = \case {}
{-# INLINE gnorm_2 #-}
instance GNorm2 U1 where
gnorm_2 _ = 0
{-# INLINE gnorm_2 #-}
instance GNorm2 f => GNorm2 (M1 i c f) where
gnorm_2 (M1 x) = gnorm_2 x
{-# INLINE gnorm_2 #-}
class GPMap f where
gpmap :: (Double -> Double) -> f t -> f t
instance ParamSet a => GPMap (K1 i a) where
gpmap f (K1 x) = K1 (pmap f x)
{-# INLINE gpmap #-}
instance (GPMap f, GPMap g) => GPMap (f :*: g) where
gpmap f (x1 :*: y1) = x2 :*: y2
where
!x2 = gpmap f x1
!y2 = gpmap f y1
{-# INLINE gpmap #-}
instance GPMap V1 where
gpmap _ = \case {}
{-# INLINE gpmap #-}
instance GPMap U1 where
gpmap _ _ = U1
{-# INLINE gpmap #-}
instance GPMap f => GPMap (M1 i c f) where
gpmap f (M1 x) = M1 (gpmap f x)
{-# INLINE gpmap #-}
instance ParamSet Double where
zero = const 0
pmap = id
add = (+)
sub = (-)
mul = (*)
div = (/)
norm_2 = abs
instance (ParamSet a, ParamSet b) => ParamSet (a, b) where
pmap f (x, y) = (pmap f x, pmap f y)
add (x1, y1) (x2, y2) = (x1 `add` x2, y1 `add` y2)
sub (x1, y1) (x2, y2) = (x1 `sub` x2, y1 `sub` y2)
mul (x1, y1) (x2, y2) = (x1 `mul` x2, y1 `mul` y2)
div (x1, y1) (x2, y2) = (x1 `div` x2, y1 `div` y2)
norm_2 (x, y)
= sqrt . sum . map ((^(2::Int)))
$ [norm_2 x, norm_2 y]
instance (KnownNat n) => ParamSet (LA.R n) where
zero = const 0
pmap = LA.dvmap
add = (+)
sub = (-)
mul = (*)
div = (/)
norm_2 = LA.norm_2
instance (KnownNat n, KnownNat m) => ParamSet (LA.L n m) where
zero = const 0
pmap = LA.dmmap
add = (+)
sub = (-)
mul = (*)
div = (/)
norm_2 = LA.norm_2
instance (ParamSet a) => ParamSet (Maybe a) where
zero = fmap zero
pmap = fmap . pmap
add (Just x) (Just y) = Just (add x y)
add _ _ = Nothing
sub (Just x) (Just y) = Just (sub x y)
sub _ _ = Nothing
mul (Just x) (Just y) = Just (mul x y)
mul _ _ = Nothing
div (Just x) (Just y) = Just (div x y)
div _ _ = Nothing
norm_2 = maybe 0 norm_2
instance (Ord k, ParamSet a) => ParamSet (M.Map k a) where
zero = fmap zero
pmap f = fmap (pmap f)
add = M.intersectionWith add
sub = M.intersectionWith sub
mul= M.intersectionWith mul
div= M.intersectionWith div
norm_2 = sqrt . sum . map ((^(2::Int)) . norm_2) . M.elems