{-# LANGUAGE BangPatterns, CPP, DerivingStrategies, DerivingVia #-} {-# LANGUAGE GeneralizedNewtypeDeriving, MultiWayIf, StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module Control.Subcategory.Semialign ( CSemialign(..), CAlign(..), csalign, cpadZip, cpadZipWith ) where import Control.Applicative (ZipList) import Control.Monad (forM_) import Control.Monad.ST.Strict (runST) import Control.Subcategory.Functor import Control.Subcategory.Wrapper.Internal import Data.Bifunctor (Bifunctor (bimap)) import Data.Coerce import Data.Containers import Data.Functor.Compose (Compose (..)) import Data.Functor.Identity (Identity) import qualified Data.Functor.Product as SOP import Data.Hashable (Hashable) import Data.HashMap.Strict (HashMap) import Data.IntMap.Strict (IntMap) import qualified Data.IntSet as IS import Data.List.NonEmpty (NonEmpty) import Data.Map (Map) import Data.MonoTraversable import qualified Data.Primitive.Array as A import qualified Data.Primitive.PrimArray as PA import qualified Data.Primitive.SmallArray as SA import Data.Proxy (Proxy) import Data.Semialign import Data.Semigroup (Option (..)) import Data.Sequence (Seq) import qualified Data.Sequences as MT import Data.These (These (..), fromThese, mergeThese) import Data.Tree (Tree) import qualified Data.Vector as V import qualified Data.Vector.Primitive as P import qualified Data.Vector.Storable as S import qualified Data.Vector.Unboxed as U import GHC.Generics ((:*:) (..), (:.:) (..)) class CFunctor f => CSemialign f where {-# MINIMAL calignWith #-} calignWith :: (Dom f a, Dom f b, Dom f c) => (These a b -> c) -> f a -> f b -> f c calign :: (Dom f a, Dom f b, Dom f (These a b)) => f a -> f b -> f (These a b) {-# INLINE [1] calign #-} calign = calignWith id instance Semialign f => CSemialign (WrapFunctor f) where calignWith = alignWith {-# INLINE [1] calignWith #-} calign = align {-# INLINE [1] calign #-} instance {-# OVERLAPPING #-} CSemialign (WrapMono IS.IntSet) where calignWith f = withMonoCoercible @IS.IntSet $ coerce @(IS.IntSet -> IS.IntSet -> IS.IntSet) $ \ l r -> let ints = l `IS.intersection` r in IS.unions [ IS.map (f . This) l , IS.map (f . That) r , IS.map (\x -> f $ These x x) ints ] {-# INLINE [1] calignWith #-} class CSemialign f => CAlign f where cnil :: Dom f a => f a instance Align f => CAlign (WrapFunctor f) where cnil = WrapFunctor nil {-# INLINE [1] cnil #-} deriving via WrapFunctor [] instance CSemialign [] deriving via WrapFunctor [] instance CAlign [] deriving via WrapFunctor Maybe instance CSemialign Maybe deriving via WrapFunctor Maybe instance CAlign Maybe #if MIN_VERSION_semialign(1,1,0) deriving via WrapFunctor Option instance CSemialign Option deriving via WrapFunctor Option instance CAlign Option #else deriving newtype instance CSemialign Option deriving newtype instance CAlign Option #endif deriving via WrapFunctor ZipList instance CSemialign ZipList deriving via WrapFunctor ZipList instance CAlign ZipList deriving via WrapFunctor Identity instance CSemialign Identity deriving via WrapFunctor NonEmpty instance CSemialign NonEmpty deriving via WrapFunctor IntMap instance CSemialign IntMap deriving via WrapFunctor IntMap instance CAlign IntMap deriving via WrapFunctor Tree instance CSemialign Tree deriving via WrapFunctor Seq instance CSemialign Seq deriving via WrapFunctor Seq instance CAlign Seq deriving via WrapFunctor V.Vector instance CSemialign V.Vector deriving via WrapFunctor V.Vector instance CAlign V.Vector deriving via WrapFunctor Proxy instance CSemialign Proxy deriving via WrapFunctor Proxy instance CAlign Proxy deriving via WrapFunctor (Map k) instance Ord k => CSemialign (Map k) deriving via WrapFunctor (Map k) instance Ord k => CAlign (Map k) deriving via WrapFunctor (HashMap k) instance (Eq k, Hashable k) => CSemialign (HashMap k) deriving via WrapFunctor (HashMap k) instance (Eq k, Hashable k) => CAlign (HashMap k) deriving via WrapFunctor ((->) s) instance CSemialign ((->) s) instance (CSemialign f, CSemialign g) => CSemialign (SOP.Product f g) where calign (SOP.Pair a b) (SOP.Pair c d) = SOP.Pair (calign a c) (calign b d) {-# INLINE [1] calign #-} calignWith f (SOP.Pair a b) (SOP.Pair c d) = SOP.Pair (calignWith f a c) (calignWith f b d) {-# INLINE [1] calignWith #-} instance (CAlign f, CAlign g) => CAlign (SOP.Product f g) where cnil = SOP.Pair cnil cnil {-# INLINE [1] cnil #-} instance (CSemialign f, CSemialign g) => CSemialign (f :*: g) where calign ((:*:) a b) ((:*:) c d) = (:*:) (calign a c) (calign b d) {-# INLINE [1] calign #-} calignWith f ((:*:) a b) ((:*:) c d) = (:*:) (calignWith f a c) (calignWith f b d) {-# INLINE [1] calignWith #-} instance (CAlign f, CAlign g) => CAlign (f :*: g) where cnil = cnil :*: cnil {-# INLINE [1] cnil #-} instance (CSemialign f, CSemialign g) => CSemialign (Compose f g) where calignWith f (Compose x) (Compose y) = Compose (calignWith g x y) where g (This ga) = cmap (f . This) ga g (That gb) = cmap (f . That) gb g (These ga gb) = calignWith f ga gb {-# INLINE [1] calignWith #-} instance (CAlign f, CSemialign g) => CAlign (Compose f g) where cnil = Compose cnil {-# INLINE [1] cnil #-} instance (CSemialign f, CSemialign g) => CSemialign ((:.:) f g) where calignWith f (Comp1 x) (Comp1 y) = Comp1 (calignWith g x y) where g (This ga) = cmap (f . This) ga g (That gb) = cmap (f . That) gb g (These ga gb) = calignWith f ga gb {-# INLINE [1] calignWith #-} instance (CAlign f, CSemialign g) => CAlign ((:.:) f g) where cnil = Comp1 cnil {-# INLINE [1] cnil #-} instance CSemialign U.Vector where calignWith = alignVectorWith {-# INLINE [1] calignWith #-} instance CAlign U.Vector where cnil = U.empty {-# INLINE [1] cnil #-} instance CSemialign S.Vector where calignWith = alignVectorWith {-# INLINE [1] calignWith #-} instance CAlign S.Vector where cnil = S.empty {-# INLINE [1] cnil #-} instance CSemialign P.Vector where calignWith = alignVectorWith {-# INLINE [1] calignWith #-} instance CAlign P.Vector where cnil = P.empty {-# INLINE [1] cnil #-} instance CSemialign SA.SmallArray where calignWith f l r = runST $ do let !lenL = length l !lenR = length r (isLftShort, thresh, len) | lenL < lenR = (True, lenL, lenR) | otherwise = (False, lenR, lenL) sa <- SA.newSmallArray len (error "Uninitialised element") forM_ [0..len-1] $ \n -> if | n == len -> pure () | n < thresh -> SA.writeSmallArray sa n $ f $ These (SA.indexSmallArray l n) (SA.indexSmallArray r n) | isLftShort -> SA.writeSmallArray sa n $ f $ That $ SA.indexSmallArray r n | otherwise -> SA.writeSmallArray sa n $ f $ This $ SA.indexSmallArray l n SA.unsafeFreezeSmallArray sa {-# INLINE [1] calignWith #-} instance CAlign SA.SmallArray where cnil = SA.smallArrayFromListN 0 [] {-# INLINE [1] cnil #-} instance CSemialign A.Array where calignWith f l r = runST $ do let !lenL = length l !lenR = length r (isLftShort, thresh, len) | lenL < lenR = (True, lenL, lenR) | otherwise = (False, lenR, lenL) sa <- A.newArray len (error "Uninitialised element") forM_ [0..len-1] $ \n -> if | n == len -> pure () | n < thresh -> A.writeArray sa n $ f $ These (A.indexArray l n) (A.indexArray r n) | isLftShort -> A.writeArray sa n $ f $ That $ A.indexArray r n | otherwise -> A.writeArray sa n $ f $ This $ A.indexArray l n A.unsafeFreezeArray sa {-# INLINE [1] calignWith #-} instance CAlign A.Array where cnil = A.fromListN 0 [] {-# INLINE [1] cnil #-} instance CSemialign PA.PrimArray where calignWith f l r = runST $ do let !lenL = PA.sizeofPrimArray l !lenR = PA.sizeofPrimArray r (isLftShort, thresh, len) | lenL < lenR = (True, lenL, lenR) | otherwise = (False, lenR, lenL) sa <- PA.newPrimArray len forM_ [0..len-1] $ \n -> if | n == len -> pure () | n < thresh -> PA.writePrimArray sa n $ f $ These (PA.indexPrimArray l n) (PA.indexPrimArray r n) | isLftShort -> PA.writePrimArray sa n $ f $ That $ PA.indexPrimArray r n | otherwise -> PA.writePrimArray sa n $ f $ This $ PA.indexPrimArray l n PA.unsafeFreezePrimArray sa {-# INLINE [1] calignWith #-} instance CAlign PA.PrimArray where cnil = PA.primArrayFromListN 0 [] {-# INLINE [1] cnil #-} instance (MT.IsSequence mono, MonoZip mono) => CSemialign (WrapMono mono) where calignWith f = coerce go where go :: mono -> mono -> mono go ls rs | lenL == lenR = ozipWith (fmap f . These) ls rs | lenL < lenR = ozipWith (fmap f . These) ls rs <> omap (f . That) (MT.drop (fromIntegral lenL) rs) | otherwise = ozipWith (fmap f . These) ls rs <> omap (f . This) (MT.drop (fromIntegral lenL) ls) where lenL = olength ls lenR = olength rs instance (MT.IsSequence mono, MonoZip mono) => CAlign (WrapMono mono) where cnil = WrapMono mempty csalign :: (CSemialign f, Dom f a, Semigroup a) => f a -> f a -> f a {-# INLINE [1] csalign #-} csalign = calignWith $ mergeThese (<>) cpadZip :: (CSemialign f, Dom f a, Dom f b, Dom f (Maybe a, Maybe b)) => f a -> f b -> f (Maybe a, Maybe b) {-# INLINE [1] cpadZip #-} cpadZip = calignWith (fromThese Nothing Nothing . bimap Just Just) cpadZipWith :: (CSemialign f, Dom f a, Dom f b, Dom f c) => (Maybe a -> Maybe b -> c) -> f a -> f b -> f c {-# INLINE [1] cpadZipWith #-} cpadZipWith f = calignWith $ uncurry f . fromThese Nothing Nothing . bimap Just Just