{-# LANGUAGE CPP #-} #if __GLASGOW_HASKELL__ >= 708 {-# LANGUAGE DataKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveGeneric #-} {- | Module : Internal.Static Copyright : (c) Alberto Ruiz 2006-14 License : BSD3 Stability : provisional -} module Internal.Static where import GHC.TypeLits import qualified Numeric.LinearAlgebra as LA import Numeric.LinearAlgebra hiding (konst,size,R,C) import Internal.Vector as D hiding (R,C) import Internal.ST import Control.DeepSeq import Data.Proxy(Proxy) import Foreign.Storable(Storable) import Text.Printf import Data.Binary import GHC.Generics (Generic) import Data.Proxy (Proxy(..)) -------------------------------------------------------------------------------- type ℝ = Double type ℂ = Complex Double newtype Dim (n :: Nat) t = Dim t deriving (Show, Generic) instance (KnownNat n, Binary a) => Binary (Dim n a) where get = do k <- get let n = natVal (Proxy :: Proxy n) if n == k then Dim <$> get else fail ("Expected dimension " ++ (show n) ++ ", but found dimension " ++ (show k)) put (Dim x) = do put (natVal (Proxy :: Proxy n)) put x lift1F :: (c t -> c t) -> Dim n (c t) -> Dim n (c t) lift1F f (Dim v) = Dim (f v) lift2F :: (c t -> c t -> c t) -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) lift2F f (Dim u) (Dim v) = Dim (f u v) instance NFData t => NFData (Dim n t) where rnf (Dim (force -> !_)) = () -------------------------------------------------------------------------------- newtype R n = R (Dim n (Vector ℝ)) deriving (Num,Fractional,Floating,Generic,Binary) newtype C n = C (Dim n (Vector ℂ)) deriving (Num,Fractional,Floating,Generic) newtype L m n = L (Dim m (Dim n (Matrix ℝ))) deriving (Generic, Binary) newtype M m n = M (Dim m (Dim n (Matrix ℂ))) deriving (Generic) mkR :: Vector ℝ -> R n mkR = R . Dim mkC :: Vector ℂ -> C n mkC = C . Dim mkL :: Matrix ℝ -> L m n mkL x = L (Dim (Dim x)) mkM :: Matrix ℂ -> M m n mkM x = M (Dim (Dim x)) instance NFData (R n) where rnf (R (force -> !_)) = () instance NFData (C n) where rnf (C (force -> !_)) = () instance NFData (L n m) where rnf (L (force -> !_)) = () instance NFData (M n m) where rnf (M (force -> !_)) = () -------------------------------------------------------------------------------- type V n t = Dim n (Vector t) ud :: Dim n (Vector t) -> Vector t ud (Dim v) = v mkV :: forall (n :: Nat) t . t -> Dim n t mkV = Dim vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) => V n t -> V m t -> V (n+m) t (ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v']) where du = fromIntegral . natVal $ (undefined :: Proxy n) dv = fromIntegral . natVal $ (undefined :: Proxy m) u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du | otherwise = u v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv | otherwise = v gvec2 :: Storable t => t -> t -> V 2 t gvec2 a b = mkV $ runSTVector $ do v <- newUndefinedVector 2 writeVector v 0 a writeVector v 1 b return v gvec3 :: Storable t => t -> t -> t -> V 3 t gvec3 a b c = mkV $ runSTVector $ do v <- newUndefinedVector 3 writeVector v 0 a writeVector v 1 b writeVector v 2 c return v gvec4 :: Storable t => t -> t -> t -> t -> V 4 t gvec4 a b c d = mkV $ runSTVector $ do v <- newUndefinedVector 4 writeVector v 0 a writeVector v 1 b writeVector v 2 c writeVector v 3 d return v gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t gvect st xs' | ok = mkV v | not (null rest) && null (tail rest) = abort (show xs') | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | otherwise = abort (show xs) where (xs,rest) = splitAt d xs' ok = LA.size v == d && null rest v = LA.fromList xs d = fromIntegral . natVal $ (undefined :: Proxy n) abort info = error $ st++" "++show d++" can't be created from elements "++info -------------------------------------------------------------------------------- type GM m n t = Dim m (Dim n (Matrix t)) gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t gmat st xs' | ok = Dim (Dim x) | not (null rest) && null (tail rest) = abort (show xs') | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | otherwise = abort (show xs) where (xs,rest) = splitAt (m'*n') xs' v = LA.fromList xs x = reshape n' v ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n')) m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info -------------------------------------------------------------------------------- class Num t => Sized t s d | s -> t, s -> d where konst :: t -> s unwrap :: s -> d t fromList :: [t] -> s extract :: s -> d t create :: d t -> Maybe s size :: s -> IndexOf d singleV v = LA.size v == 1 singleM m = rows m == 1 && cols m == 1 instance KnownNat n => Sized ℂ (C n) Vector where size _ = fromIntegral . natVal $ (undefined :: Proxy n) konst x = mkC (LA.scalar x) unwrap (C (Dim v)) = v fromList xs = C (gvect "C" xs) extract s@(unwrap -> v) | singleV v = LA.konst (v!0) (size s) | otherwise = v create v | LA.size v == size r = Just r | otherwise = Nothing where r = mkC v :: C n instance KnownNat n => Sized ℝ (R n) Vector where size _ = fromIntegral . natVal $ (undefined :: Proxy n) konst x = mkR (LA.scalar x) unwrap (R (Dim v)) = v fromList xs = R (gvect "R" xs) extract s@(unwrap -> v) | singleV v = LA.konst (v!0) (size s) | otherwise = v create v | LA.size v == size r = Just r | otherwise = Nothing where r = mkR v :: R n instance (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix where size _ = ((fromIntegral . natVal) (undefined :: Proxy m) ,(fromIntegral . natVal) (undefined :: Proxy n)) konst x = mkL (LA.scalar x) fromList xs = L (gmat "L" xs) unwrap (L (Dim (Dim m))) = m extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' extract s@(unwrap -> a) | singleM a = LA.konst (a `atIndex` (0,0)) (size s) | otherwise = a create x | LA.size x == size r = Just r | otherwise = Nothing where r = mkL x :: L m n instance (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix where size _ = ((fromIntegral . natVal) (undefined :: Proxy m) ,(fromIntegral . natVal) (undefined :: Proxy n)) konst x = mkM (LA.scalar x) fromList xs = M (gmat "M" xs) unwrap (M (Dim (Dim m))) = m extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' extract s@(unwrap -> a) | singleM a = LA.konst (a `atIndex` (0,0)) (size s) | otherwise = a create x | LA.size x == size r = Just r | otherwise = Nothing where r = mkM x :: M m n -------------------------------------------------------------------------------- instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) where tr a@(isDiag -> Just _) = mkL (extract a) tr (extract -> a) = mkL (tr a) tr' = tr instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) where tr a@(isDiagC -> Just _) = mkM (extract a) tr (extract -> a) = mkM (tr a) tr' a@(isDiagC -> Just _) = mkM (extract a) tr' (extract -> a) = mkM (tr' a) -------------------------------------------------------------------------------- isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) isDiag (L x) = isDiagg x isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) isDiagC (M x) = isDiagg x isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) isDiagg (Dim (Dim x)) | singleM x = Nothing | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | otherwise = Nothing where m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int v = flatten x z = v `atIndex` 0 y = subVector 1 (LA.size v-1) v ny = LA.size y zeros = LA.konst 0 (max 0 (min m' n' - ny)) yz = vjoin [y,zeros] -------------------------------------------------------------------------------- instance KnownNat n => Show (R n) where show s@(R (Dim v)) | singleV v = "("++show (v!0)++" :: R "++show d++")" | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" where d = size s instance KnownNat n => Show (C n) where show s@(C (Dim v)) | singleV v = "("++show (v!0)++" :: C "++show d++")" | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" where d = size s instance (KnownNat m, KnownNat n) => Show (L m n) where show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' show s@(L (Dim (Dim x))) | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" where (m',n') = size s instance (KnownNat m, KnownNat n) => Show (M m n) where show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' show s@(M (Dim (Dim x))) | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" where (m',n') = size s -------------------------------------------------------------------------------- instance (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) where (+) = lift2F (+) (*) = lift2F (*) (-) = lift2F (-) abs = lift1F abs signum = lift1F signum negate = lift1F negate fromInteger x = Dim (fromInteger x) instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim n (Vector t)) where fromRational x = Dim (fromRational x) (/) = lift2F (/) instance (Fractional t, Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where sin = lift1F sin cos = lift1F cos tan = lift1F tan asin = lift1F asin acos = lift1F acos atan = lift1F atan sinh = lift1F sinh cosh = lift1F cosh tanh = lift1F tanh asinh = lift1F asinh acosh = lift1F acosh atanh = lift1F atanh exp = lift1F exp log = lift1F log sqrt = lift1F sqrt (**) = lift2F (**) pi = Dim pi instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) where (+) = (lift2F . lift2F) (+) (*) = (lift2F . lift2F) (*) (-) = (lift2F . lift2F) (-) abs = (lift1F . lift1F) abs signum = (lift1F . lift1F) signum negate = (lift1F . lift1F) negate fromInteger x = Dim (Dim (fromInteger x)) instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim m (Dim n (Matrix t))) where fromRational x = Dim (Dim (fromRational x)) (/) = (lift2F.lift2F) (/) instance (Num (Vector t), Floating (Matrix t), Fractional t, Numeric t) => Floating (Dim m (Dim n (Matrix t))) where sin = (lift1F . lift1F) sin cos = (lift1F . lift1F) cos tan = (lift1F . lift1F) tan asin = (lift1F . lift1F) asin acos = (lift1F . lift1F) acos atan = (lift1F . lift1F) atan sinh = (lift1F . lift1F) sinh cosh = (lift1F . lift1F) cosh tanh = (lift1F . lift1F) tanh asinh = (lift1F . lift1F) asinh acosh = (lift1F . lift1F) acosh atanh = (lift1F . lift1F) atanh exp = (lift1F . lift1F) exp log = (lift1F . lift1F) log sqrt = (lift1F . lift1F) sqrt (**) = (lift2F . lift2F) (**) pi = Dim (Dim pi) -------------------------------------------------------------------------------- adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) adaptDiag f a b = f a b isFull m = isDiag m == Nothing && not (singleM (unwrap m)) lift1L f (L v) = L (f v) lift2L f (L a) (L b) = L (f a b) lift2LD f = adaptDiag (lift2L f) instance (KnownNat n, KnownNat m) => Num (L n m) where (+) = lift2LD (+) (*) = lift2LD (*) (-) = lift2LD (-) abs = lift1L abs signum = lift1L signum negate = lift1L negate fromInteger = L . Dim . Dim . fromInteger instance (KnownNat n, KnownNat m) => Fractional (L n m) where fromRational = L . Dim . Dim . fromRational (/) = lift2LD (/) instance (KnownNat n, KnownNat m) => Floating (L n m) where sin = lift1L sin cos = lift1L cos tan = lift1L tan asin = lift1L asin acos = lift1L acos atan = lift1L atan sinh = lift1L sinh cosh = lift1L cosh tanh = lift1L tanh asinh = lift1L asinh acosh = lift1L acosh atanh = lift1L atanh exp = lift1L exp log = lift1L log sqrt = lift1L sqrt (**) = lift2LD (**) pi = konst pi -------------------------------------------------------------------------------- adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b)) adaptDiagC f a b = f a b isFullC m = isDiagC m == Nothing && not (singleM (unwrap m)) lift1M f (M v) = M (f v) lift2M f (M a) (M b) = M (f a b) lift2MD f = adaptDiagC (lift2M f) instance (KnownNat n, KnownNat m) => Num (M n m) where (+) = lift2MD (+) (*) = lift2MD (*) (-) = lift2MD (-) abs = lift1M abs signum = lift1M signum negate = lift1M negate fromInteger = M . Dim . Dim . fromInteger instance (KnownNat n, KnownNat m) => Fractional (M n m) where fromRational = M . Dim . Dim . fromRational (/) = lift2MD (/) instance (KnownNat n, KnownNat m) => Floating (M n m) where sin = lift1M sin cos = lift1M cos tan = lift1M tan asin = lift1M asin acos = lift1M acos atan = lift1M atan sinh = lift1M sinh cosh = lift1M cosh tanh = lift1M tanh asinh = lift1M asinh acosh = lift1M acosh atanh = lift1M atanh exp = lift1M exp log = lift1M log sqrt = lift1M sqrt (**) = lift2MD (**) pi = M pi instance Additive (R n) where add = (+) instance Additive (C n) where add = (+) instance (KnownNat m, KnownNat n) => Additive (L m n) where add = (+) instance (KnownNat m, KnownNat n) => Additive (M m n) where add = (+) -------------------------------------------------------------------------------- class Disp t where disp :: Int -> t -> IO () instance (KnownNat m, KnownNat n) => Disp (L m n) where disp n x = do let a = extract x let su = LA.dispf n a printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) instance (KnownNat m, KnownNat n) => Disp (M m n) where disp n x = do let a = extract x let su = LA.dispcf n a printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) instance KnownNat n => Disp (R n) where disp n v = do let su = LA.dispf n (asRow $ extract v) putStr "R " >> putStr (tail . dropWhile (/='x') $ su) instance KnownNat n => Disp (C n) where disp n v = do let su = LA.dispcf n (asRow $ extract v) putStr "C " >> putStr (tail . dropWhile (/='x') $ su) -------------------------------------------------------------------------------- overMatL' :: (KnownNat m, KnownNat n) => (LA.Matrix ℝ -> LA.Matrix ℝ) -> L m n -> L m n overMatL' f = mkL . f . unwrap {-# INLINE overMatL' #-} overMatM' :: (KnownNat m, KnownNat n) => (LA.Matrix ℂ -> LA.Matrix ℂ) -> M m n -> M m n overMatM' f = mkM . f . unwrap {-# INLINE overMatM' #-} #else module Numeric.LinearAlgebra.Static.Internal where #endif