{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE NoStarIsType #-} module Data.Array.Shaped.Convolve(convolve) where import Data.Array.Shaped import Data.Array.Shaped.MatMul import Data.Array.Internal.Shape import GHC.TypeLits import qualified Numeric.LinearAlgebra as N -- | Convolve the /n/ outer dimensions with the given kernel. -- There is no padding nor striding. -- The input has shape /spatialSh/ ++ /channelSh/, -- the kernel has shape /spatialKernelSh/ ++ /channelSh/ ++ /featureSh/, -- and the result has shape /spatialOutSh/ ++ /featureSh/. -- The /n/ gives the rank of the /spatialSh/. -- -- Example: -- @ -- i :: Array [20,30,3] T -- 20x30 image with 3 channels -- k :: Array [5,5,3,8] T -- 5x5 kernel with 8 output features -- convolve @2 i k :: Array [16,26,8] T -- @ convolve :: forall (n :: Nat) ish ksh osh wsh a ksc ksf i ws isp iwc . ( i ~ Rank ish -- input rank , ws ~ Take n ksh -- window size , Window ws ish wsh, KnownNat (Rank ws) , ksc ~ Size (Take i ksh) -- spatial + channels , ksf ~ Size (Drop i ksh) -- features , isp ~ Size (Take n wsh) -- spatial , iwc ~ Size (Drop n wsh) -- kernel + channels , iwc ~ ksc , osh ~ (Take n wsh ++ Drop i ksh) , Size wsh ~ (isp * iwc) , Size ksh ~ (ksc * ksf) , Size osh ~ (isp * ksf) , Shape wsh, Shape ksh, Shape osh , KnownNat ksc, KnownNat isp, KnownNat ksf , N.Numeric a ) => Array ish a -> Array ksh a -> Array osh a convolve i k = let iw :: Array wsh a iw = window @ws i ir :: Array [isp, iwc] a ir = reshape iw kr :: Array [ksc, ksf] a kr = reshape k m :: Array [isp, ksf] a m = matMul ir kr r :: Array osh a r = reshape m in r _example :: Array [20,30,3] Float -> Array [5,5,3,8] Float -> Array [16,26,8] Float _example = convolve @2