{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE UnboxedTuples #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Numeric.Quaternion.Internal.QFloat ( QFloat, Quater (..) ) where import qualified Control.Monad.ST as ST import Data.Coerce (coerce) import Numeric.Basics import Numeric.DataFrame.Internal.PrimArray import qualified Numeric.DataFrame.ST as ST import Numeric.DataFrame.Type import Numeric.Quaternion.Internal import Numeric.Vector.Internal type QFloat = Quater Float deriving instance PrimBytes (Quater Float) deriving instance PrimArray Float (Quater Float) instance Quaternion Float where newtype Quater Float = QFloat (Vector Float 4) {-# INLINE packQ #-} packQ = coerce (vec4 :: Float -> Float -> Float -> Float -> Vector Float 4) {-# INLINE unpackQ# #-} unpackQ# = coerce (unpackV4# :: Vector Float 4 -> (# Float, Float, Float, Float #)) {-# INLINE fromVecNum #-} fromVecNum (unpackV3# -> (# x, y, z #)) = packQ x y z {-# INLINE fromVec4 #-} fromVec4 = coerce {-# INLINE toVec4 #-} toVec4 = coerce {-# INLINE square #-} square (unpackQ# -> (# x, y, z, w #)) = (x * x) + (y * y) + (z * z) + (w * w) {-# INLINE im #-} im (unpackQ# -> (# x, y, z, _ #)) = packQ x y z 0.0 {-# INLINE re #-} re (unpackQ# -> (# _, _, _, w #)) = packQ 0 0 0 w {-# INLINE imVec #-} imVec (unpackQ# -> (# x, y, z, _ #)) = vec3 x y z {-# INLINE taker #-} taker (unpackQ# -> (# _, _, _, w #)) = w {-# INLINE takei #-} takei (unpackQ# -> (# x, _, _, _ #)) = x {-# INLINE takej #-} takej (unpackQ# -> (# _, y, _, _ #)) = y {-# INLINE takek #-} takek (unpackQ# -> (# _, _, z, _ #)) = z {-# INLINE conjugate #-} conjugate (unpackQ# -> (# x, y, z, w #)) = packQ (negate x) (negate y) (negate z) w {-# INLINE rotScale #-} rotScale (unpackQ# -> (# i, j, k, t #)) (unpackV3# -> (# x, y, z #)) = let l = t*t - i*i - j*j - k*k d = 2.0 * ( i*x + j*y + k*z) t2 = t * 2.0 in vec3 (l*x + d*i + t2 * (z*j - y*k)) (l*y + d*j + t2 * (x*k - z*i)) (l*z + d*k + t2 * (y*i - x*j)) {-# INLINE getRotScale #-} getRotScale a b = case (# unpackV3# a, unpackV3# b #) of (# _, (# 0, 0, 0 #) #) -> packQ 0 0 0 0 (# (# 0, 0, 0 #), _ #) -> let x = (1 / 0 :: Float) in packQ x x x x (# (# a1, a2, a3 #), (# b1, b2, b3 #) #) -> let ma = sqrt (a1*a1 + a2*a2 + a3*a3) mb = sqrt (b1*b1 + b2*b2 + b3*b3) d = a1*b1 + a2*b2 + a3*b3 c = sqrt (ma*mb + d) ma2 = ma * 1.4142135623730951 -- sqrt 2.0 r = recip (ma2 * c) c' = sqrt (mb / ma) -- ratio of a and b for corner cases r' = recip (sqrt ( negate (a1*b1 + a2*b2) )) in case unpackV3# (cross a b) of (# 0, 0, 0 #) -- if a and b face the same direction, q is fully real | d >= 0 -> packQ 0 0 0 c' -- if a and b face opposite directions, find an orthogonal vector -- prerequisites: w == 0 and a·(x,y,z) == 0 -- corner cases: only one vector component is non-zero | b1 == 0 -> packQ c' 0 0 0 -- otherwise set the last component to zero, -- and get an orthogonal vector in 2D. | otherwise -> packQ (-b2*r') (b1*r') 0 0 -- NB: here we have some precision troubles -- when a and b are close to parallel and opposite. (# t1, t2, t3 #) -> packQ (t1 * r) (t2 * r) (t3 * r) (c / ma2) {-# INLINE axisRotation #-} axisRotation v a = case unpackV3# v of (# 0, 0, 0 #) -> packQ 0 0 0 (negateUnless (abs a < M_PI) 1) (# x, y, z #) -> let c = cos (a * 0.5) s = sin (a * 0.5) / sqrt (x*x + y*y + z*z) in packQ (x * s) (y * s) (z * s) c {-# INLINE qArg #-} qArg (unpackQ# -> (# x, y, z, w #)) = 2 * atan2 (sqrt (x*x + y*y + z*z)) w {-# INLINE fromMatrix33 #-} fromMatrix33 m = fromM 1 (ix# 0# m) (ix# 1# m) (ix# 2# m) (ix# 3# m) (ix# 4# m) (ix# 5# m) (ix# 6# m) (ix# 7# m) (ix# 8# m) {-# INLINE fromMatrix44 #-} fromMatrix44 m = fromM (ix# 15# m) (ix# 0# m) (ix# 1# m) (ix# 2# m) (ix# 4# m) (ix# 5# m) (ix# 6# m) (ix# 8# m) (ix# 9# m) (ix# 10# m) {-# INLINE toMatrix33 #-} toMatrix33 (unpackQ# -> (# 0.0, 0.0, 0.0, w #)) = let x = w * w f 0 = (# 3 :: Int , x #) f k = (# k-1, 0 #) in case gen# (CumulDims [9,3,1]) f 0 of (# _, m #) -> m -- diag (scalar (w * w)) toMatrix33 (unpackQ# -> (# x', y', z', w' #)) = let x = scalar x' y = scalar y' z = scalar z' w = scalar w' x2 = x * x y2 = y * y z2 = z * z w2 = w * w l2 = x2 + y2 + z2 + w2 in ST.runST $ do df <- ST.newDataFrame ST.writeDataFrameOff df 0 $ l2 - 2*(z2 + y2) ST.writeDataFrameOff df 1 $ 2*(x*y + z*w) ST.writeDataFrameOff df 2 $ 2*(x*z - y*w) ST.writeDataFrameOff df 3 $ 2*(x*y - z*w) ST.writeDataFrameOff df 4 $ l2 - 2*(z2 + x2) ST.writeDataFrameOff df 5 $ 2*(y*z + x*w) ST.writeDataFrameOff df 6 $ 2*(x*z + y*w) ST.writeDataFrameOff df 7 $ 2*(y*z - x*w) ST.writeDataFrameOff df 8 $ l2 - 2*(y2 + x2) ST.unsafeFreezeDataFrame df {-# INLINE toMatrix44 #-} toMatrix44 (unpackQ# -> (# 0.0, 0.0, 0.0, w #)) = ST.runST $ do df <- ST.newDataFrame mapM_ (flip (ST.writeDataFrameOff df) 0) [0..15] let w2 = scalar (w * w) ST.writeDataFrameOff df 0 w2 ST.writeDataFrameOff df 5 w2 ST.writeDataFrameOff df 10 w2 ST.writeDataFrameOff df 15 1 ST.unsafeFreezeDataFrame df toMatrix44 (unpackQ# -> (# x', y', z', w' #)) = let x = scalar x' y = scalar y' z = scalar z' w = scalar w' x2 = x * x y2 = y * y z2 = z * z w2 = w * w l2 = x2 + y2 + z2 + w2 in ST.runST $ do df <- ST.newDataFrame ST.writeDataFrameOff df 0 $ l2 - 2*(z2 + y2) ST.writeDataFrameOff df 1 $ 2*(x*y + z*w) ST.writeDataFrameOff df 2 $ 2*(x*z - y*w) ST.writeDataFrameOff df 3 0 ST.writeDataFrameOff df 4 $ 2*(x*y - z*w) ST.writeDataFrameOff df 5 $ l2 - 2*(z2 + x2) ST.writeDataFrameOff df 6 $ 2*(y*z + x*w) ST.writeDataFrameOff df 7 0 ST.writeDataFrameOff df 8 $ 2*(x*z + y*w) ST.writeDataFrameOff df 9 $ 2*(y*z - x*w) ST.writeDataFrameOff df 10 $ l2 - 2*(y2 + x2) ST.writeDataFrameOff df 11 0 ST.writeDataFrameOff df 12 0 ST.writeDataFrameOff df 13 0 ST.writeDataFrameOff df 14 0 ST.writeDataFrameOff df 15 1 ST.unsafeFreezeDataFrame df {- Calculate quaternion from a 3x3 matrix. First argument is a constant; it is either 1 for a 3x3 matrix, or m44 for a 4x4 matrix. I just need to multiply all components by this number. Further NB for the formulae: d == square q == det m ** (1/3) t == trace m == 4 w w - d m01 - m10 == 4 z w m20 - m02 == 4 y w m12 - m21 == 4 x w m01 + m10 == 4 x y m20 + m02 == 4 x z m12 + m21 == 4 y z m00 == + x x - y y - z z + w w m11 == - x x + y y - z z + w w m22 == - x x - y y + z z + w w 4 x x == d + m00 - m11 - m22 4 y y == d - m00 + m11 - m22 4 z z == d - m00 - m11 + m22 4 w w == d + m00 + m11 + m22 -} fromM :: Float -> Float -> Float -> Float -> Float -> Float -> Float -> Float -> Float -> Float -> QFloat fromM c' m00 m01 m02 m10 m11 m12 m20 m21 m22 | t > 0 = let dd = sqrt ( d + t ) is = c / dd in packQ ((m12 - m21)*is) ((m20 - m02)*is) ((m01 - m10)*is) (c*dd) | m00 > m11 && m00 > m22 = let dd = sqrt ( d + m00 - m11 - m22 ) is = c / dd in packQ (c*dd) ((m01 + m10)*is) ((m02 + m20)*is) ((m12 - m21)*is) | m11 > m22 = let dd = sqrt ( d - m00 + m11 - m22 ) is = c / dd in packQ ((m01 + m10)*is) (c*dd) ((m12 + m21)*is) ((m20 - m02)*is) | otherwise = let dd = sqrt ( d - m00 - m11 + m22 ) is = c / dd in packQ ((m02 + m20)*is) ((m12 + m21)*is) (c*dd) ((m01 - m10)*is) where -- normalizing constant c = recip $ 2 * sqrt c' -- trace t = m00 + m11 + m22 -- cubic root of determinant d = ( m00 * ( m11 * m22 - m12 * m21 ) - m01 * ( m10 * m22 - m12 * m20 ) + m02 * ( m10 * m21 - m11 * m20 ) ) ** 0.33333333333333333333333333333333 instance Num QFloat where QFloat a + QFloat b = QFloat (a + b) {-# INLINE (+) #-} QFloat a - QFloat b = QFloat (a - b) {-# INLINE (-) #-} (unpackQ# -> (# a1, a2, a3, a4 #)) * (unpackQ# -> (# b1, b2, b3, b4 #)) = packQ ((a4 * b1) + (a1 * b4) + (a2 * b3) - (a3 * b2)) ((a4 * b2) - (a1 * b3) + (a2 * b4) + (a3 * b1)) ((a4 * b3) + (a1 * b2) - (a2 * b1) + (a3 * b4)) ((a4 * b4) - (a1 * b1) - (a2 * b2) - (a3 * b3)) {-# INLINE (*) #-} negate (QFloat a) = QFloat (negate a) {-# INLINE negate #-} abs = packQ 0 0 0 . sqrt . square {-# INLINE abs #-} signum q@(unpackQ# -> (# x, y, z, w #)) | qd == 0 = q | otherwise = case ix + iy + iz + iw + nn of 0 -> packQ (x * l) (y * l) (z * l) (w * l) 1 -> packQ (copysign 1 x) 0 0 0 2 -> packQ 0 (copysign 1 y) 0 0 4 -> packQ 0 0 (copysign 1 z) 0 8 -> packQ 0 0 0 (copysign 1 w) _ -> packQ n n n n where n = 0 / 0 :: Float qd = x*x + y*y + z*z + w*w ix = if isInfinite x then 1 else 0 :: Int iy = if isInfinite y then 2 else 0 :: Int iz = if isInfinite z then 4 else 0 :: Int iw = if isInfinite w then 8 else 0 :: Int nn = if isNaN x || isNaN y || isNaN z || isNaN w then 16 else 0 :: Int l = recip (sqrt qd) {-# INLINE signum #-} fromInteger = packQ 0 0 0 . fromInteger {-# INLINE fromInteger #-} instance Fractional QFloat where {-# INLINE recip #-} recip q@(unpackQ# -> (# x, y, z, w #)) = case negate (recip (square q)) of c -> packQ (x * c) (y * c) (z * c) (negate (w * c)) {-# INLINE (/) #-} a / b = a * recip b {-# INLINE fromRational #-} fromRational = packQ 0 0 0 . fromRational instance Floating QFloat where {-# INLINE pi #-} pi = packQ 0 0 0 M_PI {-# INLINE exp #-} exp (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z ew | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 ew = exp w l = ew * sin mv / mv arg = ew * cos mv {-# INLINE log #-} log = log' (Vec3 1 0 0) {-# INLINE sqrt #-} sqrt = sqrt' (Vec3 1 0 0) {-# INLINE sin #-} sin (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z (sin w) | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 l = cos w * sinh mv / mv arg = sin w * cosh mv {-# INLINE cos #-} cos (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z (cos w) | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 l = sin w * sinh mv / negate mv arg = cos w * cosh mv {-# INLINE tan #-} tan (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z (tan w) | isInfinite mv2 = signum (packQ x y z 0) | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 b = 2*mv a = 2*w sina = sin a eb = exp (-b) eb2 = eb*eb d = 1 + eb2 + 2 * eb * cos a rd = recip d pa = M_PI - abs a rd' = 2 / (b*b + pa*pa) (l, arg) = if d >= M_EPS then ((1 - eb2) * rd / mv, 2 * sina * eb * rd) else (2 * rd' , negate sina * rd') {-# INLINE sinh #-} sinh (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z (sinh w) | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 l = cosh w * sin mv / mv arg = sinh w * cos mv {-# INLINE cosh #-} cosh (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z (cosh w) | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 l = sinh w * sin mv / mv arg = cosh w * cos mv {-# INLINE tanh #-} tanh (unpackQ# -> (# x, y, z, w #)) | mv2 == 0 = packQ x y z (tanh w) | isInfinite mv2 = packQ 0 0 0 (signum w) | otherwise = packQ (x * l) (y * l) (z * l) arg where mv2 = (x * x) + (y * y) + (z * z) mv = sqrt mv2 b = 2*w a = 2*mv eb = exp (- abs b) eb2 = eb*eb d = 1 + eb2 + 2 * eb * cos a rd = recip d pa = M_PI - a rd' = 2 / (b*b + pa*pa) (l, arg) = if d >= M_EPS then (2 * sin a * eb * rd / mv, copysign (1 - eb2) b * rd) else (2 * rd' , b * rd') {-# INLINE asin #-} -- The original formula: -- asin q = -i * log (i*q + sqrt (1 - q*q)) -- below is a more numerically stable version. asin (unpackQ# -> (# x, y, z, w #)) | v2 == 0 = if w2 <= 1 then packQ x y z (asin w) else packQ l 0 0 arg | otherwise = packQ (x*c) (y*c) (z*c) arg where v2 = (x * x) + (y * y) + (z * z) v = sqrt v2 w2 = w*w w1qq = 0.5 *(1 - w2 + v2) -- real part of (1 - q*q)/2 l1qq = sqrt (w1qq*w1qq + w2*v2) -- length of (1 - q*q)/2 sp2 = l1qq + w1qq sn2 = l1qq - w1qq sp = sqrt sp2 sn = copysign (sqrt sn2) w -- choose a more stable (symbolically equiv) version dp = if 2 * v2 <= sp2 then sp - v else v2 / ((sp + v)*(sn2 + v2)) dn = if 2 * w2 <= sn2 then w - sn else w2 / ((sn + w)*(sp2 + w2)) (wD, vD) = case compare w1qq 0 of GT -> (dp, w * dp / sp) LT -> (v * dn / sn, dn) EQ -> (-v, w) l = -0.5 * log (wD*wD + vD*vD) c = l / v arg = atan2 vD wD {-# INLINE acos #-} acos q = M_PI_2 - asin q {-# INLINE atan #-} -- atan q = i / 2 * log ( (i + q) / (i - q) ) atan (unpackQ# -> (# x, y, z, w #)) | v2 == 0 = packQ x y z (atan w) | otherwise = packQ (x*c) (y*c) (z*c) arg where v2 = (x * x) + (y * y) + (z * z) v = sqrt v2 w2 = w*w q2 = w2 + v2 v' = v - 1 mzero = w2 + v'*v' (c, arg) = if mzero == 0 then ( sqrt maxFinite / v, 0) else ( 0.25 * (log (1 + q2 + 2*v) - log mzero) / v , 0.5 * atan2 (2*w) (1 - q2) ) {-# INLINE asinh #-} -- The original formula: -- asinh q = log (q + sqrt (q*q + 1)) -- below is a more numerically stable version. asinh (unpackQ# -> (# x, y, z, w #)) | v2 == 0 = packQ x y z (asinh w) | otherwise = packQ (x*c) (y*c) (z*c) arg where v2 = (x * x) + (y * y) + (z * z) v = sqrt v2 w2 = w*w w1qq = 0.5 *(1 + w2 - v2) -- real part of (1 + q*q)/2 l1qq = sqrt (w1qq*w1qq + w2*v2) -- length of (1 + q*q)/2 sp2 = l1qq + w1qq sn2 = l1qq - w1qq sp = sqrt sp2 sn = copysign (sqrt sn2) w -- choose a more stable (symbolically equiv) version dp = if 2 * w >= - sp then w + sp else w2 / ((sp - w)*(w2 + sn2)) dn = if 2 * v <= - sn || sn >= 0 then v + sn else v2 / ((v - sn)*(v2 + sp2)) (wD, vD) = case compare w1qq 0 of GT -> (dp, v * dp / sp) LT -> (w * dn / sn, dn) EQ -> (w, v) c = atan2 vD wD / v arg = 0.5 * log (wD*wD + vD*vD) {-# INLINE acosh #-} -- The original formula: -- asinh q = log (q + sqrt (q + 1) * sqrt (q - 1)) -- below is a more numerically stable version. -- note, log (q + sqrt (q*q - 1)) would not work, because that would not -- be the principal value. acosh (unpackQ# -> (# x, y, z, w #)) | v2 == 0 = packQ x y z (acosh w) | otherwise = packQ (x*c) (y*c) (z*c) arg where v2 = (x * x) + (y * y) + (z * z) v = sqrt v2 w2 = w*w w1qq = 0.5 *(w2 - v2 - 1) -- real part of (q*q - 1)/2 l1qq = sqrt (w1qq*w1qq + w2*v2) -- length of (q*q - 1)/2 sp2 = l1qq + w1qq sn2 = l1qq - w1qq sp = sqrt sp2 sn = copysign (sqrt sn2) w -- choose a more stable (symbolically equiv) version dp = if 2 * w >= - sp then w + sp else w2 / ((w - sp)*(w2 + sn2)) dn = if 2 * v <= - sn || sn >= 0 then v + sn else v2 / ((sn - v)*(v2 + sp2)) (wD, vD) = case compare w1qq 0 of GT -> (dp, v * dp / sp) LT -> (w * dn / sn, dn) EQ -> (w, v) c = atan2 vD wD / v arg = 0.5 * log (wD*wD + vD*vD) {-# INLINE atanh #-} -- atanh q = 0.5 * log ( (1 + q) / (1 - q) ) atanh (unpackQ# -> (# x, y, z, w #)) | v2 == 0 = packQ x y z (atanh w) | otherwise = packQ (x*c) (y*c) (z*c) (copysign arg w) where v2 = (x * x) + (y * y) + (z * z) v = sqrt v2 w2 = w*w q2 = w2 + v2 w' = abs w - 1 c = 0.5 * atan2 (2*v) (1 - q2) / v arg = if w' == 0 then (1/0) else 0.25 * (log (1 + q2 + 2 * abs w) - log (v2 + w'*w')) -- If q is negative real, provide a fallback axis to align log. log' :: Vector Float 3 -> QFloat -> QFloat log' r (unpackQ# -> (# x, y, z, w #)) = case (x * x) + (y * y) + (z * z) of 0.0 | w >= 0 -> packQ 0 0 0 (log w) | Vec3 rx ry rz <- r -> packQ (M_PI*rx) (M_PI*ry) (M_PI*rz) (log (negate w)) mv2 -> case (# mv2 + w * w, sqrt mv2 #) of (# q2, mv #) -> case atan2 mv w / mv of l -> packQ (x * l) (y * l) (z * l) (0.5 * log q2) -- If q is negative real, provide a fallback axis to align sqrt. sqrt' :: Vector Float 3 -> QFloat -> QFloat sqrt' r (unpackQ# -> (# x, y, z, w #)) | v2 == 0 && w >= 0 = packQ x y z (sqrt w) | v2 == 0 , Vec3 rx ry rz <- r , sw <- sqrt (negate w) = packQ (sw*rx) (sw*ry) (sw*rz) 0 | otherwise = packQ (x * c) (y * c) (z * c) arg where v2 = (x * x) + (y * y) + (z * z) mq = sqrt (v2 + w * w) arg = sqrt $ 0.5 * if w >= 0 then mq + w else v2 / (mq - w) c = 0.5 / arg instance Eq QFloat where {-# INLINE (==) #-} QFloat a == QFloat b = a == b {-# INLINE (/=) #-} QFloat a /= QFloat b = a /= b