{-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE Rank2Types #-} module LLVM.Core.Vector (MkVector(..), vector, cyclicVector, ) where import qualified LLVM.ExecutionEngine.Target as Target import qualified LLVM.Core.UnaryVector as UnaryVector import qualified LLVM.Util.Proxy as Proxy import LLVM.Core.Type (IsPrimitive, unsafeTypeRef) import LLVM.Core.Data (Vector(Vector), FixedList) import qualified Type.Data.Num.Decimal.Proof as DecProof import qualified Type.Data.Num.Decimal.Number as Dec import qualified Type.Data.Num.Unary as Unary import Type.Data.Num.Decimal.Literal (D2, D4, D8) import qualified Foreign.Storable.Traversable as Store import Foreign.Storable (Storable(..)) import qualified Test.QuickCheck as QC import qualified Control.Monad.Trans.State as MS import Control.Applicative (Applicative, pure, liftA2, (<*>)) import Control.Functor.HT (unzip, outerProduct) import qualified Data.Traversable as Trav import qualified Data.Foldable as Fold import qualified Data.NonEmpty as NonEmpty import qualified Data.Empty as Empty import Data.Traversable (Traversable, foldMapDefault) import Data.Foldable (Foldable, foldMap) import Data.NonEmpty ((!:)) import System.IO.Unsafe (unsafePerformIO) import Prelude hiding (replicate, map, head, unzip, zipWith, uncurry) -- XXX Should these really be here? class (Dec.Positive n, IsPrimitive a) => MkVector n a where type Tuple n a :: * toVector :: Tuple n a -> Vector n a fromVector :: Vector n a -> Tuple n a instance (IsPrimitive a) => MkVector D2 a where type Tuple D2 a = (a,a) toVector (a1, a2) = vector (a1 !: a2 !: Empty.Cons) fromVector = uncurry $ \a1 a2 -> (a1, a2) instance (IsPrimitive a) => MkVector D4 a where type Tuple D4 a = (a,a,a,a) toVector (a1, a2, a3, a4) = vector (a1 !: a2 !: a3 !: a4 !: Empty.Cons) fromVector = uncurry $ \a1 a2 a3 a4 -> (a1, a2, a3, a4) instance (IsPrimitive a) => MkVector D8 a where type Tuple D8 a = (a,a,a,a,a,a,a,a) toVector (a1, a2, a3, a4, a5, a6, a7, a8) = vector (a1 !: a2 !: a3 !: a4 !: a5 !: a6 !: a7 !: a8 !: Empty.Cons) fromVector = uncurry $ \a1 a2 a3 a4 a5 a6 a7 a8 -> (a1, a2, a3, a4, a5, a6, a7, a8) head :: (Dec.Positive n) => Vector n a -> a head = withPosDict1 $ \dict v -> case dict of DecProof.UnaryPos -> UnaryVector.head . unaryFromDecimalVector $ v unaryFromDecimalVector :: Vector n a -> UnaryVector.T (Dec.ToUnary n) a unaryFromDecimalVector (Vector xs) = UnaryVector.fromFixedList xs decimalFromUnaryVector :: UnaryVector.T (Dec.ToUnary n) a -> Vector n a decimalFromUnaryVector = Vector . UnaryVector.toFixedList type Curried n a b = UnaryVector.Curried (Dec.ToUnary n) a b uncurry :: (Dec.Natural n) => Curried n a b -> Vector n a -> b uncurry f = withNatDict1 $ \dict v -> case dict of DecProof.UnaryNat -> UnaryVector.uncurry f $ unaryFromDecimalVector v withNatDict :: (Dec.Natural n) => (DecProof.UnaryNat n -> Vector n a) -> Vector n a withNatDict f = f DecProof.unaryNat withNatDict1 :: (Dec.Natural n) => (DecProof.UnaryNat n -> Vector n a -> b) -> Vector n a -> b withNatDict1 f = f DecProof.unaryNat withPosDict1 :: (Dec.Positive n) => (DecProof.UnaryPos n -> Vector n a -> b) -> Vector n a -> b withPosDict1 f = f DecProof.unaryPos withUnaryDecVector :: (Dec.Natural n) => (forall m. (Dec.ToUnary n ~ m, Unary.Natural m) => UnaryVector.T m a) -> Vector n a withUnaryDecVector v = withNatDict (\dict -> case dict of DecProof.UnaryNat -> decimalFromUnaryVector v) instance (Storable a, Dec.Positive n, IsPrimitive a) => Storable (Vector n a) where sizeOf a = Target.storeSizeOfType ourTargetData $ unsafeTypeRef $ Proxy.fromValue a alignment a = Target.abiAlignmentOfType ourTargetData $ unsafeTypeRef $ Proxy.fromValue a peek = Store.peekApplicative poke = Store.poke -- XXX The JITer target data. This isn't really right. ourTargetData :: Target.TargetData ourTargetData = unsafePerformIO Target.getTargetData -------------------------------------- {- maybe we should export this in order to allow NumericPrelude instances unVector :: (Dec.Positive n) => Vector n a -> FixedList n a unVector (Vector xs) = xs -} vector :: (Dec.Positive n) => FixedList (Dec.ToUnary n) a -> Vector n a vector = Vector {- | Make a constant vector. Replicates or truncates the list to get length /n/. This behaviour is consistent uncurry that of 'LLVM.Core.CodeGen.constCyclicVector'. May be abused for constructing vectors from lists uncurry statically unknown size. -} cyclicVector :: (Dec.Positive n) => NonEmpty.T [] a -> Vector n a cyclicVector xs = withUnaryDecVector (UnaryVector.cyclicVector xs) replicate :: (Dec.Positive n) => a -> Vector n a replicate a = withUnaryDecVector (pure a) instance (Dec.Positive n) => Functor (Vector n) where fmap f a = withUnaryDecVector (fmap f $ unaryFromDecimalVector a) instance (Dec.Positive n) => Applicative (Vector n) where pure = replicate f <*> a = withUnaryDecVector (unaryFromDecimalVector f <*> unaryFromDecimalVector a) instance (Dec.Positive n) => Foldable (Vector n) where foldMap = foldMapDefault instance (Dec.Positive n) => Traversable (Vector n) where sequenceA = withNatDict1 $ \dict v -> case dict of DecProof.UnaryNat -> fmap decimalFromUnaryVector $ Trav.sequenceA $ unaryFromDecimalVector v instance (Eq a, Dec.Positive n) => Eq (Vector n a) where x == y = Fold.and $ liftA2 (==) x y instance (Ord a, Dec.Positive n) => Ord (Vector n a) where compare x y = Fold.foldr (\r rs -> if r==EQ then rs else r) EQ $ liftA2 compare x y instance (Num a, Dec.Positive n) => Num (Vector n a) where (+) = liftA2 (+) (-) = liftA2 (-) (*) = liftA2 (*) negate = fmap negate abs = fmap abs signum = fmap signum fromInteger = pure . fromInteger instance (Enum a, Dec.Positive n) => Enum (Vector n a) where succ = fmap succ pred = fmap pred fromEnum = error "Vector fromEnum" toEnum = pure . toEnum instance (Real a, Dec.Positive n) => Real (Vector n a) where toRational = error "Vector toRational" instance (Integral a, Dec.Positive n) => Integral (Vector n a) where quot = liftA2 quot rem = liftA2 rem div = liftA2 div mod = liftA2 mod quotRem xs ys = unzip $ liftA2 quotRem xs ys divMod xs ys = unzip $ liftA2 divMod xs ys toInteger = error "Vector toInteger" instance (Fractional a, Dec.Positive n) => Fractional (Vector n a) where (/) = liftA2 (/) fromRational = pure . fromRational instance (RealFrac a, Dec.Positive n) => RealFrac (Vector n a) where properFraction = error "Vector properFraction" instance (Floating a, Dec.Positive n) => Floating (Vector n a) where pi = pure pi sqrt = fmap sqrt log = fmap log logBase = liftA2 logBase (**) = liftA2 (**) exp = fmap exp sin = fmap sin cos = fmap cos tan = fmap tan asin = fmap asin acos = fmap acos atan = fmap atan sinh = fmap sinh cosh = fmap cosh tanh = fmap tanh asinh = fmap asinh acosh = fmap acosh atanh = fmap atanh instance (RealFloat a, Dec.Positive n) => RealFloat (Vector n a) where floatRadix = floatRadix . head floatDigits = floatDigits . head floatRange = floatRange . head decodeFloat = error "Vector decodeFloat" encodeFloat = error "Vector encodeFloat" exponent _ = 0 scaleFloat 0 x = x scaleFloat _ _ = error "Vector scaleFloat" isNaN = error "Vector isNaN" isInfinite = error "Vector isInfinite" isDenormalized = error "Vector isDenormalized" isNegativeZero = error "Vector isNegativeZero" isIEEE = isIEEE . head indices :: (Dec.Positive n) => Vector n Int indices = flip MS.evalState 0 $ Trav.sequenceA $ replicate $ MS.state (\k -> (k,k+1)) instance (Dec.Positive n, QC.Arbitrary a) => QC.Arbitrary (Vector n a) where arbitrary = Trav.sequenceA $ replicate QC.arbitrary shrink v = case indices of ixs -> concatMap (Trav.sequenceA . liftA2 (\x doShrink -> if doShrink then QC.shrink x else [x]) v) $ outerProduct (==) (Fold.toList ixs) ixs