{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | -- Module : Data.Array.Accelerate.Data.Monoid -- Copyright : [2016..2017] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- Monoid instances for Accelerate -- module Data.Array.Accelerate.Data.Monoid ( Monoid(..), (<>), Sum(..), Product(..), ) where import Data.Array.Accelerate as A import Data.Array.Accelerate.Type as A import Data.Array.Accelerate.Smart as A ( Exp(..), PreExp(..) ) import Data.Array.Accelerate.Product as A import Data.Array.Accelerate.Array.Sugar as A import Data.Function import Data.Monoid hiding ( mconcat ) import qualified Prelude as P -- Sum: Monoid under addition -- -------------------------- type instance EltRepr (Sum a) = ((), EltRepr a) instance Elt a => Elt (Sum a) where eltType _ = PairTuple UnitTuple (eltType (undefined::a)) toElt ((),x) = Sum (toElt x) fromElt (Sum x) = ((), fromElt x) instance Elt a => IsProduct Elt (Sum a) where type ProdRepr (Sum a) = ((), a) toProd _ ((),a) = Sum a fromProd _ (Sum a) = ((),a) prod _ _ = ProdRsnoc ProdRunit instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Sum a) where type Plain (Sum a) = Sum (Plain a) lift (Sum a) = Exp $ Tuple $ NilTup `SnocTup` lift a instance Elt a => Unlift Exp (Sum (Exp a)) where unlift t = Sum . Exp $ ZeroTupIdx `Prj` t instance A.Num a => Monoid (Exp (Sum a)) where mempty = 0 mappend = lift2 (mappend :: Sum (Exp a) -> Sum (Exp a) -> Sum (Exp a)) instance A.Num a => P.Num (Exp (Sum a)) where (+) = lift2 ((+) :: Sum (Exp a) -> Sum (Exp a) -> Sum (Exp a)) (-) = lift2 ((-) :: Sum (Exp a) -> Sum (Exp a) -> Sum (Exp a)) (*) = lift2 ((*) :: Sum (Exp a) -> Sum (Exp a) -> Sum (Exp a)) negate = lift1 (negate :: Sum (Exp a) -> Sum (Exp a)) signum = lift1 (signum :: Sum (Exp a) -> Sum (Exp a)) abs = lift1 (signum :: Sum (Exp a) -> Sum (Exp a)) fromInteger x = lift (P.fromInteger x :: Sum (Exp a)) instance A.Eq a => A.Eq (Sum a) where (==) = lift2 ((==) `on` getSum) (/=) = lift2 ((/=) `on` getSum) instance A.Ord a => A.Ord (Sum a) where (<) = lift2 ((<) `on` getSum) (>) = lift2 ((>) `on` getSum) (<=) = lift2 ((<=) `on` getSum) (>=) = lift2 ((>=) `on` getSum) min x y = lift . Sum $ lift2 (min `on` getSum) x y max x y = lift . Sum $ lift2 (max `on` getSum) x y -- Product: Monoid under multiplication -- ------------------------------------ type instance EltRepr (Product a) = ((), EltRepr a) instance Elt a => Elt (Product a) where eltType _ = PairTuple UnitTuple (eltType (undefined::a)) toElt ((),x) = Product (toElt x) fromElt (Product x) = ((), fromElt x) instance Elt a => IsProduct Elt (Product a) where type ProdRepr (Product a) = ((), a) toProd _ ((),a) = Product a fromProd _ (Product a) = ((),a) prod _ _ = ProdRsnoc ProdRunit instance (Lift Exp a, Elt (Plain a)) => Lift Exp (Product a) where type Plain (Product a) = Product (Plain a) lift (Product a) = Exp $ Tuple $ NilTup `SnocTup` lift a instance Elt a => Unlift Exp (Product (Exp a)) where unlift t = Product . Exp $ ZeroTupIdx `Prj` t instance A.Num a => Monoid (Exp (Product a)) where mempty = 1 mappend = lift2 (mappend :: Product (Exp a) -> Product (Exp a) -> Product (Exp a)) instance A.Num a => P.Num (Exp (Product a)) where (+) = lift2 ((+) :: Product (Exp a) -> Product (Exp a) -> Product (Exp a)) (-) = lift2 ((-) :: Product (Exp a) -> Product (Exp a) -> Product (Exp a)) (*) = lift2 ((*) :: Product (Exp a) -> Product (Exp a) -> Product (Exp a)) negate = lift1 (negate :: Product (Exp a) -> Product (Exp a)) signum = lift1 (signum :: Product (Exp a) -> Product (Exp a)) abs = lift1 (signum :: Product (Exp a) -> Product (Exp a)) fromInteger x = lift (P.fromInteger x :: Product (Exp a)) instance A.Eq a => A.Eq (Product a) where (==) = lift2 ((==) `on` getProduct) (/=) = lift2 ((/=) `on` getProduct) instance A.Ord a => A.Ord (Product a) where (<) = lift2 ((<) `on` getProduct) (>) = lift2 ((>) `on` getProduct) (<=) = lift2 ((<=) `on` getProduct) (>=) = lift2 ((>=) `on` getProduct) min x y = lift . Product $ lift2 (min `on` getProduct) x y max x y = lift . Product $ lift2 (max `on` getProduct) x y -- Instances for unit and tuples -- ----------------------------- instance Monoid (Exp ()) where mempty = constant () mappend _ _ = constant () instance (Elt a, Elt b, Monoid (Exp a), Monoid (Exp b)) => Monoid (Exp (a,b)) where mempty = lift (mempty :: Exp a, mempty :: Exp b) mappend x y = let (a1,b1) = unlift x :: (Exp a, Exp b) (a2,b2) = unlift y in lift (a1<>a2, b1<>b2) instance (Elt a, Elt b, Elt c, Monoid (Exp a), Monoid (Exp b), Monoid (Exp c)) => Monoid (Exp (a,b,c)) where mempty = lift (mempty :: Exp a, mempty :: Exp b, mempty :: Exp c) mappend x y = let (a1,b1,c1) = unlift x :: (Exp a, Exp b, Exp c) (a2,b2,c2) = unlift y in lift (a1<>a2, b1<>b2, c1<>c2) instance (Elt a, Elt b, Elt c, Elt d, Monoid (Exp a), Monoid (Exp b), Monoid (Exp c), Monoid (Exp d)) => Monoid (Exp (a,b,c,d)) where mempty = lift (mempty :: Exp a, mempty :: Exp b, mempty :: Exp c, mempty :: Exp d) mappend x y = let (a1,b1,c1,d1) = unlift x :: (Exp a, Exp b, Exp c, Exp d) (a2,b2,c2,d2) = unlift y in lift (a1<>a2, b1<>b2, c1<>c2, d1<>d2) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Monoid (Exp a), Monoid (Exp b), Monoid (Exp c), Monoid (Exp d), Monoid (Exp e)) => Monoid (Exp (a,b,c,d,e)) where mempty = lift (mempty :: Exp a, mempty :: Exp b, mempty :: Exp c, mempty :: Exp d, mempty :: Exp e) mappend x y = let (a1,b1,c1,d1,e1) = unlift x :: (Exp a, Exp b, Exp c, Exp d, Exp e) (a2,b2,c2,d2,e2) = unlift y in lift (a1<>a2, b1<>b2, c1<>c2, d1<>d2, e1<>e2)