{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.DataFrame.Internal.Array.Family -- Copyright : (c) Artem Chirkin -- License : BSD3 -- -- Maintainer : chirkin@arch.ethz.ch -- -- ----------------------------------------------------------------------------- module Numeric.DataFrame.Internal.Array.Family ( Array, ScalarBase (..), ArrayBase (..) , ArraySingleton (..) , ArraySing (..), aSingEv, inferASing , inferPrimElem, inferPrim, inferEq, inferShow, inferOrd, inferNum , inferFractional, inferFloating ) where import GHC.Base import Numeric.DataFrame.Internal.Array.Class import Numeric.DataFrame.Internal.Array.Family.ArrayBase import Numeric.DataFrame.Internal.Array.Family.DoubleX2 import Numeric.DataFrame.Internal.Array.Family.DoubleX3 import Numeric.DataFrame.Internal.Array.Family.DoubleX4 import Numeric.DataFrame.Internal.Array.Family.FloatX2 import Numeric.DataFrame.Internal.Array.Family.FloatX3 import Numeric.DataFrame.Internal.Array.Family.FloatX4 import Numeric.DataFrame.Internal.Array.Family.ScalarBase import Numeric.Dimensions import Numeric.PrimBytes -- | This type family aggregates all types used for arrays with different -- dimensioinality. -- The family is injective; thus, it is possible to get type family instance -- given the data constructor (and vice versa). -- If GHC knows the dimensionality of an array at compile time, it chooses -- a more efficient specialized instance of Array, e.g. Scalar newtype wrapper. -- Otherwise, it falls back to the generic ArrayBase implementation. -- -- Data family would not work here, because it would give overlapping instances. type family Array (t :: Type) (ds :: [Nat]) = (v :: Type) | v -> t ds where Array t '[] = ScalarBase t Array Float '[2] = FloatX2 Array Float '[3] = FloatX3 Array Float '[4] = FloatX4 Array Double '[2] = DoubleX2 Array Double '[3] = DoubleX3 Array Double '[4] = DoubleX4 Array t ds = ArrayBase t ds -- | A framework for using Array type family instances. class ArraySingleton (t :: Type) (ds :: [Nat]) where -- | Get Array type family instance aSing :: ArraySing t ds data ArraySing t (ds :: [Nat]) where AScalar :: (Array t ds ~ ScalarBase t) => ArraySing t '[] AF2 :: (Array t ds ~ FloatX2) => ArraySing Float '[2] AF3 :: (Array t ds ~ FloatX3) => ArraySing Float '[3] AF4 :: (Array t ds ~ FloatX4) => ArraySing Float '[4] AD2 :: (Array t ds ~ DoubleX2) => ArraySing Double '[2] AD3 :: (Array t ds ~ DoubleX3) => ArraySing Double '[3] AD4 :: (Array t ds ~ DoubleX4) => ArraySing Double '[4] ABase :: ( Array t ds ~ ArrayBase t ds , PrimBytes t ) => ArraySing t ds deriving instance Eq (ArraySing t ds) deriving instance Ord (ArraySing t ds) deriving instance Show (ArraySing t ds) -- | This function does GHC's magic to convert user-supplied `aSing` function -- to create an instance of `ArraySingleton` typeclass at runtime. -- The trick is taken from Edward Kmett's reflection library explained -- in https://www.schoolofhaskell.com/user/thoughtpolice/using-reflection reifyArraySing :: forall r t ds . ArraySing t ds -> ( ArraySingleton t ds => r) -> r reifyArraySing as k = unsafeCoerce# (MagicArraySing k :: MagicArraySing t ds r) as {-# INLINE reifyArraySing #-} newtype MagicArraySing t (ds :: [Nat]) r = MagicArraySing (ArraySingleton t ds => r) -- | Use `ArraySing` GADT to construct an `ArraySingleton` dictionary. -- In other words, bring an evidence of `ArraySingleton` instance into -- a scope at runtime. aSingEv :: ArraySing t ds -> Evidence (ArraySingleton t ds) aSingEv ds = reifyArraySing ds E {-# INLINE aSingEv #-} -- | Use `ArraySing` GADT to construct an `ArraySingleton` dictionary. -- The same as `aSingEv`, but relies on `PrimBytes` and `Dimensions`. inferASing :: forall t ds . (PrimBytes t, Dimensions ds) => Evidence (ArraySingleton t ds) inferASing = case (dims @_ @ds, primTag @t undefined) of (U, _) -> E (d :* U, PTagFloat) | Just E <- sameDim (D @2) d -> E | Just E <- sameDim (D @3) d -> E | Just E <- sameDim (D @4) d -> E (d :* U, PTagDouble) | Just E <- sameDim (D @2) d -> E | Just E <- sameDim (D @3) d -> E | Just E <- sameDim (D @4) d -> E _ -> case (unsafeCoerce# (E @(ds ~ ds)) :: Evidence (ds ~ '[0])) of E -> E {-# INLINE inferASing #-} instance {-# OVERLAPPABLE #-} (Array t ds ~ ArrayBase t ds, PrimBytes t) => ArraySingleton t ds where aSing = ABase instance {-# OVERLAPPING #-} ArraySingleton t '[] where aSing = AScalar instance {-# OVERLAPPING #-} ArraySingleton Float '[2] where aSing = AF2 instance {-# OVERLAPPING #-} ArraySingleton Float '[3] where aSing = AF3 instance {-# OVERLAPPING #-} ArraySingleton Float '[4] where aSing = AF4 instance {-# OVERLAPPING #-} ArraySingleton Double '[2] where aSing = AD2 instance {-# OVERLAPPING #-} ArraySingleton Double '[3] where aSing = AD3 instance {-# OVERLAPPING #-} ArraySingleton Double '[4] where aSing = AD4 -- | This is a special function, because Scalar does not require PrimBytes. -- That is why the dimension list in the argument is not empty. inferPrimElem :: forall t d ds . ArraySingleton t (d ': ds) => Evidence (PrimBytes t) inferPrimElem = case (aSing :: ArraySing t (d ': ds)) of AF2 -> E AF3 -> E AF4 -> E AD2 -> E AD3 -> E AD4 -> E ABase -> E -- Rather verbose way to show that there is an instance of a required type class -- for every instance of the type family. #define WITNESS case (aSing :: ArraySing t ds) of {\ AScalar -> E;\ AF2 -> E;\ AF3 -> E;\ AF4 -> E;\ AD2 -> E;\ AD3 -> E;\ AD4 -> E;\ ABase -> E} inferPrim :: forall t ds . ( PrimBytes t , ArraySingleton t ds , Dimensions ds ) => Evidence (PrimBytes (Array t ds), PrimArray t (Array t ds)) inferPrim = WITNESS inferEq :: forall t ds . (Eq t, ArraySingleton t ds) => Evidence (Eq (Array t ds)) inferEq = WITNESS inferOrd :: forall t ds . (Ord t, ArraySingleton t ds) => Evidence (Ord (Array t ds)) inferOrd = WITNESS inferNum :: forall t ds . (Num t, ArraySingleton t ds) => Evidence (Num (Array t ds)) inferNum = WITNESS inferFractional :: forall t ds . (Fractional t, ArraySingleton t ds) => Evidence (Fractional (Array t ds)) inferFractional = WITNESS inferFloating :: forall t ds . (Floating t, ArraySingleton t ds) => Evidence (Floating (Array t ds)) inferFloating = WITNESS inferShow :: forall t ds . (Show t, Dimensions ds, ArraySingleton t ds) => Evidence (Show (Array t ds)) inferShow = WITNESS