-- |
-- Module      : ConClusion.Array.Util
-- Description : Additional tools to work with numerical arrays
-- Copyright   : Phillip Seeber, 2022
-- License     : AGPL-3
-- Maintainer  : phillip.seeber@googlemail.com
-- Stability   : experimental
-- Portability : POSIX, Windows
module ConClusion.Array.Util
  ( IndexException (..),
    magnitude,
    normalise,
    angle,
    minDistAt,
    minDistAtVec,
    iMinimumM,
  )
where

import Data.Massiv.Array as Massiv hiding (IndexException)
import RIO
import System.IO.Unsafe (unsafePerformIO)

-- | Exception regarding indexing in some kind of aaray.
newtype IndexException = IndexException String deriving (Int -> IndexException -> ShowS
[IndexException] -> ShowS
IndexException -> String
(Int -> IndexException -> ShowS)
-> (IndexException -> String)
-> ([IndexException] -> ShowS)
-> Show IndexException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IndexException] -> ShowS
$cshowList :: [IndexException] -> ShowS
show :: IndexException -> String
$cshow :: IndexException -> String
showsPrec :: Int -> IndexException -> ShowS
$cshowsPrec :: Int -> IndexException -> ShowS
Show)

instance Exception IndexException

-- | Magnitude of a vector (length).
magnitude :: (Massiv.Numeric r e, Source r e, Floating e) => Massiv.Vector r e -> e
magnitude :: Vector r e -> e
magnitude Vector r e
v = e -> e
forall a. Floating a => a -> a
sqrt (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ Vector r e
v Vector r e -> Vector r e -> e
forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! Vector r e
v

-- | Normalise a vector.
normalise :: (Massiv.Numeric r e, Source r e, Floating e) => Massiv.Vector r e -> Massiv.Vector r e
normalise :: Vector r e -> Vector r e
normalise Vector r e
v = Vector r e
v Vector r e -> e -> Vector r e
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* (e
1 e -> e -> e
forall a. Fractional a => a -> a -> a
/ Vector r e -> e
forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
v)

-- | Angle between two vectors.
angle :: (Massiv.Numeric r e, Source r e, Floating e) => Massiv.Vector r e -> Massiv.Vector r e -> e
angle :: Vector r e -> Vector r e -> e
angle Vector r e
a Vector r e
b = e -> e
forall a. Floating a => a -> a
acos (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ Vector r e
a Vector r e -> Vector r e -> e
forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! Vector r e
b e -> e -> e
forall a. Fractional a => a -> a -> a
/ (Vector r e -> e
forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
a e -> e -> e
forall a. Num a => a -> a -> a
* Vector r e -> e
forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
b)

-- | Find the minimal distance in a distance matrix, which is not the main diagonal.
{-# SCC minDistAt #-}
minDistAt ::
  ( Manifest r e,
    MonadThrow m,
    Ord e
  ) =>
  Massiv.Matrix r e ->
  m (e, Ix2)
minDistAt :: Matrix r e -> m (e, Ix2)
minDistAt Matrix r e
arr
  | Matrix r e -> Bool
forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Matrix r e
arr = SizeException -> m (e, Ix2)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m (e, Ix2)) -> SizeException -> m (e, Ix2)
forall a b. (a -> b) -> a -> b
$ Sz Ix2 -> SizeException
forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (Matrix r e -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Matrix r e
arr)
  | Bool
otherwise = (e, Ix2) -> m (e, Ix2)
forall (m :: * -> *) a. Monad m => a -> m a
return ((e, Ix2) -> m (e, Ix2))
-> (IO (e, Ix2) -> (e, Ix2)) -> IO (e, Ix2) -> m (e, Ix2)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (e, Ix2) -> (e, Ix2)
forall a. IO a -> a
unsafePerformIO (IO (e, Ix2) -> m (e, Ix2)) -> IO (e, Ix2) -> m (e, Ix2)
forall a b. (a -> b) -> a -> b
$ ((e, Ix2) -> Ix2 -> e -> (e, Ix2))
-> (e, Ix2)
-> ((e, Ix2) -> (e, Ix2) -> (e, Ix2))
-> (e, Ix2)
-> Matrix r e
-> IO (e, Ix2)
forall (m :: * -> *) ix r e a b.
(MonadIO m, Index ix, Source r e) =>
(a -> ix -> e -> a)
-> a -> (b -> a -> b) -> b -> Array r ix e -> m b
ifoldlP (e, Ix2) -> Ix2 -> e -> (e, Ix2)
forall a. Ord a => (a, Ix2) -> Ix2 -> a -> (a, Ix2)
minFold (e, Ix2)
start (e, Ix2) -> (e, Ix2) -> (e, Ix2)
forall a b. Ord a => (a, b) -> (a, b) -> (a, b)
chFold (e, Ix2)
start Matrix r e
arr
  where
    ix0 :: Ix2
ix0 = Int -> Ix2
forall ix. Index ix => Int -> ix
pureIndex Int
0
    e0 :: e
e0 = Matrix r e
arr Matrix r e -> Ix2 -> e
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
Massiv.! Ix2
ix0
    start :: (e, Ix2)
start = (e
e0, Ix2
ix0)
    minFold :: (a, Ix2) -> Ix2 -> a -> (a, Ix2)
minFold acc :: (a, Ix2)
acc@(a
eA, Ix2
_) ix :: Ix2
ix@(Int
m :. Int
n) a
e = if a
e a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eA Bool -> Bool -> Bool
&& Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n then (a
e, Ix2
ix) else (a, Ix2)
acc
    chFold :: (a, b) -> (a, b) -> (a, b)
chFold acc :: (a, b)
acc@(a
eA, b
_) ch :: (a, b)
ch@(a
e, b
_) = if a
e a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eA then (a, b)
ch else (a, b)
acc

-- | Find the minimal element of a vector, which is at a larger than the supplied index.
minDistAtVec ::
  ( Manifest r e,
    MonadThrow m,
    Ord e
  ) =>
  Ix1 ->
  Massiv.Vector r e ->
  m (e, Ix1)
minDistAtVec :: Int -> Vector r e -> m (e, Int)
minDistAtVec Int
ixStart Vector r e
vec
  | Vector r e -> Bool
forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Vector r e
vec = SizeException -> m (e, Int)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m (e, Int)) -> SizeException -> m (e, Int)
forall a b. (a -> b) -> a -> b
$ Sz Int -> SizeException
forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (Vector r e -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Vector r e
vec)
  | Int
ixStart Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nElems = IndexException -> m (e, Int)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (IndexException -> m (e, Int)) -> IndexException -> m (e, Int)
forall a b. (a -> b) -> a -> b
$ Sz Int -> Int -> IndexException
forall ix. Index ix => Sz ix -> ix -> IndexException
IndexOutOfBoundsException (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
nElems) Int
ixStart
  | Bool
otherwise = do
    let (e
minE, Int
minIx) = IO (e, Int) -> (e, Int)
forall a. IO a -> a
unsafePerformIO (IO (e, Int) -> (e, Int)) -> IO (e, Int) -> (e, Int)
forall a b. (a -> b) -> a -> b
$ ((e, Int) -> Int -> e -> (e, Int))
-> (e, Int)
-> ((e, Int) -> (e, Int) -> (e, Int))
-> (e, Int)
-> Vector r e
-> IO (e, Int)
forall (m :: * -> *) ix r e a b.
(MonadIO m, Index ix, Source r e) =>
(a -> ix -> e -> a)
-> a -> (b -> a -> b) -> b -> Array r ix e -> m b
ifoldlP (e, Int) -> Int -> e -> (e, Int)
forall a b. Ord a => (a, b) -> b -> a -> (a, b)
minFold (e, Int)
startAcc (e, Int) -> (e, Int) -> (e, Int)
forall a. Ord a => a -> a -> a
chFold (e, Int)
startAcc Vector r e
searchVec
    (e, Int) -> m (e, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (e
minE, Int
minIx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ixStart Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  where
    Sz Int
nElems = Vector r e -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Vector r e
vec
    searchVec :: Vector r e
searchVec = Sz Int -> Vector r e -> Vector r e
forall r e. Source r e => Sz Int -> Vector r e -> Vector r e
Massiv.drop (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz (Int -> Sz Int) -> Int -> Sz Int
forall a b. (a -> b) -> a -> b
$ Int
ixStart Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Vector r e
vec
    ix0 :: Int
ix0 = Int
0
    e0 :: e
e0 = Vector r e
searchVec Vector r e -> Int -> e
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
Massiv.! Int
ix0
    startAcc :: (e, Int)
startAcc = (e
e0, Int
ix0)
    minFold :: (a, b) -> b -> a -> (a, b)
minFold acc :: (a, b)
acc@(a
eA, b
_) b
ix a
e = if a
e a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eA then (a
e, b
ix) else (a, b)
acc
    chFold :: a -> a -> a
chFold a
acc a
ch = a -> a -> a
forall a. Ord a => a -> a -> a
min a
acc a
ch

-- | Like 'Massiv.minimumM' but also returns the index of the minimal element.
iMinimumM ::
  ( Manifest r a,
    MonadThrow m,
    Index ix,
    Ord a
  ) =>
  Array r ix a ->
  m (a, ix)
iMinimumM :: Array r ix a -> m (a, ix)
iMinimumM Array r ix a
arr
  | Array r ix a -> Bool
forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Array r ix a
arr = SizeException -> m (a, ix)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m (a, ix)) -> SizeException -> m (a, ix)
forall a b. (a -> b) -> a -> b
$ Sz ix -> SizeException
forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (Array r ix a -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Array r ix a
arr)
  | Bool
otherwise = (a, ix) -> m (a, ix)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, ix) -> m (a, ix))
-> (IO (a, ix) -> (a, ix)) -> IO (a, ix) -> m (a, ix)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (a, ix) -> (a, ix)
forall a. IO a -> a
unsafePerformIO (IO (a, ix) -> m (a, ix)) -> IO (a, ix) -> m (a, ix)
forall a b. (a -> b) -> a -> b
$ ((a, ix) -> ix -> a -> (a, ix))
-> (a, ix)
-> ((a, ix) -> (a, ix) -> (a, ix))
-> (a, ix)
-> Array r ix a
-> IO (a, ix)
forall (m :: * -> *) ix r e a b.
(MonadIO m, Index ix, Source r e) =>
(a -> ix -> e -> a)
-> a -> (b -> a -> b) -> b -> Array r ix e -> m b
ifoldlP (a, ix) -> ix -> a -> (a, ix)
forall a b. Ord a => (a, b) -> b -> a -> (a, b)
minFold (a, ix)
start (a, ix) -> (a, ix) -> (a, ix)
forall a b. Ord a => (a, b) -> (a, b) -> (a, b)
chFold (a, ix)
start Array r ix a
arr
  where
    ix0 :: ix
ix0 = Int -> ix
forall ix. Index ix => Int -> ix
pureIndex Int
0
    e0 :: a
e0 = Array r ix a
arr Array r ix a -> ix -> a
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
Massiv.! ix
ix0
    start :: (a, ix)
start = (a
e0, ix
ix0)

    minFold :: (a, b) -> b -> a -> (a, b)
minFold acc :: (a, b)
acc@(a
eA, b
_) b
ix a
e = if a
e a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eA then (a
e, b
ix) else (a, b)
acc
    chFold :: (a, b) -> (a, b) -> (a, b)
chFold acc :: (a, b)
acc@(a
eA, b
_) ch :: (a, b)
ch@(a
e, b
_) = if a
e a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
eA then (a, b)
ch else (a, b)
acc