{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies, TypeFamilyDependencies #-} {-# LANGUAGE MultiParamTypeClasses, MagicHash #-} {-# LANGUAGE KindSignatures, DataKinds #-} {-# LANGUAGE TypeOperators, FlexibleInstances, ScopedTypeVariables #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.Tensor -- Copyright : (c) Artem Chirkin -- License : MIT -- -- Maintainer : chirkin@arch.ethz.ch -- -- ----------------------------------------------------------------------------- module Numeric.Tensor where import GHC.TypeLits import GHC.Prim import Data.Proxy data Dim (ds :: [Nat]) = Dim class Dimensions (ds :: [Nat]) where dims :: Dim ds -> [Int] headDim :: Dim (d ': ds) -> Proxy d headDim _ = Proxy tailDim :: Dim (d ': ds) -> Dim ds tailDim _ = Dim instance Dimensions '[] where dims _ = [] instance (KnownNat d, Dimensions ds) => Dimensions (d ': ds) where dims x = (fromIntegral . natVal $ headDim x) : dims (tailDim x) printCrazy :: Dimensions d => Dim d -> String printCrazy d = show $ dims d contraDimsType :: Tensor t n m -> Dim n contraDimsType _ = Dim coDimsType :: Tensor t n m -> Dim m coDimsType _ = Dim contraDims :: Dimensions n => Tensor t n m -> [Int] contraDims = dims . contraDimsType coDims :: Dimensions m => Tensor t n m -> [Int] coDims = dims . coDimsType type Vec2 = Tensor Double '[2] '[] vec2 :: Double -> Double -> Vec2 vec2 x y = T10 $ Vector2 x y class TensorCalculus t (ns :: [Nat]) (ms :: [Nat]) where data Tensor t ns ms type TensorStore t ns ms -- -- | Add a contravariant rank -- infixr 5 .<. -- (.<.) :: Tensor t ns ms -> Tensor t ns ms -> Tensor t (2 ': ns) ms -- -- | Add a covariant rank -- infixr 5 .>. -- (.>.) :: Tensor t ns ms -> Tensor t ns ms -> Tensor t ns (2 ': ms) -- -- | Append dimension of the first contravariant rank -- infixr 5 .< -- (.<) :: AppendDim (TensorStore t ns ms) (TensorStore t (nb ': ns) ms) (TensorStore t ((nb + 1) ': ns) ms) -- => Tensor t ns ms -> Tensor t (nb ': ns) ms -> Tensor t ((nb + 1) ': ns) ms -- -- | Append dimension of the first covariant rank -- infixr 5 .> -- (.>) :: AppendDim (TensorStore t ns ms) (TensorStore t ns (mb ': ms)) (TensorStore t ns ((mb + 1) ': ms)) -- => Tensor t ns ms -> Tensor t ns (mb ': ms) -> Tensor t ns ((mb + 1) ': ms) -- AppendDim (Tensor t ns ms) (Tensor t ns (mb ': ms)) (Tensor t ns ((mb + 1) ': ms)) instance TensorCalculus t '[] '[] where newtype Tensor t '[] '[] = T00 t deriving (Bounded, Enum, Eq, Integral, Num, Fractional, Floating, Ord, Read, Real, RealFrac, RealFloat, Show) type TensorStore t '[] '[] = t -- T00 a .<. T00 b = T10 $ Vector2 a b -- T00 a .>. T00 b = T01 $ Vector2 a b -- T00 a .< T10 b = T10 $ appendDim a b -- T00 a .> T01 b = T01 $ appendDim a b instance TensorCalculus t '[n] '[] where newtype Tensor t '[n] '[] = T10 (SomeVector t n) type TensorStore t '[n] '[] = SomeVector t n -- contraV (T00 a) (T00 b) = T10 $ Vector2 a b instance TensorCalculus t '[] '[m] where newtype Tensor t '[] '[m] = T01 (SomeVector t m) type TensorStore t '[] '[m] = SomeVector t m instance TensorCalculus t '[n1, n2] '[] where newtype Tensor t '[n1, n2] '[] = T20 (SomeMatrix t n1 n2) type TensorStore t '[n1, n2] '[] = SomeMatrix t n1 n2 instance TensorCalculus t '[n] '[m] where newtype Tensor t '[n] '[m] = T11 (SomeMatrix t n m) type TensorStore t '[n] '[m] = SomeMatrix t n m instance TensorCalculus t '[] '[m1,m2] where newtype Tensor t '[] '[m1,m2] = T02 (SomeMatrix t m1 m2) type TensorStore t '[] '[m1,m2] = SomeMatrix t m1 m2 instance TensorCalculus t (n1 ': n2 ': n3 ': ns) '[] where newtype Tensor t (n1 ': n2 ': n3 ': ns) '[] = Tn0 (NDArray t) deriving Show type TensorStore t (n1 ': n2 ': n3 ': ns) '[] = NDArray t instance TensorCalculus t (n1 ': n2 ': n3 ': ns) '[m0] where newtype Tensor t (n1 ': n2 ': n3 ': ns) '[m0] = Tn1 (NDArray t) deriving Show type TensorStore t (n1 ': n2 ': n3 ': ns) '[m0] = NDArray t instance TensorCalculus t (n1 ': n2 ': n3 ': ns) '[m0,m1] where newtype Tensor t (n1 ': n2 ': n3 ': ns) '[m0, m1] = Tn2 (NDArray t) deriving Show type TensorStore t (n1 ': n2 ': n3 ': ns) '[m0, m1] = NDArray t instance TensorCalculus t '[] (m1 ': m2 ': m3 ': ms) where newtype Tensor t '[] (m1 ': m2 ': m3 ': ms) = T0m (NDArray t) deriving Show type TensorStore t '[] (m1 ': m2 ': m3 ': ms) = NDArray t instance TensorCalculus t '[n1] (m1 ': m2 ': m3 ': ms) where newtype Tensor t '[n1] (m1 ': m2 ': m3 ': ms) = T1m (NDArray t) deriving Show type TensorStore t '[n1] (m1 ': m2 ': m3 ': ms) = NDArray t instance TensorCalculus t '[n1, n2] (m1 ': m2 ': m3 ': ms) where newtype Tensor t '[n1, n2] (m1 ': m2 ': m3 ': ms) = T2m (NDArray t) deriving Show type TensorStore t '[n1, n2] (m1 ': m2 ': m3 ': ms) = NDArray t instance TensorCalculus t (n1 ': n2 ': n3 ': ns) (m1 ': m2 ': m3 ': ms) where newtype Tensor t (n1 ': n2 ': n3 ': ns) (m1 ': m2 ': m3 ': ms) = Tnm (NDArray t) deriving Show type TensorStore t (n1 ': n2 ': n3 ': ns) (m1 ': m2 ': m3 ': ms) = NDArray t deriving instance Show (SomeVector t n) => Show (Tensor t '[n] '[]) deriving instance Plus (SomeVector t n) => Plus (Tensor t '[n] '[]) deriving instance Show (SomeVector t m) => Show (Tensor t '[] '[m]) deriving instance Plus (SomeVector t m) => Plus (Tensor t '[] '[m]) deriving instance Num (SomeVector t n) => Num (Tensor t '[n] '[]) deriving instance Show (SomeMatrix t n0 n1) => Show (Tensor t '[n0, n1] '[]) deriving instance Show (SomeMatrix t n0 m0) => Show (Tensor t '[n0] '[m0]) deriving instance Show (SomeMatrix t m0 m1) => Show (Tensor t '[] '[m0,m1]) --contraV :: Tensor t ns ms -> Tensor t ns ms -> Tensor t (2 ':. ns) ms --contraV class Plus a where plus :: a -> a -> a instance Num t => Plus (Vector1 t) where plus (Vector1 a) (Vector1 b) = Vector1 (a+b) instance Num t => Plus (Vector2 t) where plus (Vector2 a1 a2) (Vector2 b1 b2) = Vector2 (a1+b1) (a2+b2) instance Num t => Plus (Vector3 t) where plus (Vector3 a1 a2 a3) (Vector3 b1 b2 b3) = Vector3 (a1+b1) (a2+b2) (a3+b3) instance Num t => Plus (Vector4 t) where plus (Vector4 a1 a2 a3 a4) (Vector4 b1 b2 b3 b4) = Vector4 (a1+b1) (a2+b2) (a3+b3) (a4+b4) instance Num t => Plus (VectorN t n) where plus (VectorN as) (VectorN bs) = VectorN $ zipWith (+) as bs newtype Vector1 t = Vector1 t deriving Show data Vector2 t = Vector2 t t deriving Show data Vector3 t = Vector3 t t t deriving Show data Vector4 t = Vector4 t t t t deriving Show newtype VectorN t (n::Nat) = VectorN [t] deriving Show newtype Matrix1x1 t = Matrix1x1 t deriving Show data Matrix2x2 t = Matrix2x2 t t t t deriving Show newtype MatrixNxM t (n::Nat) (m::Nat) = MatrixNxM [[t]] deriving Show data NDArray t = NDArray ByteArray# instance Show (NDArray t) where show _ = "Big array" --data TensorContraOnly t (ns :: [Nat]) = TContra Addr# --data TensorCoOnly t (ns :: [Nat]) = TCo Addr# --data TensorCo t (ns :: [Nat]) (ms :: [Nat]) = T Addr# type family SomeVector t (n :: Nat) = v | v -> t n where SomeVector t 1 = Vector1 t SomeVector t 2 = Vector2 t SomeVector t 3 = Vector3 t SomeVector t 4 = Vector4 t SomeVector t n = VectorN t n type family SomeMatrix t (n :: Nat) (m :: Nat) = v | v -> t n m where SomeMatrix t 1 1 = Matrix1x1 t SomeMatrix t 2 2 = Matrix2x2 t SomeMatrix t n m = MatrixNxM t n m class AppendDim a b c | a b -> c where appendDim :: a -> b -> c instance AppendDim t (Vector2 t) (Vector3 t) where appendDim a (Vector2 b1 b2) = Vector3 a b1 b2 instance AppendDim t (Vector3 t) (Vector4 t) where appendDim a (Vector3 b1 b2 b3) = Vector4 a b1 b2 b3 instance AppendDim t (Vector4 t) (VectorN t 5) where appendDim a (Vector4 b1 b2 b3 b4) = VectorN [a,b1,b2,b3,b4] instance (m ~ (n+1), 5 <= m, 4 <= n) => AppendDim t (VectorN t n) (VectorN t m) where appendDim a (VectorN bs) = VectorN $ a : bs --class VectorOps a b c | a b -> c where -- appendVecs :: a -> b -> c -- -- --instance VectorOps (Vector1 t) (Vector1 t) (Vector2 t) where -- appendVecs (Vector1 a) (Vector1 b) = Vector2 a b --instance VectorOps (Vector2 t) (Vector1 t) (Vector3 t) where -- appendVecs (Vector2 a1 a2) (Vector1 b) = Vector3 a1 a2 b --instance VectorOps (Vector3 t) (Vector1 t) (Vector4 t) where -- appendVecs (Vector3 a1 a2 a3) (Vector1 b) = Vector4 a1 a2 a3 b --instance VectorOps (Vector4 t) (Vector1 t) (VectorN t 5) where -- appendVecs (Vector4 a1 a2 a3 a4) (Vector1 b) = VectorN [a1,a2,a3,a4,b] --instance m ~ (n+1) => VectorOps (VectorN t n) (Vector1 t) (VectorN t m) where -- appendVecs (VectorN as) (Vector1 b) = VectorN $ as ++ [b] -- --instance VectorOps (Vector1 t) (Vector2 t) (Vector3 t) where -- appendVecs (Vector1 a) (Vector2 b1 b2) = Vector3 a b1 b2 --instance VectorOps (Vector2 t) (Vector2 t) (Vector4 t) where -- appendVecs (Vector2 a1 a2) (Vector2 b1 b2) = Vector4 a1 a2 b1 b2 --instance VectorOps (Vector3 t) (Vector2 t) (VectorN t 5) where -- appendVecs (Vector3 a1 a2 a3) (Vector2 b1 b2) = VectorN [a1,a2,a3,b1,b2] --instance VectorOps (Vector4 t) (Vector2 t) (VectorN t 6) where -- appendVecs (Vector4 a1 a2 a3 a4) (Vector2 b1 b2) = VectorN [a1,a2,a3,a4,b1,b2] --instance m ~ (n+2) => VectorOps (VectorN t n) (Vector2 t) (VectorN t m) where -- appendVecs (VectorN as) (Vector2 b1 b2) = VectorN $ as ++ [b1,b2] -- --instance VectorOps (Vector1 t) (Vector3 t) (Vector4 t) where -- appendVecs (Vector1 a) (Vector3 b1 b2 b3) = Vector4 a b1 b2 b3 --instance VectorOps (Vector2 t) (Vector3 t) (VectorN t 5) where -- appendVecs (Vector2 a1 a2) (Vector3 b1 b2 b3) = VectorN [a1,a2,b1,b2,b3] --instance VectorOps (Vector3 t) (Vector3 t) (VectorN t 6) where -- appendVecs (Vector3 a1 a2 a3) (Vector3 b1 b2 b3) = VectorN [a1,a2,a3,b1,b2,b3] --instance VectorOps (Vector4 t) (Vector3 t) (VectorN t 7) where -- appendVecs (Vector4 a1 a2 a3 a4) (Vector3 b1 b2 b3) = VectorN [a1,a2,a3,a4,b1,b2,b3] --instance m ~ (n+3) => VectorOps (VectorN t n) (Vector3 t) (VectorN t m) where -- appendVecs (VectorN as) (Vector3 b1 b2 b3) = VectorN $ as ++ [b1,b2,b3] -- --instance VectorOps (Vector1 t) (Vector4 t) (VectorN t 5) where -- appendVecs (Vector1 a) (Vector4 b1 b2 b3 b4) = VectorN [a,b1,b2,b3,b4] --instance VectorOps (Vector2 t) (Vector4 t) (VectorN t 6) where -- appendVecs (Vector2 a1 a2) (Vector4 b1 b2 b3 b4) = VectorN [a1,a2,b1,b2,b3,b4] --instance VectorOps (Vector3 t) (Vector4 t) (VectorN t 7) where -- appendVecs (Vector3 a1 a2 a3) (Vector4 b1 b2 b3 b4) = VectorN [a1,a2,a3,b1,b2,b3,b4] --instance VectorOps (Vector4 t) (Vector4 t) (VectorN t 8) where -- appendVecs (Vector4 a1 a2 a3 a4) (Vector4 b1 b2 b3 b4) = VectorN [a1,a2,a3,a4,b1,b2,b3,b4] --instance m ~ (n+4) => VectorOps (VectorN t n) (Vector4 t) (VectorN t m) where -- appendVecs (VectorN as) (Vector4 b1 b2 b3 b4) = VectorN $ as ++ [b1,b2,b3,b4] -- --instance k ~ (m+1) => VectorOps (Vector1 t) (VectorN t m) (VectorN t k) where -- appendVecs (Vector1 a) (VectorN bs) = VectorN $ a : bs --instance k ~ (m+2) => VectorOps (Vector2 t) (VectorN t m) (VectorN t 6) where -- appendVecs (Vector2 a1 a2) (VectorN bs) = VectorN $ a1 : a2 : bs --instance k ~ (m+3) => VectorOps (Vector3 t) (VectorN t m) (VectorN t 7) where -- appendVecs (Vector3 a1 a2 a3) (VectorN bs) = VectorN $ a1 : a2 : a3 : bs --instance k ~ (m+4) => VectorOps (Vector4 t) (VectorN t m) (VectorN t 8) where -- appendVecs (Vector4 a1 a2 a3 a4) (VectorN bs) = VectorN $ a1 : a2 : a3 : a4 : bs --instance k ~ (m+n) => VectorOps (VectorN t n) (VectorN t m) (VectorN t m) where -- appendVecs (VectorN as) (VectorN bs) = VectorN $ as ++ bs