{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Convolution.Preprocessed (
   Transform2,
   karatsuba,
   ) where

import Data.Array.Accelerate.Convolution.Private (Transform2, indexPad, )

import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import Data.Array.Accelerate.Utility.Lift.Exp (expr)

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate ((:.)((:.)), Any(Any), All(All), Slice, Shape, )


{- |
Both arrays must have the same size.

There is not much to preprocess,
thus you should prefer 'Data.Array.Accelerate.Convolution.Adhoc.karatsuba'.
-}
karatsuba ::
   (Shape sh, Slice sh, A.Num a) =>
   Int -> Transform2 (sh :. Int) a
karatsuba len x y =
   if len <= 1
     then A.zipWith (*) x y
     else
        let len2 = - div (-len) 2
            elen2 = A.constant len2
            xl = Sliced.take elen2 x
            yl = Sliced.take elen2 y
            xr = Sliced.pad 0 elen2 $ Sliced.drop elen2 x
            yr = Sliced.pad 0 elen2 $ Sliced.drop elen2 y
            zmerged =
               karatsuba len2
                  (Sliced.stack3 xl (A.zipWith (+) xl xr) xr)
                  (Sliced.stack3 yl (A.zipWith (+) yl yr) yr)
            zl = A.slice zmerged $ A.lift $ Any :. (0::Int) :. All
            zm = A.slice zmerged $ A.lift $ Any :. (1::Int) :. All
            zr = A.slice zmerged $ A.lift $ Any :. (2::Int) :. All
            zc = A.zipWith (-) zm $ A.zipWith (+) zl zr
            sh = A.indexTail $ A.shape zc
        in  A.generate (A.lift $ sh :. 2*len-1) $
            Exp.modify (expr:.expr) $
            \(ix:.k) ->
               indexPad (ix:.k)         zl +
               indexPad (ix:.k-elen2)   zc +
               indexPad (ix:.k-elen2*2) zr