From 5a1c700e2c6883c27b63d2e0e2f60cf46c1b2c18 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 5 Jan 2023 23:57:27 +0100 Subject: [PATCH 01/21] Update .cabal file to work with recent Cabal and GHC versions --- symengine.cabal | 80 ++++++++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/symengine.cabal b/symengine.cabal index 0f33da5..3ef89d7 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -1,40 +1,58 @@ -name: symengine -version: 0.1.2.0 -synopsis: SymEngine symbolic mathematics engine for Haskell -description: Please see README.md -homepage: http://github.com/symengine/symengine.hs#readme -license: MIT -license-file: LICENSE -author: Siddharth Bhat -maintainer: siddu.druid@gmail.com -copyright: 2016 Siddharth Bhat -category: FFI, Math, Symbolic Computation -build-type: Simple --- extra-source-files: -cabal-version: >=1.10 +cabal-version: 3.0 +name: symengine +version: 0.1.2.0 +synopsis: SymEngine symbolic mathematics engine for Haskell +description: Please see README.md +homepage: https://github.com/symengine/symengine.hs +license: MIT +license-file: LICENSE +author: Siddharth Bhat +maintainer: siddu.druid@gmail.com +copyright: 2016 Siddharth Bhat +category: FFI, Math, Symbolic Computation +build-type: Simple +tested-with: GHC == 8.10.7 + +common common-options + build-depends: base >= 4.13.0.0 + + ghc-options: -Wall + -Wcompat + -Widentities + -Wincomplete-uni-patterns + -Wincomplete-record-updates + if impl(ghc >= 8.0) + ghc-options: -Wredundant-constraints + if impl(ghc >= 8.2) + ghc-options: -fhide-source-paths + if impl(ghc >= 8.4) + ghc-options: -Wmissing-export-lists + -Wpartial-fields + if impl(ghc >= 8.8) + ghc-options: -Wmissing-deriving-strategies + + default-language: Haskell2010 + default-extensions: BangPatterns + FlexibleContexts + FlexibleInstances + DerivingVia library + import: common-options hs-source-dirs: src exposed-modules: Symengine - build-depends: base >= 4.5.0 && <= 5 - default-language: Haskell2010 + extra-libraries: symengine + teuchos + gmp + stdc++ test-suite symengine-test + import: common-options type: exitcode-stdio-1.0 - hs-source-dirs: test, src + hs-source-dirs: test main-is: Spec.hs - build-depends: base >= 4.5.0 && <= 5 - , symengine >= 0.1.1 && <= 0.2 - , tasty >= 0.10.0 && <= 0.13 - , tasty-hunit >= 0.9.0 && <= 1.5 - , tasty-quickcheck >= 0.8.0 && <= 1.5 + build-depends: symengine + , tasty >= 0.10.0 + , tasty-hunit >= 0.9.0 + , tasty-quickcheck >= 0.8.0 ghc-options: -threaded -rtsopts -with-rtsopts=-N - extra-libraries: symengine stdc++ gmpxx gmp - - other-modules: Symengine - - default-language: Haskell2010 - -source-repository head - type: git - location: https://github.com/symengine/symengine.hs From 5b7b4852bbbe8c7968e160dcae8c2dbb373dc42d Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:13:01 +0100 Subject: [PATCH 02/21] First experiment with GitHub Actions --- .github/workflows/ci.yml | 78 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7abd7b6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,78 @@ +name: CI + +# Trigger the workflow on push or pull request, but only for the master branch +on: + pull_request: + push: + branches: [master] + +jobs: + build: + name: Building on ${{ matrix.os }} with ghc-${{ matrix.ghc }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-latest + cabal: latest + ghc: "8.10.7" + - os: macos-latest + cabal: latest + ghc: "8.10.7" + steps: + - uses: actions/checkout@v2 + - uses: haskell/actions/setup@v1 + name: Setup Haskell + with: + ghc-version: ${{ matrix.ghc }} + cabal-version: ${{ matrix.cabal }} + - uses: actions/cache@v3 + name: Cache ~/.cabal/store + with: + path: ~/.cabal/store + key: ${{ runner.os }}-${{ matrix.ghc }}-cabal + + - name: Install system dependencies (Linux) + if: matrix.os == 'ubuntu-18.04' || matrix.os == 'ubuntu-20.04' || matrix.os == 'ubuntu-latest' + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + cmake ninja-build g++ + - name: Install system dependencies (MacOS) + if: matrix.os == 'macos-latest' + run: | + brew install cmake ninja-build + + - uses: actions/cache@v3 + name: Cache /opt/symengine + id: cache-symengine + with: + path: /opt/symengine + key: ${{ runner.os }}-symengine + + - name: Build C++ code + if: steps.cache-symengine.outputs.cache-hit != 'true' + run: | + cd $GITHUB_WORKSPACE + git clone --depth=0 https://github.com/symengine/symengine + cd symengine + cmake -B build -G Ninja \ + -DCMAKE_INSTALL_PREFIX=/opt/symengine \ + -DCMAKE_BUILD_TYPE=Debug \ + -DWITH_SYMENGINE_ASSERT=ON \ + -DWITH_SYMENGINE_THREAD_SAFE=ON \ + -DBUILD_TESTS=OFF \ + -DBUILD_BENCHMARKS=OFF \ + -DINTEGER_CLASS=boostmp + cmake --build build + sudo cmake --build build --target install + + - name: Build Haskell code + run: | + echo "package symengine" >> cabal.project.local + echo " extra-lib-dirs: /opt/symengine" >> cabal.project.local + cabal build + + - name: Test + run: | + cabal test --test-show-details=direct From af704c6b5dd6de52b8e47da1e21f30dc02d661f4 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:17:36 +0100 Subject: [PATCH 03/21] --depth=0 is invalid for git clone; increase to 1 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7abd7b6..bd982ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: if: steps.cache-symengine.outputs.cache-hit != 'true' run: | cd $GITHUB_WORKSPACE - git clone --depth=0 https://github.com/symengine/symengine + git clone --depth=1 https://github.com/symengine/symengine cd symengine cmake -B build -G Ninja \ -DCMAKE_INSTALL_PREFIX=/opt/symengine \ From e170454ef58a5962d459348b8566b149ee1badca Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:22:09 +0100 Subject: [PATCH 04/21] Install boost in CI --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd982ca..9c22fac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,11 +37,11 @@ jobs: run: | sudo apt-get update sudo apt-get install -y --no-install-recommends \ - cmake ninja-build g++ + cmake ninja-build g++ libboost-all-dev - name: Install system dependencies (MacOS) if: matrix.os == 'macos-latest' run: | - brew install cmake ninja-build + brew install cmake ninja-build boost - uses: actions/cache@v3 name: Cache /opt/symengine From e0293c4cdd92b90340275031780080fd0082e249 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:25:45 +0100 Subject: [PATCH 05/21] ninja-build -> ninja on MacOS --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c22fac..ec863ed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,7 +41,7 @@ jobs: - name: Install system dependencies (MacOS) if: matrix.os == 'macos-latest' run: | - brew install cmake ninja-build boost + brew install cmake ninja boost - uses: actions/cache@v3 name: Cache /opt/symengine From 1c1ef2f261392fc69f62074b463a230ab3fe5131 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:31:27 +0100 Subject: [PATCH 06/21] Fix extra-lib-dirs --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ec863ed..e313e51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: - name: Build Haskell code run: | echo "package symengine" >> cabal.project.local - echo " extra-lib-dirs: /opt/symengine" >> cabal.project.local + echo " extra-lib-dirs: /opt/symengine/lib" >> cabal.project.local cabal build - name: Test From 8fa17f4a820dfa4d961b8920a1780cfbc749f9bc Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 00:50:08 +0100 Subject: [PATCH 07/21] Only link libstdc++ on Linux; test C++ code on Mac --- .github/workflows/ci.yml | 4 ++-- symengine.cabal | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e313e51..3a26aed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,12 +59,12 @@ jobs: cmake -B build -G Ninja \ -DCMAKE_INSTALL_PREFIX=/opt/symengine \ -DCMAKE_BUILD_TYPE=Debug \ - -DWITH_SYMENGINE_ASSERT=ON \ -DWITH_SYMENGINE_THREAD_SAFE=ON \ - -DBUILD_TESTS=OFF \ + -DBUILD_TESTS=ON \ -DBUILD_BENCHMARKS=OFF \ -DINTEGER_CLASS=boostmp cmake --build build + cmake --build build --target test sudo cmake --build build --target install - name: Build Haskell code diff --git a/symengine.cabal b/symengine.cabal index 3ef89d7..0e6d809 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -43,8 +43,8 @@ library exposed-modules: Symengine extra-libraries: symengine teuchos - gmp - stdc++ + if os(linux) + extra-libraries: stdc++ test-suite symengine-test import: common-options From e878c5f19e16f4b122d0a9a081a1f8a8e34e52d9 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 01:08:37 +0100 Subject: [PATCH 08/21] Link libc++ on OS X --- symengine.cabal | 2 ++ 1 file changed, 2 insertions(+) diff --git a/symengine.cabal b/symengine.cabal index 0e6d809..ca54733 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -45,6 +45,8 @@ library teuchos if os(linux) extra-libraries: stdc++ + if os(darwin) || os(osx) + extra-libraries: c++ test-suite symengine-test import: common-options From 0d762937f0ec7d4e6047c3d9ec3236e6e3b84fb7 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 01:27:36 +0100 Subject: [PATCH 09/21] Remove all tests except ascii art --- test/Spec.hs | 71 +++++++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/test/Spec.hs b/test/Spec.hs index e934667..36706f6 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,12 +1,10 @@ -import Test.Tasty -import Test.Tasty.QuickCheck as QC -import Test.Tasty.HUnit as HU - import Data.List -import Data.Ord import Data.Monoid - +import Data.Ord import Symengine as Sym +import Test.Tasty +import Test.Tasty.HUnit as HU +import Test.Tasty.QuickCheck as QC import Prelude hiding (pi) main = defaultMain tests @@ -14,40 +12,39 @@ main = defaultMain tests tests :: TestTree tests = testGroup "Tests" [unitTests] - -- These are used to check invariants that can be tested by creating -- random members of the type and then checking invariants on them -- properties :: TestTree -- properties = testGroup "Properties" [qcProps] -unitTests = testGroup "Unit tests" - [ HU.testCase "FFI Sanity Check - ASCII Art should be non-empty" $ - do - ascii_art <- Sym.ascii_art_str - HU.assertBool "ASCII art from ascii_art_str is empty" (not . null $ ascii_art) - - - , HU.testCase "Basic Constructors" $ - do - "0" @?= (show zero) - "1" @?= (show one) - "-1" @?= (show minus_one) - , HU.testCase "Basic Trignometric Functions" $ - do - let pi_over_3 = pi / 3 :: BasicSym - let pi_over_2 = pi / 2 :: BasicSym - - sin zero @?= zero - cos zero @?= one - - sin (pi / 6) @?= 1 / 2 - sin (pi / 3) @?= (3 ** (1/2)) / 2 - - cos (pi / 6) @?= (3 ** (1/2)) / 2 - cos (pi / 3) @?= 1 / 2 - - sin pi_over_2 @?= one - cos pi_over_2 @?= zero - - ] +unitTests = + testGroup + "Unit tests" + [ HU.testCase "FFI Sanity Check - ASCII Art should be non-empty" $ + do + ascii_art <- Sym.ascii_art_str + HU.assertBool "ASCII art from ascii_art_str is empty" (not . null $ ascii_art) + + -- , HU.testCase "Basic Constructors" $ + -- do + -- "0" @?= (show zero) + -- "1" @?= (show one) + -- "-1" @?= (show minus_one) + -- , HU.testCase "Basic Trignometric Functions" $ + -- do + -- let pi_over_3 = pi / 3 :: BasicSym + -- let pi_over_2 = pi / 2 :: BasicSym + + -- sin zero @?= zero + -- cos zero @?= one + -- + -- sin (pi / 6) @?= 1 / 2 + -- sin (pi / 3) @?= (3 ** (1/2)) / 2 + + -- cos (pi / 6) @?= (3 ** (1/2)) / 2 + -- cos (pi / 3) @?= 1 / 2 + + -- sin pi_over_2 @?= one + -- cos pi_over_2 @?= zero + ] From 5a2fecf3e369070ccfb260867298b24158f025fa Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 13:51:19 +0100 Subject: [PATCH 10/21] Safer wrappers --- .github/workflows/ci.yml | 2 +- src/Symengine/Internal.hs | 239 ++++++++++++++++++++++++++++++++++++++ symengine.cabal | 4 +- 3 files changed, 243 insertions(+), 2 deletions(-) create mode 100644 src/Symengine/Internal.hs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3a26aed..b6abec7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,7 @@ jobs: -DCMAKE_INSTALL_PREFIX=/opt/symengine \ -DCMAKE_BUILD_TYPE=Debug \ -DWITH_SYMENGINE_THREAD_SAFE=ON \ - -DBUILD_TESTS=ON \ + -DBUILD_TESTS=OFF \ -DBUILD_BENCHMARKS=OFF \ -DINTEGER_CLASS=boostmp cmake --build build diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs new file mode 100644 index 0000000..ce20654 --- /dev/null +++ b/src/Symengine/Internal.hs @@ -0,0 +1,239 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Symengine.Internal + ( Basic, + basicFromText, + basicToText, + constZero, + constOne, + symengineVersion, + ) +where + +import Control.Exception (bracket) +import Data.Bits (toIntegralSized) +import Data.ByteString (packCString, useAsCString) +import Data.Text (Text) +import qualified Data.Text as Text +import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Foreign.C.String (CString) +import Foreign.C.Types (CInt (..), CLong (..)) +import Foreign.ForeignPtr +import Foreign.Ptr +import Foreign.Storable +import GHC.Exts (IsString (..)) +import GHC.Generics (Generic) +import System.IO.Unsafe (unsafePerformIO) + +data Cbasic_struct + = Cbasic_struct + {-# UNPACK #-} !(Ptr ()) + {-# UNPACK #-} !(Ptr ()) + {-# UNPACK #-} !CInt + deriving stock (Show, Eq, Generic) + +instance Storable Cbasic_struct where + sizeOf _ = 24 + {-# INLINE sizeOf #-} + alignment _ = 8 + {-# INLINE alignment #-} + peek _ = error "Storable instance for Cbasic_struct does not implement peek, because you should not rely on the internal representation of it" + poke _ _ = error "Storable instance for Cbasic_struct does not implement poke, because you should not rely on the internal representation of it" + +data SymengineError + = RuntimeError + | DivideByZero + | NotImplemented + | DomainError + | ParseError + | SerializationError + deriving stock (Show, Eq, Generic) + +instance Enum SymengineError where + toEnum e = case e of + 1 -> RuntimeError + 2 -> DivideByZero + 3 -> NotImplemented + 4 -> DomainError + 5 -> ParseError + 6 -> SerializationError + _ -> error "invalid error code" + fromEnum _ = error "Enum instance of SymengineError does not provide fromEnum" + +newtype Basic = Basic (ForeignPtr Cbasic_struct) + +-- | Allocate a new 'Basic' and use the provided function for initialization. +newBasic :: (Ptr Cbasic_struct -> IO ()) -> IO Basic +newBasic initialize = do + x@(Basic fp) <- newBasicNoDestructor initialize + addForeignPtrFinalizer basic_free_stack fp + pure x + +-- | Same as 'newBasic', but do not attach a finalizer to the underlying 'ForeignPtr' +newBasicNoDestructor :: (Ptr Cbasic_struct -> IO ()) -> IO Basic +newBasicNoDestructor initialize = do + fp <- mallocForeignPtr + withForeignPtr fp (\p -> basic_new_stack p >> initialize p) + pure $ Basic fp + +withBasic :: Basic -> (Ptr Cbasic_struct -> IO a) -> IO a +withBasic (Basic fp) = withForeignPtr fp + +unaryOp :: (Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO ()) -> Basic -> Basic +unaryOp f x = unsafePerformIO $! + withBasic x $ \xPtr -> + newBasic (\out -> f out xPtr) +{-# NOINLINE unaryOp #-} + +binaryOp :: (Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO ()) -> Basic -> Basic -> Basic +binaryOp f x y = unsafePerformIO $! + withBasic x $ \xPtr -> + withBasic y $ \yPtr -> + newBasic (\out -> f out xPtr yPtr) +{-# NOINLINE binaryOp #-} + +foreign import ccall unsafe "basic_new_stack" + basic_new_stack :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "&basic_free_stack" + basic_free_stack :: FunPtr (Ptr Cbasic_struct -> IO ()) + +foreign import ccall unsafe "basic_str" + basic_str :: Ptr Cbasic_struct -> IO CString + +foreign import ccall unsafe "basic_str_free" + basic_str_free :: CString -> IO () + +basicToText :: Basic -> Text +basicToText x = unsafePerformIO $ + withBasic x $ \p -> + bracket (basic_str p) basic_str_free $ \cStr -> do + -- NOTE: need to force evaluation before the C string is freed + !r <- peekUtf8 cStr + pure r + +basicFromText :: Text -> Maybe Basic +basicFromText s = unsafePerformIO $! + withUtf8 s $ \cStr -> do + x <- newBasic (\_ -> pure ()) + withBasic x $ \p -> do + e <- basic_parse p cStr + if e /= 0 + then case toEnum (fromIntegral e) of + ParseError -> pure Nothing + otherError -> error $ "basic_parse of " <> show s <> " failed with: " <> show otherError + else pure (Just x) + +instance Show Basic where + showsPrec p x = + showParen (p > 0) + . showString + . Text.unpack + . basicToText + $ x + +instance IsString Basic where + fromString s = case (basicFromText . Text.pack) s of + Just x -> x + Nothing -> error $ "could not convert " <> show s <> " to Basic" + +instance Eq Basic where + (==) a b = unsafePerformIO $! + withBasic a $ \aPtr -> + withBasic b $ \bPtr -> + toEnum . fromIntegral <$> basic_eq aPtr bPtr + +basicFromInt :: Int -> Basic +basicFromInt n = + unsafePerformIO $! do + x <- newBasic (\_ -> pure ()) + withBasic x $ \p -> do + e <- integer_set_si p (fromIntegral n) + if e /= 0 + then error $ "integer_set_si failed: " <> show (toEnum (fromIntegral e) :: SymengineError) + else pure x + +instance Num Basic where + (+) = binaryOp basic_add + (-) = binaryOp basic_sub + (*) = binaryOp basic_mul + negate = unaryOp basic_neg + abs = unaryOp basic_abs + signum = error "Num instance of Basic does not implement signum" + fromInteger n = case toIntegralSized n of + Just k -> basicFromInt k + Nothing -> error $ "integer overflow in fromInteger " <> show n + +constZero :: Basic +constZero = unsafePerformIO $! newBasicNoDestructor basic_const_zero + +constOne :: Basic +constOne = unsafePerformIO $! newBasicNoDestructor basic_const_one + +foreign import ccall unsafe "basic_const_zero" + basic_const_zero :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_one" + basic_const_one :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_minus_one" + basic_const_minus_one :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_I" + basic_const_I :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_pi" + basic_const_pi :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_E" + basic_const_E :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_EulerGamma" + basic_const_EulerGamma :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_Catalan" + basic_const_Catalan :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_const_GoldenRatio" + basic_const_GoldenRatio :: Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "integer_set_si" + integer_set_si :: Ptr Cbasic_struct -> CLong -> IO CInt + +foreign import ccall unsafe "basic_parse" + basic_parse :: Ptr Cbasic_struct -> CString -> IO CInt + +foreign import ccall unsafe "basic_eq" + basic_eq :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import ccall unsafe "basic_add" + basic_add :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_sub" + basic_sub :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_mul" + basic_mul :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_abs" + basic_abs :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () + +foreign import ccall unsafe "basic_neg" + basic_neg :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () + +-- | Unicode-safe alternative to 'peekCString' +peekUtf8 :: CString -> IO Text +peekUtf8 = fmap decodeUtf8 . packCString + +-- | Unicode-safe alternative to 'withCString' +withUtf8 :: Text -> (CString -> IO a) -> IO a +withUtf8 x = useAsCString (encodeUtf8 x) + +-- | Version of the underlying SymEngine C++ library +symengineVersion :: Text +symengineVersion = unsafePerformIO $ peekUtf8 =<< symengine_version +{-# NOINLINE symengineVersion #-} + +foreign import ccall unsafe "symengine_version" + symengine_version :: IO CString diff --git a/symengine.cabal b/symengine.cabal index ca54733..8e0491d 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -41,8 +41,10 @@ library import: common-options hs-source-dirs: src exposed-modules: Symengine + Symengine.Internal + build-depends: text + , bytestring extra-libraries: symengine - teuchos if os(linux) extra-libraries: stdc++ if os(darwin) || os(osx) From c11b3f5fcfb53cf0dc9216d9cb4ba3b8646cc7aa Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 13:53:58 +0100 Subject: [PATCH 11/21] Build shared instead of static libraries --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b6abec7..fae922e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,6 +59,7 @@ jobs: cmake -B build -G Ninja \ -DCMAKE_INSTALL_PREFIX=/opt/symengine \ -DCMAKE_BUILD_TYPE=Debug \ + -DBUILD_SHARED_LIBS=ON \ -DWITH_SYMENGINE_THREAD_SAFE=ON \ -DBUILD_TESTS=OFF \ -DBUILD_BENCHMARKS=OFF \ From d990fd7da78fed56e7bd1f0f5d0446efaf05155a Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 6 Jan 2023 14:07:26 +0100 Subject: [PATCH 12/21] Set LD_LIBRARY_PATH and DYLD_LIBRARY_PATH --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fae922e..ea72220 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,4 +76,6 @@ jobs: - name: Test run: | + export LD_LIBRARY_PATH=/opt/symengine/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=/opt/symengine/lib:$DYLD_LIBRARY_PATH cabal test --test-show-details=direct From 7a607de60ccfb4b5edcb03195440193f4b0958fa Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Sat, 7 Jan 2023 23:29:14 +0100 Subject: [PATCH 13/21] Use CApiFFI for safety; Finish Num, Fractional, and Floating for Basic --- src/Symengine/Internal.hs | 270 +++++++++++++++++++++++++++++++------- 1 file changed, 219 insertions(+), 51 deletions(-) diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs index ce20654..0ecfa50 100644 --- a/src/Symengine/Internal.hs +++ b/src/Symengine/Internal.hs @@ -1,5 +1,7 @@ +{-# LANGUAGE CApiFFI #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedStrings #-} module Symengine.Internal ( Basic, @@ -7,13 +9,20 @@ module Symengine.Internal basicToText, constZero, constOne, + isNumber, + isPositive, + isNegative, + isZero, + isComplex, symengineVersion, ) where import Control.Exception (bracket) +import Control.Monad (when) import Data.Bits (toIntegralSized) import Data.ByteString (packCString, useAsCString) +import Data.Ratio import Data.Text (Text) import qualified Data.Text as Text import Data.Text.Encoding (decodeUtf8, encodeUtf8) @@ -24,9 +33,10 @@ import Foreign.Ptr import Foreign.Storable import GHC.Exts (IsString (..)) import GHC.Generics (Generic) +import GHC.Stack (HasCallStack) import System.IO.Unsafe (unsafePerformIO) -data Cbasic_struct +data {-# CTYPE "symengine/cwrapper.h" "basic_struct" #-} Cbasic_struct = Cbasic_struct {-# UNPACK #-} !(Ptr ()) {-# UNPACK #-} !(Ptr ()) @@ -80,30 +90,35 @@ newBasicNoDestructor initialize = do withBasic :: Basic -> (Ptr Cbasic_struct -> IO a) -> IO a withBasic (Basic fp) = withForeignPtr fp +checkError :: HasCallStack => Text -> CInt -> IO () +checkError name e + | e == 0 = pure () + | otherwise = + error $ + Text.unpack name <> " failed with: " <> show (toEnum (fromIntegral e) :: SymengineError) + unaryOp :: (Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO ()) -> Basic -> Basic -unaryOp f x = unsafePerformIO $! +unaryOp = unaryOp' pure + +unaryOp' :: (a -> IO ()) -> (Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO a) -> Basic -> Basic +unaryOp' check f x = unsafePerformIO $! withBasic x $ \xPtr -> - newBasic (\out -> f out xPtr) -{-# NOINLINE unaryOp #-} + newBasic (\out -> check =<< f out xPtr) binaryOp :: (Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO ()) -> Basic -> Basic -> Basic -binaryOp f x y = unsafePerformIO $! +binaryOp = binaryOp' pure + +binaryOp' :: (a -> IO ()) -> (Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO a) -> Basic -> Basic -> Basic +binaryOp' check f x y = unsafePerformIO $! withBasic x $ \xPtr -> withBasic y $ \yPtr -> - newBasic (\out -> f out xPtr yPtr) -{-# NOINLINE binaryOp #-} - -foreign import ccall unsafe "basic_new_stack" - basic_new_stack :: Ptr Cbasic_struct -> IO () + newBasic (\out -> check =<< f out xPtr yPtr) -foreign import ccall unsafe "&basic_free_stack" - basic_free_stack :: FunPtr (Ptr Cbasic_struct -> IO ()) - -foreign import ccall unsafe "basic_str" - basic_str :: Ptr Cbasic_struct -> IO CString - -foreign import ccall unsafe "basic_str_free" - basic_str_free :: CString -> IO () +queryOp :: (Ptr Cbasic_struct -> IO CInt) -> Basic -> Bool +queryOp f x = + unsafePerformIO $! + withBasic x $ + fmap (toEnum . fromIntegral) . f basicToText :: Basic -> Text basicToText x = unsafePerformIO $ @@ -149,10 +164,23 @@ basicFromInt n = unsafePerformIO $! do x <- newBasic (\_ -> pure ()) withBasic x $ \p -> do - e <- integer_set_si p (fromIntegral n) - if e /= 0 - then error $ "integer_set_si failed: " <> show (toEnum (fromIntegral e) :: SymengineError) - else pure x + checkError "integer_set_si" =<< integer_set_si p (fromIntegral n) + pure x + +isZero :: Basic -> Bool +isZero x = queryOp number_is_zero x + +isPositive :: Basic -> Bool +isPositive x = queryOp number_is_positive x + +isNegative :: Basic -> Bool +isNegative x = queryOp number_is_negative x + +isNumber :: Basic -> Bool +isNumber = queryOp is_a_Number + +isComplex :: Basic -> Bool +isComplex = queryOp is_a_Complex instance Num Basic where (+) = binaryOp basic_add @@ -160,80 +188,220 @@ instance Num Basic where (*) = binaryOp basic_mul negate = unaryOp basic_neg abs = unaryOp basic_abs - signum = error "Num instance of Basic does not implement signum" + signum _ = error "Num instance of Basic does not implement signum" fromInteger n = case toIntegralSized n of Just k -> basicFromInt k Nothing -> error $ "integer overflow in fromInteger " <> show n +instance Fractional Basic where + (/) = binaryOp basic_div + fromRational r = + binaryOp' + (checkError "rational_set") + rational_set + (fromInteger (numerator r)) + (fromInteger (denominator r)) + recip r = constOne / r + +instance Floating Basic where + pi = constPi + exp = unaryOp' (checkError "basic_exp") basic_exp + log = unaryOp' (checkError "basic_log") basic_log + (**) = binaryOp' (checkError "basic_pow") basic_pow + sqrt = unaryOp' (checkError "basic_sqrt") basic_sqrt + sin = unaryOp' (checkError "basic_sin") basic_sin + cos = unaryOp' (checkError "basic_cos") basic_cos + tan = unaryOp' (checkError "basic_tan") basic_tan + asin = unaryOp' (checkError "basic_asin") basic_asin + acos = unaryOp' (checkError "basic_acos") basic_acos + atan = unaryOp' (checkError "basic_atan") basic_atan + sinh = unaryOp' (checkError "basic_sinh") basic_sinh + cosh = unaryOp' (checkError "basic_cosh") basic_cosh + tanh = unaryOp' (checkError "basic_tanh") basic_tanh + asinh = unaryOp' (checkError "basic_asinh") basic_asinh + acosh = unaryOp' (checkError "basic_acosh") basic_acosh + atanh = unaryOp' (checkError "basic_atanh") basic_atanh + constZero :: Basic constZero = unsafePerformIO $! newBasicNoDestructor basic_const_zero constOne :: Basic constOne = unsafePerformIO $! newBasicNoDestructor basic_const_one -foreign import ccall unsafe "basic_const_zero" +constPi :: Basic +constPi = unsafePerformIO $! newBasicNoDestructor basic_const_pi + +-- | Unicode-safe alternative to 'peekCString' +peekUtf8 :: CString -> IO Text +peekUtf8 = fmap decodeUtf8 . packCString + +-- | Unicode-safe alternative to 'withCString' +withUtf8 :: Text -> (CString -> IO a) -> IO a +withUtf8 x = useAsCString (encodeUtf8 x) + +-- | Version of the underlying SymEngine C++ library +symengineVersion :: Text +symengineVersion = unsafePerformIO $ peekUtf8 =<< symengine_version +{-# NOINLINE symengineVersion #-} + +foreign import capi unsafe "symengine/cwrapper.h basic_new_stack" + basic_new_stack :: Ptr Cbasic_struct -> IO () + +foreign import capi unsafe "symengine/cwrapper.h &basic_free_stack" + basic_free_stack :: FunPtr (Ptr Cbasic_struct -> IO ()) + +foreign import capi unsafe "symengine/cwrapper.h basic_str" + basic_str :: Ptr Cbasic_struct -> IO CString + +foreign import capi unsafe "symengine/cwrapper.h basic_str_free" + basic_str_free :: CString -> IO () + +foreign import capi unsafe "symengine/cwrapper.h basic_const_zero" basic_const_zero :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_one" +foreign import capi unsafe "symengine/cwrapper.h basic_const_one" basic_const_one :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_minus_one" +foreign import capi unsafe "symengine/cwrapper.h basic_const_minus_one" basic_const_minus_one :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_I" +foreign import capi unsafe "symengine/cwrapper.h basic_const_I" basic_const_I :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_pi" +foreign import capi unsafe "symengine/cwrapper.h basic_const_pi" basic_const_pi :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_E" +foreign import capi unsafe "symengine/cwrapper.h basic_const_E" basic_const_E :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_EulerGamma" +foreign import capi unsafe "symengine/cwrapper.h basic_const_EulerGamma" basic_const_EulerGamma :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_Catalan" +foreign import capi unsafe "symengine/cwrapper.h basic_const_Catalan" basic_const_Catalan :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_const_GoldenRatio" +foreign import capi unsafe "symengine/cwrapper.h basic_const_GoldenRatio" basic_const_GoldenRatio :: Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "integer_set_si" +foreign import capi unsafe "symengine/cwrapper.h integer_set_si" integer_set_si :: Ptr Cbasic_struct -> CLong -> IO CInt -foreign import ccall unsafe "basic_parse" +foreign import capi unsafe "symengine/cwrapper.h rational_set" + rational_set :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_parse" basic_parse :: Ptr Cbasic_struct -> CString -> IO CInt -foreign import ccall unsafe "basic_eq" +foreign import capi unsafe "symengine/cwrapper.h basic_eq" basic_eq :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt -foreign import ccall unsafe "basic_add" +foreign import capi unsafe "symengine/cwrapper.h basic_add" basic_add :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_sub" +foreign import capi unsafe "symengine/cwrapper.h basic_sub" basic_sub :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_mul" +foreign import capi unsafe "symengine/cwrapper.h basic_mul" basic_mul :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_abs" +foreign import capi unsafe "symengine/cwrapper.h basic_div" + basic_div :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () + +foreign import capi unsafe "symengine/cwrapper.h basic_pow" + basic_pow :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_abs" basic_abs :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () -foreign import ccall unsafe "basic_neg" +foreign import capi unsafe "symengine/cwrapper.h basic_neg" basic_neg :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () --- | Unicode-safe alternative to 'peekCString' -peekUtf8 :: CString -> IO Text -peekUtf8 = fmap decodeUtf8 . packCString +foreign import capi unsafe "symengine/cwrapper.h basic_sqrt" + basic_sqrt :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt --- | Unicode-safe alternative to 'withCString' -withUtf8 :: Text -> (CString -> IO a) -> IO a -withUtf8 x = useAsCString (encodeUtf8 x) +foreign import capi unsafe "symengine/cwrapper.h basic_sin" + basic_sin :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt --- | Version of the underlying SymEngine C++ library -symengineVersion :: Text -symengineVersion = unsafePerformIO $ peekUtf8 =<< symengine_version -{-# NOINLINE symengineVersion #-} +foreign import capi unsafe "symengine/cwrapper.h basic_cos" + basic_cos :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_tan" + basic_tan :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_asin" + basic_asin :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_acos" + basic_acos :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_atan" + basic_atan :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_sinh" + basic_sinh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_cosh" + basic_cosh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_tanh" + basic_tanh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_asinh" + basic_asinh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_acosh" + basic_acosh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_atanh" + basic_atanh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_exp" + basic_exp :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_log" + basic_log :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h number_is_zero" + number_is_zero :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h number_is_negative" + number_is_negative :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h number_is_positive" + number_is_positive :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h number_is_complex" + number_is_complex :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_Number" + is_a_Number :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_Integer" + is_a_Integer :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_Rational" + is_a_Rational :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_Symbol" + is_a_Symbol :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_Complex" + is_a_Complex :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_RealDouble" + is_a_RealDouble :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_ComplexDouble" + is_a_ComplexDouble :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_RealMPFR" + is_a_RealMPFR :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_ComplexMPC" + is_a_ComplexMPC :: Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h is_a_Set" + is_a_Set :: Ptr Cbasic_struct -> IO CInt -foreign import ccall unsafe "symengine_version" +foreign import ccall unsafe "symengine/cwrapper.h symengine_version" symengine_version :: IO CString From f8f50f0cc3dc9422e20870102bbb3856ca4312f1 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Wed, 11 Jan 2023 00:04:24 +0100 Subject: [PATCH 14/21] Remove old files --- .travis.yml | 280 ---------------------------------------------------- Setup.hs | 2 - stack.yaml | 29 ------ 3 files changed, 311 deletions(-) delete mode 100644 .travis.yml delete mode 100644 Setup.hs delete mode 100644 stack.yaml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index bf75be4..0000000 --- a/.travis.yml +++ /dev/null @@ -1,280 +0,0 @@ -# Copy these contents into the root directory of your Github project in a file -# named .travis.yml - -# Use new container infrastructure to enable caching -sudo: false - -# Choose a lightweight base image; we provide our own build tools. -language: c - -addons: - apt: - sources: - - ubuntu-toolchain-r-test - packages: - - libgmp-dev - - libmpfr-dev - - libmpc-dev - - binutils-dev - - g++-4.7 - - gcc - -# Caching so the next build will be fast too. -cache: - directories: - - $HOME/.ghc - - $HOME/.cabal - - $HOME/.stack - -# The different configurations we want to test. We have BUILD=cabal which uses -# cabal-install, and BUILD=stack which uses Stack. More documentation on each -# of those below. -# -# We set the compiler values here to tell Travis to use a different -# cache file per set of arguments. -# -# If you need to have different apt packages for each combination in the -# matrix, you can use a line such as: -# addons: {apt: {packages: [libfcgi-dev,libgmp-dev]}} -matrix: - include: - # We grab the appropriate GHC and cabal-install versions from hvr's PPA. See: - # https://github.com/hvr/multi-ghc-travis - #- env: BUILD=cabal GHCVER=7.0.4 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - # compiler: ": #GHC 7.0.4" - # addons: {apt: {packages: [cabal-install-1.16,ghc-7.0.4,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - #- env: BUILD=cabal GHCVER=7.2.2 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - # compiler: ": #GHC 7.2.2" - # addons: {apt: {packages: [cabal-install-1.16,ghc-7.2.2,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.4.2 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.4.2" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.16,ghc-7.4.2,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.6.3 CABALVER=1.16 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.6.3" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.16,ghc-7.6.3,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.8.4 CABALVER=1.18 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.8.4" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.18,ghc-7.8.4,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - env: BUILD=cabal GHCVER=7.10.3 CABALVER=1.22 HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC 7.10.3" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-1.22,ghc-7.10.3,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - # Build with the newest GHC and cabal-install. This is an accepted failure, - # see below. - - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 - compiler: ": #GHC HEAD" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, cabal-install-head,ghc-head,happy-1.19.5,alex-3.1.7], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - # The Stack builds. We can pass in arbitrary Stack arguments via the ARGS - # variable, such as using --stack-yaml to point to a different file. - - env: BUILD=stack ARGS=" " - compiler: ": #stack default" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.10.3], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - - env: BUILD=stack ARGS="--resolver lts-2" - compiler: ": #stack 7.8.4" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.8.4], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - - env: BUILD=stack ARGS="--resolver lts-3" - compiler: ": #stack 7.10.2" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.10.2], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - - env: BUILD=stack ARGS="--resolver lts-5" - compiler: ": #stack 7.10.3" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, ghc-7.10.3], sources: [hvr-ghc, ubuntu-toolchain-r-test]}} - - # Nightly builds are allowed to fail - - env: BUILD=stack ARGS="--resolver nightly" - compiler: ": #stack nightly" - addons: {apt: {packages: [libgmp-dev, - libmpfr-dev, - libmpc-dev, - binutils-dev, - g++-4.7, - gcc, libgmp-dev]}} - - # Build on OS X in addition to Linux - - env: BUILD=stack ARGS=" " - compiler: ": #stack default osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-2" - compiler: ": #stack 7.8.4 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-3" - compiler: ": #stack 7.10.2 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver lts-5" - compiler: ": #stack 7.10.3 osx" - os: osx - - - env: BUILD=stack ARGS="--resolver nightly" - compiler: ": #stack nightly osx" - os: osx - - allow_failures: - - env: BUILD=cabal GHCVER=head CABALVER=head HAPPYVER=1.19.5 ALEXVER=3.1.7 - - env: BUILD=stack ARGS="--resolver nightly" - - -install: -# SYMENGINE INSTALL PHASE -# ----------------------- -# Download and install SymEngine -- cd $HOME && git clone https://github.com/symengine/symengine.git - -# Setup C compiler variables -# The reason we need to do this is because our build system is haskell, so none of the -# C variables are set. -- | - if [ `uname` = "Darwin" ] - then - export CC="clang" && export CXX="clang++" - else - export CC="gcc" && export CXX="g++-4.7" - fi - -- | - set -ex - export TEST_CPP="no" - cd $HOME/symengine - source bin/install_travis.sh - bin/test_travis.sh - -# EXPORT PHASE -# ------------ -# Export environment variables related to SymEngine's library and includes -- | - # $our_install_dir is exported by test_travis.sh from symengine - set -ex - export SYMENGINE_LIB_ARGS="--extra-lib-dirs=$our_install_dir/lib/" - export SYMENGINE_INCLUDE_ARGS="--extra-include-dirs=$our_install_dir/include/" - cd $TRAVIS_BUILD_DIR - -# GHC INSTALL PHASE -# ----------------- -# Install Stack if needed, then install GHC - -# Using compiler above sets CC to an invalid value, so unset it -- unset CC -# We want to always allow newer versions of packages when building on GHC HEAD -- CABALARGS="" -- if [ "x$GHCVER" = "xhead" ]; then CABALARGS=--allow-newer; fi - -# Download and unpack the stack executable -- export PATH=/opt/ghc/$GHCVER/bin:/opt/cabal/$CABALVER/bin:$HOME/.local/bin:/opt/alex/$ALEXVER/bin:/opt/happy/$HAPPYVER/bin:$HOME/.cabal/bin:$PATH -- mkdir -p ~/.local/bin -- | - if [ `uname` = "Darwin" ] - then - travis_retry curl --insecure -L https://www.stackage.org/stack/osx-x86_64 | tar xz --strip-components=1 --include '*/stack' -C ~/.local/bin - else - travis_retry curl -L https://www.stackage.org/stack/linux-x86_64 | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack' - fi - # Use the more reliable S3 mirror of Hackage - mkdir -p $HOME/.cabal - echo 'remote-repo: hackage.haskell.org:http://hackage.fpcomplete.com/' > $HOME/.cabal/config - echo 'remote-repo-cache: $HOME/.cabal/packages' >> $HOME/.cabal/config - - if [ "$CABALVER" != "1.16" ] - then - echo 'jobs: $ncpus' >> $HOME/.cabal/config - fi - -# Get the list of packages from the stack.yaml file -- PACKAGES=$(stack --install-ghc query locals | grep '^ *path' | sed 's@^ *path:@@') - -- echo "$(ghc --version) [$(ghc --print-project-git-commit-id 2> /dev/null || echo '?')]" -- if [ -f configure.ac ]; then autoreconf -i; fi -- | - set -ex - case "$BUILD" in - stack) - stack --no-terminal --install-ghc $ARGS test --only-dependencies - ;; - cabal) - cabal --version - travis_retry cabal update - cabal install --only-dependencies --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS - ;; - esac - set +ex - - -script: -- | - set -ex - case "$BUILD" in - stack) - stack --no-terminal $ARGS $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS test --haddock --no-haddock-deps - ;; - cabal) - cabal update - cabal configure --enable-tests --enable-benchmarks --ghc-options=-O0 $CABALARGS $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS - cabal build - # run the test suite - cabal test --show-details=always - - # install after building the library - # cabal install --enable-tests --enable-benchmarks --force-reinstalls --ghc-options=-O0 --reorder-goals --max-backjumps=-1 $CABALARGS $PACKAGES $SYMENGINE_LIB_ARGS $SYMENGINE_INCLUDE_ARGS - - ORIGDIR=$(pwd) - for dir in $PACKAGES - do - cd $dir - cabal check || [ "$CABALVER" == "1.16" ] - cabal sdist - SRC_TGZ=$(cabal info . | awk '{print $2;exit}').tar.gz && \ - (cd dist && cabal install --force-reinstalls "$SRC_TGZ") - cd $ORIGDIR - done - ;; - esac - set +ex diff --git a/Setup.hs b/Setup.hs deleted file mode 100644 index 9a994af..0000000 --- a/Setup.hs +++ /dev/null @@ -1,2 +0,0 @@ -import Distribution.Simple -main = defaultMain diff --git a/stack.yaml b/stack.yaml deleted file mode 100644 index 7b5a9de..0000000 --- a/stack.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# For more information, see: https://github.com/commercialhaskell/stack/blob/master/doc/yaml_configuration.md - -# Specifies the GHC version and set of packages available (e.g., lts-3.5, nightly-2015-09-21, ghc-7.10.2) -resolver: lts-3.2 - -# Local packages, usually specified by relative directory name -packages: -- '.' - -# Packages to be pulled from upstream that are not in the resolver (e.g., acme-missiles-0.3) -extra-deps: [] - -# Override default flag values for local packages and extra-deps -flags: {} - -# Control whether we use the GHC we find on the path -# system-ghc: true - -# Require a specific version of stack, using version ranges -# require-stack-version: -any # Default -# require-stack-version: >= 0.1.4.0 - -# Override the architecture used by stack, especially useful on Windows -# arch: i386 -# arch: x86_64 - -# Extra directories used by stack for building -# extra-include-dirs: [/path/to/dir] -# extra-lib-dirs: [/path/to/dir] From d64927e805b7985b471832ac3967a27dbebba5ac Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Wed, 11 Jan 2023 00:06:25 +0100 Subject: [PATCH 15/21] Add more tests --- .github/workflows/ci.yml | 1 + src/Symengine/Internal.hs | 246 ++++++++++++++++++++++++++++++++++++-- test/Spec.hs | 82 ++++++++----- 3 files changed, 294 insertions(+), 35 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ea72220..0423987 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,6 +72,7 @@ jobs: run: | echo "package symengine" >> cabal.project.local echo " extra-lib-dirs: /opt/symengine/lib" >> cabal.project.local + echo " extra-include-dirs: /opt/symengine/include" >> cabal.project.local cabal build - name: Test diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs index 0ecfa50..8ac7bd1 100644 --- a/src/Symengine/Internal.hs +++ b/src/Symengine/Internal.hs @@ -1,25 +1,52 @@ {-# LANGUAGE CApiFFI #-} {-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} module Symengine.Internal ( Basic, basicFromText, basicToText, - constZero, - constOne, + mkFunction, + im, + + -- ** Predicates isNumber, + isInteger, + isRational, + isComplex, + isSymbol, isPositive, isNegative, isZero, - isComplex, + + -- ** Complex numbers + realPart, + imagPart, + + -- ** Vector + Vec, + vecSize, + vecIndex, + + -- ** Set + Set, + setSize, + setElem, + + -- ** Utilities + freeSymbols, + functionSymbols, symengineVersion, + + -- ** Reexports + toList, + fromList, + fromString, ) where import Control.Exception (bracket) -import Control.Monad (when) import Data.Bits (toIntegralSized) import Data.ByteString (packCString, useAsCString) import Data.Ratio @@ -27,11 +54,11 @@ import Data.Text (Text) import qualified Data.Text as Text import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Foreign.C.String (CString) -import Foreign.C.Types (CInt (..), CLong (..)) +import Foreign.C.Types (CInt (..), CLong (..), CSize (..)) import Foreign.ForeignPtr import Foreign.Ptr import Foreign.Storable -import GHC.Exts (IsString (..)) +import GHC.Exts (IsList (..), IsString (..)) import GHC.Generics (Generic) import GHC.Stack (HasCallStack) import System.IO.Unsafe (unsafePerformIO) @@ -51,6 +78,10 @@ instance Storable Cbasic_struct where peek _ = error "Storable instance for Cbasic_struct does not implement peek, because you should not rely on the internal representation of it" poke _ _ = error "Storable instance for Cbasic_struct does not implement poke, because you should not rely on the internal representation of it" +data {-# CTYPE "symengine/cwrapper.h" "CVecBasic" #-} CVecBasic + +data {-# CTYPE "symengine/cwrapper.h" "CSetBasic" #-} CSetBasic + data SymengineError = RuntimeError | DivideByZero @@ -73,6 +104,10 @@ instance Enum SymengineError where newtype Basic = Basic (ForeignPtr Cbasic_struct) +newtype Vec = Vec (ForeignPtr CVecBasic) + +newtype Set = Set (ForeignPtr CSetBasic) + -- | Allocate a new 'Basic' and use the provided function for initialization. newBasic :: (Ptr Cbasic_struct -> IO ()) -> IO Basic newBasic initialize = do @@ -90,6 +125,91 @@ newBasicNoDestructor initialize = do withBasic :: Basic -> (Ptr Cbasic_struct -> IO a) -> IO a withBasic (Basic fp) = withForeignPtr fp +-- | Allocate a new 'Vec'. +newVec :: IO Vec +newVec = pure . Vec =<< newForeignPtr vecbasic_free =<< vecbasic_new + +withVec :: Vec -> (Ptr CVecBasic -> IO a) -> IO a +withVec (Vec fp) = withForeignPtr fp + +vecSize :: Vec -> Int +vecSize x = + unsafePerformIO . withVec x $ + fmap fromIntegral . vecbasic_size + +vecGet :: HasCallStack => Vec -> Int -> IO Basic +vecGet v i = withVec v $ \vPtr -> newBasic $ \xPtr -> + checkError "vecbasic_get" =<< vecbasic_get vPtr (fromIntegral i) xPtr + +vecIndex :: HasCallStack => Vec -> Int -> Basic +vecIndex v i = unsafePerformIO $! vecGet v i + +vecSet :: HasCallStack => Vec -> Int -> Basic -> IO () +vecSet v i x = withVec v $ \vPtr -> withBasic x $ \xPtr -> + checkError "vecbasic_set" =<< vecbasic_set vPtr (fromIntegral i) xPtr + +vecPushBack :: HasCallStack => Vec -> Basic -> IO () +vecPushBack v x = withVec v $ \vPtr -> withBasic x $ \xPtr -> + checkError "vecbasic_push_back" =<< vecbasic_push_back vPtr xPtr + +instance IsList Vec where + type Item Vec = Basic + toList v = unsafePerformIO $ go (vecSize v - 1) [] + where + go !i acc + | i >= 0 = do + !x <- vecGet v i + go (i - 1) (x : acc) + | otherwise = pure acc + fromList list = unsafePerformIO $ do + v <- newVec + let go [] = pure () + go (x : xs) = vecPushBack v x >> go xs + go list + pure v + +newSet :: IO Set +newSet = pure . Set =<< newForeignPtr setbasic_free =<< setbasic_new + +withSet :: Set -> (Ptr CSetBasic -> IO a) -> IO a +withSet (Set fp) = withForeignPtr fp + +setSize :: Set -> Int +setSize x = unsafePerformIO . withSet x $ fmap fromIntegral . setbasic_size + +setGet :: Set -> Int -> IO Basic +setGet s i = withSet s $ \sPtr -> newBasic $ \xPtr -> + setbasic_get sPtr (fromIntegral i) xPtr + +setInsert :: Set -> Basic -> IO Bool +setInsert s x = withSet s $ \sPtr -> withBasic x $ \xPtr -> + toEnum . fromIntegral <$> setbasic_insert sPtr xPtr + +setFind :: Set -> Basic -> IO Bool +setFind s x = withSet s $ \sPtr -> withBasic x $ + fmap (toEnum . fromIntegral) . setbasic_find sPtr + +setElem :: Basic -> Set -> Bool +setElem x s = unsafePerformIO $! setFind s x + +instance IsList Set where + type Item Set = Basic + toList s = unsafePerformIO $ go (setSize s - 1) [] + where + go !i acc + | i >= 0 = do + !x <- setGet s i + go (i - 1) (x : acc) + | otherwise = pure acc + fromList list = unsafePerformIO $ do + s <- newSet + let go [] = pure () + go (x : xs) = do + _ <- setInsert s x + go xs + go list + pure s + checkError :: HasCallStack => Text -> CInt -> IO () checkError name e | e == 0 = pure () @@ -179,6 +299,15 @@ isNegative x = queryOp number_is_negative x isNumber :: Basic -> Bool isNumber = queryOp is_a_Number +isInteger :: Basic -> Bool +isInteger = queryOp is_a_Integer + +isRational :: Basic -> Bool +isRational = queryOp is_a_Rational + +isSymbol :: Basic -> Bool +isSymbol = queryOp is_a_Symbol + isComplex :: Basic -> Bool isComplex = queryOp is_a_Complex @@ -231,6 +360,40 @@ constOne = unsafePerformIO $! newBasicNoDestructor basic_const_one constPi :: Basic constPi = unsafePerformIO $! newBasicNoDestructor basic_const_pi +im :: Basic +im = unsafePerformIO $! newBasicNoDestructor basic_const_I + +realPart :: HasCallStack => Basic -> Basic +realPart = unaryOp' (checkError "complex_base_real_part") complex_base_real_part + +imagPart :: HasCallStack => Basic -> Basic +imagPart = unaryOp' (checkError "complex_base_imaginary_part") complex_base_imaginary_part + +mkFunction :: HasCallStack => Text -> Vec -> Basic +mkFunction name args = + unsafePerformIO $! + newBasic $ \xPtr -> + withUtf8 name $ \namePtr -> + withVec args $ \argsPtr -> + checkError "function_symbol_set" + =<< function_symbol_set xPtr namePtr argsPtr + +freeSymbols :: HasCallStack => Basic -> Set +freeSymbols x = + unsafePerformIO $! do + s <- newSet + withBasic x $ \xPtr -> withSet s $ \sPtr -> + checkError "basic_free_symbols" =<< basic_free_symbols xPtr sPtr + pure s + +functionSymbols :: HasCallStack => Basic -> Set +functionSymbols x = + unsafePerformIO $! do + s <- newSet + withBasic x $ \xPtr -> withSet s $ \sPtr -> + checkError "basic_function_symbols" =<< basic_function_symbols sPtr xPtr + pure s + -- | Unicode-safe alternative to 'peekCString' peekUtf8 :: CString -> IO Text peekUtf8 = fmap decodeUtf8 . packCString @@ -403,5 +566,74 @@ foreign import capi unsafe "symengine/cwrapper.h is_a_ComplexMPC" foreign import capi unsafe "symengine/cwrapper.h is_a_Set" is_a_Set :: Ptr Cbasic_struct -> IO CInt +foreign import capi unsafe "symengine/cwrapper.h basic_get_args" + basic_get_args :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_free_symbols" + basic_free_symbols :: Ptr Cbasic_struct -> Ptr CSetBasic -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h basic_function_symbols" + basic_function_symbols :: Ptr CSetBasic -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h function_symbol_set" + function_symbol_set :: Ptr Cbasic_struct -> CString -> Ptr CVecBasic -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h complex_base_real_part" + complex_base_real_part :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h complex_base_imaginary_part" + complex_base_imaginary_part :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h vecbasic_new" + vecbasic_new :: IO (Ptr CVecBasic) + +foreign import capi unsafe "symengine/cwrapper.h &vecbasic_free" + vecbasic_free :: FunPtr (Ptr CVecBasic -> IO ()) + +foreign import capi unsafe "symengine/cwrapper.h vecbasic_push_back" + vecbasic_push_back :: Ptr CVecBasic -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h vecbasic_get" + vecbasic_get :: Ptr CVecBasic -> CSize -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h vecbasic_set" + vecbasic_set :: Ptr CVecBasic -> CSize -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h vecbasic_erase" + vecbasic_erase :: Ptr CVecBasic -> CSize -> IO () + +foreign import capi unsafe "symengine/cwrapper.h vecbasic_size" + vecbasic_size :: Ptr CVecBasic -> IO CSize + +foreign import capi unsafe "symengine/cwrapper.h basic_max" + basic_max :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize + +foreign import capi unsafe "symengine/cwrapper.h basic_min" + basic_min :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize + +foreign import capi unsafe "symengine/cwrapper.h basic_add_vec" + basic_add_vec :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize + +foreign import capi unsafe "symengine/cwrapper.h basic_mul_vec" + basic_mul_vec :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize + +foreign import capi unsafe "symengine/cwrapper.h setbasic_new" + setbasic_new :: IO (Ptr CSetBasic) + +foreign import capi unsafe "symengine/cwrapper.h &setbasic_free" + setbasic_free :: FunPtr (Ptr CSetBasic -> IO ()) + +foreign import capi unsafe "symengine/cwrapper.h setbasic_insert" + setbasic_insert :: Ptr CSetBasic -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h setbasic_get" + setbasic_get :: Ptr CSetBasic -> CInt -> Ptr Cbasic_struct -> IO () + +foreign import capi unsafe "symengine/cwrapper.h setbasic_find" + setbasic_find :: Ptr CSetBasic -> Ptr Cbasic_struct -> IO CInt + +foreign import capi unsafe "symengine/cwrapper.h setbasic_size" + setbasic_size :: Ptr CSetBasic -> IO CSize + foreign import ccall unsafe "symengine/cwrapper.h symengine_version" symengine_version :: IO CString diff --git a/test/Spec.hs b/test/Spec.hs index 36706f6..fef6449 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,9 +1,14 @@ +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} + import Data.List import Data.Monoid import Data.Ord -import Symengine as Sym +-- import Symengine as Sym +import Data.Ratio +import Symengine.Internal import Test.Tasty -import Test.Tasty.HUnit as HU +import Test.Tasty.HUnit import Test.Tasty.QuickCheck as QC import Prelude hiding (pi) @@ -21,30 +26,51 @@ tests = testGroup "Tests" [unitTests] unitTests = testGroup "Unit tests" - [ HU.testCase "FFI Sanity Check - ASCII Art should be non-empty" $ - do - ascii_art <- Sym.ascii_art_str - HU.assertBool "ASCII art from ascii_art_str is empty" (not . null $ ascii_art) - - -- , HU.testCase "Basic Constructors" $ - -- do - -- "0" @?= (show zero) - -- "1" @?= (show one) - -- "-1" @?= (show minus_one) - -- , HU.testCase "Basic Trignometric Functions" $ - -- do - -- let pi_over_3 = pi / 3 :: BasicSym - -- let pi_over_2 = pi / 2 :: BasicSym - - -- sin zero @?= zero - -- cos zero @?= one - -- - -- sin (pi / 6) @?= 1 / 2 - -- sin (pi / 3) @?= (3 ** (1/2)) / 2 - - -- cos (pi / 6) @?= (3 ** (1/2)) / 2 - -- cos (pi / 3) @?= 1 / 2 - - -- sin pi_over_2 @?= one - -- cos pi_over_2 @?= zero + [ -- testCase "FFI Sanity Check - ASCII Art should be non-empty" $ + -- do + -- ascii_art <- Sym.ascii_art_str + -- assertBool "ASCII art from ascii_art_str is empty" (not . null $ ascii_art), + testCase "test_complex" $ do + let r = fromRational (100 % 47) :: Basic + i = fromRational (76 % 59) + e = r + i * im + show e @?= "100/47 + 76/59*I" + isSymbol e @?= False + isRational e @?= False + isInteger e @?= False + isComplex e @?= True + isZero e @?= False + isNegative e @?= False + isPositive e @?= False + + show (realPart e) @?= "100/47" + isSymbol (realPart e) @?= False + isRational (realPart e) @?= True + isInteger (realPart e) @?= False + isComplex (realPart e) @?= False + + show (imagPart e) @?= "76/59" + isSymbol (imagPart e) @?= False + isRational (imagPart e) @?= True + isInteger (imagPart e) @?= False + isComplex (imagPart e) @?= False, + testCase "test_free_symbols" $ do + let x = "x" :: Basic + y = "y" + z = "z" + e = 123 + expr = (e + x) ** y / z + + setSize (freeSymbols expr) @?= 3 + toList (freeSymbols expr) @?= ["x", "y", "z"], + testCase "test_function_symbols" $ do + let x = "x" :: Basic + y = "y" + z = "z" + g = mkFunction "g" [x] + h = mkFunction "h" [g] + f = mkFunction "f" [x + y, g, h] + + show (z + f) @?= "z + f(x + y, g(x), h(g(x)))" + setSize (functionSymbols f) @?= 3 ] From eb4120fec3db442d491fd141a00424f58f5cb34b Mon Sep 17 00:00:00 2001 From: twesterhout <14264576+twesterhout@users.noreply.github.com> Date: Sat, 22 Apr 2023 18:28:21 +0200 Subject: [PATCH 16/21] Proper conversions for Integer and Rational --- cabal.project.local | 5 + flake.lock | 76 +++ flake.nix | 118 +++++ fourmolu.yaml | 14 + src/Symengine.hs | 1032 +++++++++++++++++++++++++++---------- src/Symengine/Context.hs | 108 ++++ src/Symengine/Internal.hs | 757 +++++---------------------- symengine.cabal | 138 +++-- test/Spec.hs | 130 +++-- 9 files changed, 1343 insertions(+), 1035 deletions(-) create mode 100644 cabal.project.local create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 fourmolu.yaml create mode 100644 src/Symengine/Context.hs diff --git a/cabal.project.local b/cabal.project.local new file mode 100644 index 0000000..7f529ad --- /dev/null +++ b/cabal.project.local @@ -0,0 +1,5 @@ +ignore-project: False +write-ghc-environment-files: always +tests: True +test-options: "--color" +test-show-details: streaming diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..cee9ed8 --- /dev/null +++ b/flake.lock @@ -0,0 +1,76 @@ +{ + "nodes": { + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "locked": { + "lastModified": 1680776469, + "narHash": "sha256-3CXUDK/3q/kieWtdsYpDOBJw3Gw4Af6x+2EiSnIkNQw=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "411e8764155aa9354dbcd6d5faaeb97e9e3dce24", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-filter": { + "locked": { + "lastModified": 1678109515, + "narHash": "sha256-C2X+qC80K2C1TOYZT8nabgo05Dw2HST/pSn6s+n6BO8=", + "owner": "numtide", + "repo": "nix-filter", + "rev": "aa9ff6ce4a7f19af6415fb3721eaa513ea6c763c", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "nix-filter", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1680758185, + "narHash": "sha256-sCVWwfnk7zEX8Z+OItiH+pcSklrlsLZ4TJTtnxAYREw=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "0e19daa510e47a40e06257e205965f3b96ce0ac9", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nix-filter": "nix-filter", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..326f073 --- /dev/null +++ b/flake.nix @@ -0,0 +1,118 @@ +{ + description = "symengine/symengine.hs: SymEngine symbolic mathematics engine for Haskell"; + + nixConfig = { + extra-experimental-features = "nix-command flakes"; + }; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + nix-filter.url = "github:numtide/nix-filter"; + flake-compat = { + url = "github:edolstra/flake-compat"; + flake = false; + }; + }; + + outputs = inputs: inputs.flake-utils.lib.eachDefaultSystem (system: + with builtins; + let + inherit (inputs.nixpkgs) lib; + pkgs = import inputs.nixpkgs { + inherit system; + overlays = [ + (self: super: { + symengine = super.symengine.overrideAttrs (attrs: rec { + version = "0.10.1"; + src = self.fetchFromGitHub { + owner = attrs.pname; + repo = attrs.pname; + rev = "v${version}"; + sha256 = "sha256-qTu0vS9K6rrr/0SXKpGC9P1QSN/AN7hyO/4DrGvhxWM="; + }; + cmakeFlags = (attrs.cmakeFlags or [ ]) ++ [ + "-DCMAKE_BUILD_TYPE=Debug" + "-DBUILD_SHARED_LIBS=ON" + ]; + }); + }) + ]; + }; + + src = inputs.nix-filter.lib { + root = ./.; + include = [ + "src" + "test" + "symengine.cabal" + "README.md" + "LICENSE" + ]; + }; + + # This allows us to build a Haskell package with any given GHC version. + # It will also affects all dependent libraries. + # overrides allows us to patch existing Haskell packages, or introduce new ones + # see here for specifics: https://nixos.wiki/wiki/Overlays + haskellPackagesOverride = ps: args: + ps.override + { + overrides = self: super: { + symengine = (self.callCabal2nix "symengine" src { + inherit (pkgs) symengine; + mpc = pkgs.libmpc; + }); + }; + }; + + outputsFor = + { haskellPackages + , name + , package ? "" + , ... + }: + let + ps = haskellPackagesOverride haskellPackages { }; + in + { + packages.${name} = ps.${package} or ps; + devShells.${name} = ps.shellFor { + packages = ps: with ps; [ + symengine + ]; + withHoogle = true; + nativeBuildInputs = with pkgs; with ps; [ + # Building and testing + cabal-install + # Language servers + haskell-language-server + nil + # Formatters + fourmolu + cabal-fmt + nixpkgs-fmt + # Previewing markdown files + python3Packages.grip + ]; + shellHook = '' + LD_LIBRARY_PATH=${pkgs.symengine}/lib:${pkgs.flint}/lib:${pkgs.libmpc}/lib:${pkgs.mpfr}/lib:$LD_LIBRARY_PATH + SYMENGINE_PATH=${pkgs.symengine} + ''; + }; + # The formatter to use for .nix files (but not .hs files) + # Allows us to run `nix fmt` to reformat nix files. + formatter = pkgs.nixpkgs-fmt; + }; + in + foldl' (acc: conf: lib.recursiveUpdate acc (outputsFor conf)) { } + (lib.mapAttrsToList (name: haskellPackages: { inherit name haskellPackages; }) + (lib.filterAttrs (_: ps: ps ? ghc) pkgs.haskell.packages) ++ [ + { + haskellPackages = pkgs.haskellPackages; + name = "default"; + package = "symengine"; + } + ]) + ); +} diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 0000000..e9f82b2 --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,14 @@ +indentation: 2 +function-arrows: leading +comma-style: leading +import-export-style: leading +indent-wheres: true +record-brace-space: true +newlines-between-decls: 1 +haddock-style: single-line +haddock-style-module: single-line +let-style: auto +in-style: right-align +unicode: never +respectful: true +fixities: [] diff --git a/src/Symengine.hs b/src/Symengine.hs index d0671ff..88e10c5 100644 --- a/src/Symengine.hs +++ b/src/Symengine.hs @@ -1,281 +1,769 @@ -{-# LANGUAGE RecordWildCards #-} - -{-| -Module : Symengine -Description : Symengine bindings to Haskell --} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} + +-- | +-- Module : Symengine +-- Description : Symengine bindings to Haskell module Symengine - ( - ascii_art_str, - zero, - one, - im, - Symengine.pi, - e, - minus_one, - rational, - complex, - symbol, - BasicSym, - ) where - + ( Basic (..) + , symbol + , parse + , e + , infinity + , nan + , diff + , evalf + , inverse + , identityMatrix + , zeroMatrix + , allocaCxxInteger + , peekCxxInteger + , withCxxInteger + , EvalDomain (..) + , InverseMethod (..) + , toAST + , fromAST + , AST (..) + ) where + +import Control.Exception (bracket_) +import Control.Monad +import Data.Bits +import Data.ByteString (packCString) +import Data.Text (Text, pack, unpack) +import Data.Text.Encoding qualified as T +import Data.Vector (Vector) +import Data.Vector qualified as V import Foreign.C.Types -import Foreign.Ptr -import Foreign.C.String -import Foreign.Storable -import Foreign.Marshal.Array -import Foreign.Marshal.Alloc import Foreign.ForeignPtr -import Control.Applicative +import Foreign.Marshal (allocaBytes, allocaBytesAligned, toBool, withArrayLen) +import Foreign.Ptr +import GHC.Exts (IsString (..)) +import GHC.Int +import GHC.Num.BigNat +import GHC.Num.Integer +import GHC.Real (Ratio (..)) +import Language.C.Inline qualified as C +import Language.C.Inline.Cpp.Exception qualified as C +import Language.C.Inline.Unsafe qualified as CU +import Symengine.Context +import Symengine.Internal import System.IO.Unsafe -import Control.Monad -import GHC.Real - -data BasicStruct = BasicStruct { - data_ptr :: Ptr () -} -instance Storable BasicStruct where - alignment _ = 8 - sizeOf _ = sizeOf nullPtr - peek basic_ptr = BasicStruct <$> peekByteOff basic_ptr 0 - poke basic_ptr BasicStruct{..} = pokeByteOff basic_ptr 0 data_ptr +importSymengine +-- | Convert a pointer to @std::string@ into a string. +-- +-- It properly handles unicode characters. +peekCxxString :: Ptr CxxString -> IO Text +peekCxxString p = + fmap T.decodeUtf8 $ + packCString + =<< [CU.exp| char const* { $(const std::string* p)->c_str() } |] + +-- | Call 'peekCxxString' and @delete@ the pointer. +peekAndDeleteCxxString :: Ptr CxxString -> IO Text +peekAndDeleteCxxString p = do + s <- peekCxxString p + [CU.exp| void { delete $(const std::string* p) } |] + pure s + +constructBasic :: (Ptr CxxBasic -> IO ()) -> IO Basic +constructBasic construct = + fmap Basic $ constructWithDeleter size deleter $ \ptr -> do + [CU.block| void { new ($(Object* ptr)) Object{}; } |] + construct ptr + where + size = fromIntegral [CU.pure| size_t { sizeof(Object) } |] + deleter = [C.funPtr| void deleteBasic(Object* ptr) { ptr->~Object(); } |] + +constructWithDeleter :: Int -> FinalizerPtr a -> (Ptr a -> IO ()) -> IO (ForeignPtr a) +constructWithDeleter size deleter constructor = do + fp <- mallocForeignPtrBytes size + withForeignPtr fp constructor + addForeignPtrFinalizer deleter fp + pure fp + +-- newtype VecBasic = VecBasic (ForeignPtr CxxVecBasic) + +-- constructVecBasic :: (Ptr CxxVecBasic -> IO ()) -> IO VecBasic +-- constructVecBasic construct = +-- fmap VecBasic $ constructWithDeleter size deleter $ \ptr -> do +-- [CU.block| void { new ($(SymEngine::vec_basic* ptr)) SymEngine::vec_basic{}; } |] +-- construct ptr +-- where +-- size = fromIntegral [CU.pure| size_t { sizeof(SymEngine::vec_basic) } |] +-- deleter = [C.funPtr| void deleteBasic(SymEngine::vec_basic* ptr) { ptr->~vector(); } |] + +withBasic :: Basic -> (Ptr CxxBasic -> IO a) -> IO a +withBasic (Basic fp) = withForeignPtr fp + +-- withVecBasic :: VecBasic -> (Ptr CxxVecBasic -> IO a) -> IO a +-- withVecBasic (VecBasic fp) = withForeignPtr fp + +-- vecBasicToList :: VecBasic -> [Basic] +-- vecBasicToList v = unsafePerformIO $ +-- withVecBasic v $ \v' -> do +-- size <- [CU.exp| size_t { $(const SymEngine::vec_basic* v')->size() } |] +-- forM [0 .. size - 1] $ \i -> +-- constructBasic $ \dest -> +-- [CU.exp| void { CONSTRUCT_BASIC($(Object* dest), +-- $(const SymEngine::vec_basic* v')->at($(size_t i))) } |] + +cxxVectorSize :: Ptr (Vector Basic) -> IO Int +cxxVectorSize ptr = fromIntegral <$> [CU.exp| size_t { $(const Vector* ptr)->size() } |] + +cxxVectorIndex :: Ptr (Vector Basic) -> Int -> IO Basic +cxxVectorIndex ptr (fromIntegral -> i) = + $(constructBasicFrom "$(const Vector* ptr)->at($(size_t i))") + +cxxVectorPushBack :: Ptr (Vector Basic) -> Basic -> IO () +cxxVectorPushBack ptr basic = + withBasic basic $ \x -> + [CU.exp| void { $(Vector* ptr)->push_back(*$(const Object* x)) } |] + +peekVector :: Ptr (Vector Basic) -> IO (Vector Basic) +peekVector ptr = do + size <- cxxVectorSize ptr + V.forM (V.enumFromStepN 0 1 size) (cxxVectorIndex ptr) + +allocaVector :: (Ptr (Vector Basic) -> IO a) -> IO a +allocaVector action = + allocaBytesAligned sizeBytes alignmentBytes $ \v -> + let construct = [CU.exp| void { new ($(Vector* v)) Vector{} } |] + destruct = [CU.exp| void { $(Vector* v)->~Vector() } |] + in bracket_ construct destruct (action v) + where + sizeBytes = fromIntegral [CU.pure| size_t { sizeof(Vector) } |] + alignmentBytes = fromIntegral [CU.pure| size_t { alignof(Vector) } |] + +withVector :: Vector Basic -> (Ptr (Vector Basic) -> IO a) -> IO a +withVector v action = do + allocaVector $ \ptr -> do + V.forM_ v $ cxxVectorPushBack ptr + action ptr + +-- \$ \dest -> +-- [CU.exp| void { CONSTRUCT_BASIC($(Object* dest), +-- $(const Vector* ptr)->at($(size_t i))) } |] + +allocaDenseMatrix :: Int -> Int -> (Ptr (DenseMatrix Basic) -> IO a) -> IO a +allocaDenseMatrix (fromIntegral -> nrows) (fromIntegral -> ncols) action = do + allocaBytesAligned sizeBytes alignmentBytes $ \v -> + let construct = + [CU.exp| void { new ($(DenseMatrix * v)) DenseMatrix{ + $(unsigned nrows), $(unsigned ncols)} } |] + destruct = [CU.exp| void { $(DenseMatrix* v)->~DenseMatrix() } |] + in bracket_ construct destruct (action v) + where + sizeBytes = fromIntegral [CU.pure| size_t { sizeof(DenseMatrix) } |] + alignmentBytes = fromIntegral [CU.pure| size_t { alignof(DenseMatrix) } |] + +withDenseMatrix :: DenseMatrix Basic -> (Ptr (DenseMatrix Basic) -> IO a) -> IO a +withDenseMatrix matrix action = + allocaDenseMatrix 0 0 $ \ptr -> + withVector (dmData matrix) $ \v -> do + let n = fromIntegral $ dmRows matrix + m = fromIntegral $ dmCols matrix + [CU.block| void { + *$(DenseMatrix* ptr) = DenseMatrix{$(unsigned n), $(unsigned m), *$(const Vector* v)}; + } |] + action ptr + +peekDenseMatrix :: Ptr (DenseMatrix Basic) -> IO (DenseMatrix Basic) +peekDenseMatrix ptr = do + n <- fromIntegral <$> [CU.exp| unsigned { $(const DenseMatrix* ptr)->nrows() } |] + m <- fromIntegral <$> [CU.exp| unsigned { $(const DenseMatrix* ptr)->ncols() } |] + v <- + allocaVector $ \v -> do + [CU.block| void { *$(Vector* v) = $(const DenseMatrix* ptr)->as_vec_basic(); } |] + peekVector v + pure $ DenseMatrix n m v + +instance Show Basic where + show basic = unpack . unsafePerformIO $ + withBasic basic $ \basic' -> + $(constructStringFrom "SymEngine::str(**$(Object* basic'))") + +deriving stock instance Show (DenseMatrix Basic) + +-- peekAndDeleteCxxString +-- =<< [CU.exp| std::string* { new std::string{SymEngine::str(**$(Object* basic'))} } |] + +instance Eq Basic where + a == b = unsafePerformIO $ + withBasic a $ \a' -> + withBasic b $ \b' -> + toBool + <$> [CU.exp| bool { eq(**$(const Object* a'), **$(const Object* b')) } |] + +parse :: Text -> Basic +parse (T.encodeUtf8 -> name) = + unsafePerformIO $ $(constructBasicFrom "parse($bs-cstr:name)") + +instance IsString Basic where + fromString = parse . pack + +symbol :: Text -> Basic +symbol (T.encodeUtf8 -> name) = + unsafePerformIO $ + $(constructBasicFrom "symbol(std::string{$bs-ptr:name, static_cast($bs-len:name)})") + +-- constructBasic $ \dest -> +-- [CU.exp| void { new ($(Object* dest)) Object{} } |] + +-- pureUnaryOp :: (Ptr CxxBasic -> Ptr CxxBasic -> IO ()) -> Basic -> Basic +-- pureUnaryOp f a = unsafePerformIO $ +-- withBasic a $ \a' -> +-- constructBasic $ \dest -> +-- f dest a' + +-- pureBinaryOp :: (Ptr CxxBasic -> Ptr CxxBasic -> Ptr CxxBasic -> IO ()) -> Basic -> Basic -> Basic +-- pureBinaryOp f a b = unsafePerformIO $ +-- withBasic a $ \a' -> +-- withBasic b $ \b' -> +-- constructBasic $ \dest -> +-- f dest a' b' + +allocaCxxInteger :: (Ptr CxxInteger -> IO a) -> IO a +allocaCxxInteger f = + allocaBytesAligned sizeBytes alignmentBytes $ \i -> + let construct = + [CU.exp| void { new ($(integer_class * i)) integer_class{} } |] + destruct = [CU.exp| void { $(integer_class * i)->~integer_class() } |] + in bracket_ construct destruct (f i) + where + sizeBytes = fromIntegral [CU.pure| size_t { sizeof(integer_class) } |] + alignmentBytes = fromIntegral [CU.pure| size_t { alignof(integer_class) } |] + +integerToWords :: Integer -> [Word] +integerToWords (IP b) = bigNatToWordList b +integerToWords (IN b) = bigNatToWordList b +integerToWords (IS n) = [fromIntegral (abs (I# n))] + +withCxxInteger :: Integer -> (Ptr CxxInteger -> IO a) -> IO a +withCxxInteger n action = + allocaCxxInteger $ \i -> + withArrayLen (fromIntegral <$> integerToWords n) $ + \(fromIntegral -> numWords) wordsPtr -> do + [CU.block| void { + auto const numWords = $(int numWords); + auto const* words = $(const uint64_t* wordsPtr); + if (numWords > 0) { + integer_class x{words[0]}; + for (int k = 1; k < numWords; ++k) { + x <<= 64; + x += words[k]; + } + *$(integer_class* i) = x; + } + } |] + when (n < 0) $ do + [CU.block| void { + auto& i = *$(integer_class* i); + i = -i; + } |] + action i + +peekCxxInteger :: Ptr CxxInteger -> IO Integer +peekCxxInteger i = do + allocaCxxInteger $ \j -> do + isNegative <- + toBool + <$> [CU.block| bool { + auto const& i = *$(integer_class const* i); + auto& j = *$(integer_class* j); + j = mp_abs(i); + return i < 0; + } |] + let go acc = do + w <- + [CU.block| uint64_t { + auto const& j = *$(integer_class const* j); + return mp_get_ui(j); + } |] + continue <- + toBool + <$> [CU.block| bool { + auto& j = *$(integer_class* j); + j >>= 64; + return j != 0; + } |] + if continue + then go $ w : acc + else pure $ w : acc + integerFromWordList isNegative . fmap fromIntegral <$> go [] + +instance Num Basic where + fromInteger n = unsafePerformIO $ + withCxxInteger n $ \i -> + $(constructBasicFrom "integer(*$(const integer_class* i))") + (+) = $(mkBinaryFunction "add(a, b)") + (-) = $(mkBinaryFunction "sub(a, b)") + (*) = $(mkBinaryFunction "mul(a, b)") + abs = $(mkUnaryFunction "abs(a)") + signum = $(mkUnaryFunction "sign(a)") + +instance Fractional Basic where + (/) = $(mkBinaryFunction "div(a, b)") + fromRational (numer :% denom) = + unsafePerformIO $ + withCxxInteger numer $ \numer' -> + withCxxInteger denom $ \denom' -> + $( constructBasicFrom + "Rational::from_two_ints(\ + \Integer(*$(const integer_class* numer')),\ + \Integer(*$(const integer_class* denom')))" + ) + +e :: Basic +e = unsafePerformIO $ $(constructBasicFrom "E") +{-# NOINLINE e #-} + +infinity :: Basic +infinity = unsafePerformIO $ $(constructBasicFrom "Inf") +{-# NOINLINE infinity #-} + +nan :: Basic +nan = unsafePerformIO $ $(constructBasicFrom "Nan") +{-# NOINLINE nan #-} + +instance Floating Basic where + pi = unsafePerformIO $ $(constructBasicFrom "pi") + exp = $(mkUnaryFunction "exp(a)") + log = $(mkUnaryFunction "log(a)") + sqrt = $(mkUnaryFunction "sqrt(a)") + (**) = $(mkBinaryFunction "pow(a, b)") + sin = $(mkUnaryFunction "sin(a)") + cos = $(mkUnaryFunction "cos(a)") + tan = $(mkUnaryFunction "tan(a)") + asin = $(mkUnaryFunction "asin(a)") + acos = $(mkUnaryFunction "acos(a)") + atan = $(mkUnaryFunction "atan(a)") + sinh = $(mkUnaryFunction "sinh(a)") + cosh = $(mkUnaryFunction "cosh(a)") + tanh = $(mkUnaryFunction "tanh(a)") + asinh = $(mkUnaryFunction "asinh(a)") + acosh = $(mkUnaryFunction "acosh(a)") + atanh = $(mkUnaryFunction "atanh(a)") + +diff :: Basic -> Basic -> Basic +diff f x + | basicTypeCode x == [CU.pure| int { static_cast(SYMENGINE_SYMBOL) } |] = + $(mkBinaryFunction "a->diff(rcp_static_cast(b))") f x + | otherwise = error "can only differentiate with respect to symbols" + +data EvalDomain = EvalComplex | EvalReal | EvalSymbolic + deriving stock (Show, Eq) + +evalDomainToCInt :: EvalDomain -> CInt +evalDomainToCInt EvalComplex = [CU.pure| int { static_cast(SymEngine::EvalfDomain::Complex) } |] +evalDomainToCInt EvalReal = [CU.pure| int { static_cast(SymEngine::EvalfDomain::Real) } |] +evalDomainToCInt EvalSymbolic = [CU.pure| int { static_cast(SymEngine::EvalfDomain::Symbolic) } |] + +evalf :: EvalDomain -> Int -> Basic -> Basic +evalf (evalDomainToCInt -> domain) (fromIntegral -> bits) x = unsafePerformIO $ + withBasic x $ \x' -> + $(constructBasicFrom "evalf(**$(const Object* x'), $(int bits), static_cast($(int domain)))") + +-- pureBinaryOp $ \dest f x -> do +-- [CU.exp| void { +-- CONSTRUCT_BASIC($(Object* dest), (*$(const Object* f))->diff( +-- SymEngine::rcp_static_cast(*$(const Object* x)))) } |] + +generateDenseMatrix :: Int -> Int -> (Int -> Int -> Basic) -> DenseMatrix Basic +generateDenseMatrix nrows ncols f = + DenseMatrix nrows ncols $ + V.generate (nrows * ncols) $ \i -> + let (!r, !c) = i `divMod` ncols + in f r c + +identityMatrix :: Int -> DenseMatrix Basic +identityMatrix n = generateDenseMatrix n n (\i j -> if i == j then 1 else 0) + +zeroMatrix :: Int -> Int -> DenseMatrix Basic +zeroMatrix n m = generateDenseMatrix n m (\_ _ -> 0) + +data InverseMethod + = InverseDefault + | InverseFractionFreeLU + | InverseLU + | InversePivotedLU + | InverseGaussJordan + deriving stock (Show, Eq) + +inverse :: InverseMethod -> DenseMatrix Basic -> DenseMatrix Basic +inverse InverseDefault m = unsafePerformIO $ withDenseMatrix m $ \a -> + $( createDenseMatrixVia + "auto const& a = *$(const DenseMatrix* a);\ + \out.resize(a.nrows(), a.ncols());\ + \a.inv(out);" + ) + +data AST + = SymengineInteger Integer + | SymengineRational Rational + | SymengineInfinity + | SymengineNaN + | SymengineConstant Basic + | SymengineSymbol Text + | SymengineMul (Vector Basic) + | SymengineAdd (Vector Basic) + | SymenginePow Basic Basic + | SymengineLog Basic + | SymengineSign Basic + | SymengineFunction Text (Vector Basic) + | SymengineDerivative Basic (Vector Basic) + deriving stock (Show, Eq) + +basicTypeCode :: Basic -> CInt +basicTypeCode x = unsafePerformIO $ + withBasic x $ + \x' -> [CU.exp| int { static_cast((*$(const Object* x'))->get_type_code()) } |] + +forceOneArg :: (Basic -> a) -> Vector Basic -> a +forceOneArg f v = case V.toList v of + [a] -> f a + _ -> error "expected a one-element vector" + +forceTwoArgs :: (Basic -> Basic -> a) -> Vector Basic -> a +forceTwoArgs f v = case V.toList v of + [a, b] -> f a b + _ -> error "expected a two-element vector" + +unsafeIntegerToAST :: Basic -> AST +unsafeIntegerToAST x = SymengineInteger n + where + n = unsafePerformIO $ + withBasic x $ \x' -> + allocaCxxInteger $ \i -> do + [CU.exp| void { + *$(integer_class* i) = + down_cast(**$(const Object* x')).as_integer_class() + } |] + peekCxxInteger i + +unsafeRationalToAST :: Basic -> AST +unsafeRationalToAST x = SymengineRational q + where + q = unsafePerformIO $ + withBasic x $ \x' -> + allocaCxxInteger $ \m -> + allocaCxxInteger $ \n -> do + [CU.block| void { + auto const& x = + down_cast(**$(const Object* x')).as_rational_class(); + *$(integer_class* m) = x.get_num(); + *$(integer_class* n) = x.get_den(); + } |] + (:%) <$> peekCxxInteger m <*> peekCxxInteger n + +unsafeSymbolToAST :: Basic -> AST +unsafeSymbolToAST x = SymengineSymbol . unsafePerformIO $ do + withBasic x $ \x' -> + $(constructStringFrom "down_cast(**$(const Object* x')).get_name()") + +toAST :: Basic -> AST +toAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_INTEGER) } |] = unsafeIntegerToAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_RATIONAL) } |] = unsafeRationalToAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_INFTY) } |] = SymengineInfinity + | tp == [CU.pure| int { static_cast(SYMENGINE_NOT_A_NUMBER) } |] = SymengineNaN + | tp == [CU.pure| int { static_cast(SYMENGINE_CONSTANT) } |] = SymengineConstant x + | tp == [CU.pure| int { static_cast(SYMENGINE_SYMBOL) } |] = unsafeSymbolToAST x + | tp == [CU.pure| int { static_cast(SYMENGINE_ADD) } |] = + unsafePerformIO $ SymengineAdd . V.reverse <$> $(unpackFunction "Add") x + | tp == [CU.pure| int { static_cast(SYMENGINE_MUL) } |] = + unsafePerformIO $ SymengineMul <$> $(unpackFunction "Mul") x + | tp == [CU.pure| int { static_cast(SYMENGINE_POW) } |] = + unsafePerformIO $ forceTwoArgs SymenginePow <$> $(unpackFunction "Pow") x + | tp == [CU.pure| int { static_cast(SYMENGINE_LOG) } |] = + unsafePerformIO $ forceOneArg SymengineLog <$> $(unpackFunction "Log") x + | tp == [CU.pure| int { static_cast(SYMENGINE_SIGN) } |] = + unsafePerformIO $ forceOneArg SymengineSign <$> $(unpackFunction "Sign") x + | tp == [CU.pure| int { static_cast(SYMENGINE_FUNCTIONSYMBOL) } |] = + unsafePerformIO $ do + name <- withBasic x $ \x' -> + $(constructStringFrom "down_cast(**$(const Object* x')).get_name()") + args <- $(unpackFunction "FunctionSymbol") x + pure $ SymengineFunction name args + | tp == [CU.pure| int { static_cast(SYMENGINE_DERIVATIVE) } |] = + unsafePerformIO $ do + args <- $(unpackFunction "Derivative") x + pure $ SymengineDerivative (V.head args) (V.tail args) + | otherwise = error $ "unknown type code: " <> show tp + where + tp = basicTypeCode x + +fromAST :: AST -> Basic +fromAST = \case + SymengineInteger x -> fromInteger x + SymengineRational x -> fromRational x + SymengineInfinity -> infinity + SymengineNaN -> nan + SymengineConstant x -> x + SymengineSymbol x -> symbol x + SymengineAdd v -> V.foldl' (+) 0 v + SymengineMul v -> V.foldl' (*) 0 v + SymenginePow a b -> a ** b + SymengineLog x -> log x + SymengineSign x -> signum x + SymengineDerivative f v -> V.foldl' diff f v + +{- +-- | Convert a C string into a Haskell string properly handling unicode characters. +peekCString :: CString -> IO Text +peekCString = fmap T.decodeUtf8 . packCString + +withTempCString :: IO CString -> (CString -> IO a) -> IO a +withTempCString allocate = bracket allocate destroy + where + destroy p = [CU.exp| void { basic_str_free($(char* p)) } |] + +asciiArt :: IO Text +asciiArt = withTempCString [CU.exp| char* { ascii_art_str() } |] peekCString +-} --- |represents a symbol exported by SymEngine. create this using the functions --- 'zero', 'one', 'minus_one', 'e', 'im', 'rational', 'complex', and also by --- constructing a number and converting it to a Symbol --- --- >>> 3.5 :: BasicSym --- 7/2 +-- newtype BasicStruct = BasicStruct +-- { data_ptr :: Ptr () +-- } -- --- >>> rational 2 10 --- 1 /5 +-- instance Storable BasicStruct where +-- alignment _ = 8 +-- sizeOf _ = sizeOf nullPtr +-- peek basic_ptr = BasicStruct <$> peekByteOff basic_ptr 0 +-- poke basic_ptr BasicStruct {..} = pokeByteOff basic_ptr 0 data_ptr -- --- >>> complex 1 2 --- 1 + 2*I -data BasicSym = BasicSym { fptr :: ForeignPtr BasicStruct } - -withBasicSym :: BasicSym -> (Ptr BasicStruct -> IO a) -> IO a -withBasicSym p f = withForeignPtr (fptr p ) f - -withBasicSym2 :: BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a -withBasicSym2 p1 p2 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> f p1 p2)) - -withBasicSym3 :: BasicSym -> BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a -withBasicSym3 p1 p2 p3 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> withBasicSym p3 (\p3 -> f p1 p2 p3))) - - --- | constructor for 0 -zero :: BasicSym -zero = basic_obj_constructor basic_const_zero_ffi - --- | constructor for 1 -one :: BasicSym -one = basic_obj_constructor basic_const_one_ffi - --- | constructor for -1 -minus_one :: BasicSym -minus_one = basic_obj_constructor basic_const_minus_one_ffi - --- | constructor for i = sqrt(-1) -im :: BasicSym -im = basic_obj_constructor basic_const_I_ffi - --- | the ratio of the circumference of a circle to its radius -pi :: BasicSym -pi = basic_obj_constructor basic_const_pi_ffi - --- | The base of the natural logarithm -e :: BasicSym -e = basic_obj_constructor basic_const_E_ffi - -expand :: BasicSym -> BasicSym -expand = basic_unaryop basic_expand_ffi - - -eulerGamma :: BasicSym -eulerGamma = basic_obj_constructor basic_const_EulerGamma_ffi - -basic_obj_constructor :: (Ptr BasicStruct -> IO ()) -> BasicSym -basic_obj_constructor init_fn = unsafePerformIO $ do - basic_ptr <- create_basic_ptr - withBasicSym basic_ptr init_fn - return basic_ptr - -basic_str :: BasicSym -> String -basic_str basic_ptr = unsafePerformIO $ withBasicSym basic_ptr (basic_str_ffi >=> peekCString) - -integerToCLong :: Integer -> CLong -integerToCLong i = CLong (fromInteger i) - - -intToCLong :: Int -> CLong -intToCLong i = integerToCLong (toInteger i) - -basic_int_signed :: Int -> BasicSym -basic_int_signed i = unsafePerformIO $ do - iptr <- create_basic_ptr - withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (intToCLong i) ) - return iptr - - -basic_from_integer :: Integer -> BasicSym -basic_from_integer i = unsafePerformIO $ do - iptr <- create_basic_ptr - withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (fromInteger i)) - return iptr - --- |The `ascii_art_str` function prints SymEngine in ASCII art. --- this is useful as a sanity check -ascii_art_str :: IO String -ascii_art_str = ascii_art_str_ffi >>= peekCString - --- Unexported ffi functions------------------------ - --- |Create a basic object that represents all other objects through --- the FFI -create_basic_ptr :: IO BasicSym -create_basic_ptr = do - basic_ptr <- newArray [BasicStruct { data_ptr = nullPtr }] - basic_new_heap_ffi basic_ptr - finalized_ptr <- newForeignPtr ptr_basic_free_heap_ffi basic_ptr - return $ BasicSym { fptr = finalized_ptr } - -basic_binaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym -> BasicSym -basic_binaryop f a b = unsafePerformIO $ do - s <- create_basic_ptr - withBasicSym3 s a b f - return s - -basic_unaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym -basic_unaryop f a = unsafePerformIO $ do - s <- create_basic_ptr - withBasicSym2 s a f - return s - - -basic_pow :: BasicSym -> BasicSym -> BasicSym -basic_pow = basic_binaryop basic_pow_ffi - --- |Create a rational number with numerator and denominator -rational :: BasicSym -> BasicSym -> BasicSym -rational = basic_binaryop rational_set_ffi - --- |Create a complex number a + b * im -complex :: BasicSym -> BasicSym -> BasicSym -complex a b = (basic_binaryop complex_set_ffi) a b - -basic_rational_from_integer :: Integer -> Integer -> BasicSym -basic_rational_from_integer i j = unsafePerformIO $ do - s <- create_basic_ptr - withBasicSym s (\s -> rational_set_si_ffi s (integerToCLong i) (integerToCLong j)) - return s - --- |Create a symbol with the given name -symbol :: String -> BasicSym -symbol name = unsafePerformIO $ do - s <- create_basic_ptr - cname <- newCString name - withBasicSym s (\s -> symbol_set_ffi s cname) - free cname - return s - --- |Differentiate an expression with respect to a symbol -diff :: BasicSym -> BasicSym -> BasicSym -diff expr symbol = (basic_binaryop basic_diff_ffi) expr symbol - -instance Show BasicSym where - show = basic_str - -instance Eq BasicSym where - (==) a b = unsafePerformIO $ do - i <- withBasicSym2 a b basic_eq_ffi - return $ i == 1 - - -instance Num BasicSym where - (+) = basic_binaryop basic_add_ffi - (-) = basic_binaryop basic_sub_ffi - (*) = basic_binaryop basic_mul_ffi - negate = basic_unaryop basic_neg_ffi - abs = basic_unaryop basic_abs_ffi - signum = undefined - fromInteger = basic_from_integer - -instance Fractional BasicSym where - (/) = basic_binaryop basic_div_ffi - fromRational (num :% denom) = basic_rational_from_integer num denom - recip r = one / r - -instance Floating BasicSym where - pi = Symengine.pi - exp x = e ** x - log = undefined - sqrt x = x ** 1/2 - (**) = basic_pow - logBase = undefined - sin = basic_unaryop basic_sin_ffi - cos = basic_unaryop basic_cos_ffi - tan = basic_unaryop basic_tan_ffi - asin = basic_unaryop basic_asin_ffi - acos = basic_unaryop basic_acos_ffi - atan = basic_unaryop basic_atan_ffi - sinh = basic_unaryop basic_sinh_ffi - cosh = basic_unaryop basic_cosh_ffi - tanh = basic_unaryop basic_tanh_ffi - asinh = basic_unaryop basic_asinh_ffi - acosh = basic_unaryop basic_acosh_ffi - atanh = basic_unaryop basic_atanh_ffi - -foreign import ccall "symengine/cwrapper.h ascii_art_str" ascii_art_str_ffi :: IO CString -foreign import ccall "symengine/cwrapper.h basic_new_heap" basic_new_heap_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h &basic_free_heap" ptr_basic_free_heap_ffi :: FunPtr(Ptr BasicStruct -> IO ()) - --- constants -foreign import ccall "symengine/cwrapper.h basic_const_zero" basic_const_zero_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_one" basic_const_one_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_minus_one" basic_const_minus_one_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_I" basic_const_I_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_pi" basic_const_pi_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_E" basic_const_E_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_const_EulerGamma" basic_const_EulerGamma_ffi :: Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_str" basic_str_ffi :: Ptr BasicStruct -> IO CString -foreign import ccall "symengine/cwrapper.h basic_eq" basic_eq_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO Int - -foreign import ccall "symengine/cwrapper.h symbol_set" symbol_set_ffi :: Ptr BasicStruct -> CString -> IO () -foreign import ccall "symengine/cwrapper.h basic_diff" basic_diff_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h integer_set_si" integer_set_si_ffi :: Ptr BasicStruct -> CLong -> IO () - -foreign import ccall "symengine/cwrapper.h rational_set" rational_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h rational_set_si" rational_set_si_ffi :: Ptr BasicStruct -> CLong -> CLong -> IO () - -foreign import ccall "symengine/cwrapper.h complex_set" complex_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_expand" basic_expand_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - - -foreign import ccall "symengine/cwrapper.h basic_add" basic_add_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_sub" basic_sub_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_mul" basic_mul_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_div" basic_div_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_pow" basic_pow_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_neg" basic_neg_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_abs" basic_abs_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_sin" basic_sin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_cos" basic_cos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_tan" basic_tan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_asin" basic_asin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_acos" basic_acos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_atan" basic_atan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_sinh" basic_sinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_cosh" basic_cosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_tanh" basic_tanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () - -foreign import ccall "symengine/cwrapper.h basic_asinh" basic_asinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_acosh" basic_acosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () -foreign import ccall "symengine/cwrapper.h basic_atanh" basic_atanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- -- |represents a symbol exported by SymEngine. create this using the functions +-- -- 'zero', 'one', 'minus_one', 'e', 'im', 'rational', 'complex', and also by +-- -- constructing a number and converting it to a Symbol +-- -- +-- -- >>> 3.5 :: BasicSym +-- -- 7/2 +-- -- +-- -- >>> rational 2 10 +-- -- 1 /5 +-- -- +-- -- >>> complex 1 2 +-- -- 1 + 2*I +-- data BasicSym = BasicSym {fptr :: ForeignPtr BasicStruct} +-- +-- withBasicSym :: BasicSym -> (Ptr BasicStruct -> IO a) -> IO a +-- withBasicSym p f = withForeignPtr (fptr p) f +-- +-- withBasicSym2 :: BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a +-- withBasicSym2 p1 p2 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> f p1 p2)) +-- +-- withBasicSym3 :: BasicSym -> BasicSym -> BasicSym -> (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO a) -> IO a +-- withBasicSym3 p1 p2 p3 f = withBasicSym p1 (\p1 -> withBasicSym p2 (\p2 -> withBasicSym p3 (\p3 -> f p1 p2 p3))) +-- +-- -- | constructor for 0 +-- zero :: BasicSym +-- zero = basic_obj_constructor basic_const_zero_ffi +-- +-- -- | constructor for 1 +-- one :: BasicSym +-- one = basic_obj_constructor basic_const_one_ffi +-- +-- -- | constructor for -1 +-- minus_one :: BasicSym +-- minus_one = basic_obj_constructor basic_const_minus_one_ffi +-- +-- -- | constructor for i = sqrt(-1) +-- im :: BasicSym +-- im = basic_obj_constructor basic_const_I_ffi +-- +-- -- | the ratio of the circumference of a circle to its radius +-- pi :: BasicSym +-- pi = basic_obj_constructor basic_const_pi_ffi +-- +-- -- | The base of the natural logarithm +-- e :: BasicSym +-- e = basic_obj_constructor basic_const_E_ffi +-- +-- expand :: BasicSym -> BasicSym +-- expand = basic_unaryop basic_expand_ffi +-- +-- eulerGamma :: BasicSym +-- eulerGamma = basic_obj_constructor basic_const_EulerGamma_ffi +-- +-- basic_obj_constructor :: (Ptr BasicStruct -> IO ()) -> BasicSym +-- basic_obj_constructor init_fn = unsafePerformIO $ do +-- basic_ptr <- create_basic_ptr +-- withBasicSym basic_ptr init_fn +-- return basic_ptr +-- +-- basic_str :: BasicSym -> String +-- basic_str basic_ptr = unsafePerformIO $ withBasicSym basic_ptr (basic_str_ffi >=> peekCString) +-- +-- integerToCLong :: Integer -> CLong +-- integerToCLong i = CLong (fromInteger i) +-- +-- intToCLong :: Int -> CLong +-- intToCLong i = integerToCLong (toInteger i) +-- +-- basic_int_signed :: Int -> BasicSym +-- basic_int_signed i = unsafePerformIO $ do +-- iptr <- create_basic_ptr +-- withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (intToCLong i)) +-- return iptr +-- +-- basic_from_integer :: Integer -> BasicSym +-- basic_from_integer i = unsafePerformIO $ do +-- iptr <- create_basic_ptr +-- withBasicSym iptr (\iptr -> integer_set_si_ffi iptr (fromInteger i)) +-- return iptr +-- +-- -- |The `ascii_art_str` function prints SymEngine in ASCII art. +-- -- this is useful as a sanity check +-- ascii_art_str :: IO String +-- ascii_art_str = ascii_art_str_ffi >>= peekCString +-- +-- -- Unexported ffi functions------------------------ +-- +-- -- |Create a basic object that represents all other objects through +-- -- the FFI +-- create_basic_ptr :: IO BasicSym +-- create_basic_ptr = do +-- basic_ptr <- newArray [BasicStruct {data_ptr = nullPtr}] +-- basic_new_heap_ffi basic_ptr +-- finalized_ptr <- newForeignPtr ptr_basic_free_heap_ffi basic_ptr +-- return $ BasicSym {fptr = finalized_ptr} +-- +-- basic_binaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym -> BasicSym +-- basic_binaryop f a b = unsafePerformIO $ do +-- s <- create_basic_ptr +-- withBasicSym3 s a b f +-- return s +-- +-- basic_unaryop :: (Ptr BasicStruct -> Ptr BasicStruct -> IO ()) -> BasicSym -> BasicSym +-- basic_unaryop f a = unsafePerformIO $ do +-- s <- create_basic_ptr +-- withBasicSym2 s a f +-- return s +-- +-- basic_pow :: BasicSym -> BasicSym -> BasicSym +-- basic_pow = basic_binaryop basic_pow_ffi +-- +-- -- |Create a rational number with numerator and denominator +-- rational :: BasicSym -> BasicSym -> BasicSym +-- rational = basic_binaryop rational_set_ffi +-- +-- -- |Create a complex number a + b * im +-- complex :: BasicSym -> BasicSym -> BasicSym +-- complex a b = (basic_binaryop complex_set_ffi) a b +-- +-- basic_rational_from_integer :: Integer -> Integer -> BasicSym +-- basic_rational_from_integer i j = unsafePerformIO $ do +-- s <- create_basic_ptr +-- withBasicSym s (\s -> rational_set_si_ffi s (integerToCLong i) (integerToCLong j)) +-- return s +-- +-- -- |Create a symbol with the given name +-- symbol :: String -> BasicSym +-- symbol name = unsafePerformIO $ do +-- s <- create_basic_ptr +-- cname <- newCString name +-- withBasicSym s (\s -> symbol_set_ffi s cname) +-- free cname +-- return s +-- +-- -- |Differentiate an expression with respect to a symbol +-- diff :: BasicSym -> BasicSym -> BasicSym +-- diff expr symbol = (basic_binaryop basic_diff_ffi) expr symbol +-- +-- instance Show BasicSym where +-- show = basic_str +-- +-- instance Eq BasicSym where +-- (==) a b = unsafePerformIO $ do +-- i <- withBasicSym2 a b basic_eq_ffi +-- return $ i == 1 +-- +-- instance Num BasicSym where +-- (+) = basic_binaryop basic_add_ffi +-- (-) = basic_binaryop basic_sub_ffi +-- (*) = basic_binaryop basic_mul_ffi +-- negate = basic_unaryop basic_neg_ffi +-- abs = basic_unaryop basic_abs_ffi +-- signum = undefined +-- fromInteger = basic_from_integer +-- +-- instance Fractional BasicSym where +-- (/) = basic_binaryop basic_div_ffi +-- fromRational (num :% denom) = basic_rational_from_integer num denom +-- recip r = one / r +-- +-- instance Floating BasicSym where +-- pi = Symengine.pi +-- exp x = e ** x +-- log = undefined +-- sqrt x = x ** 1 / 2 +-- (**) = basic_pow +-- logBase = undefined +-- sin = basic_unaryop basic_sin_ffi +-- cos = basic_unaryop basic_cos_ffi +-- tan = basic_unaryop basic_tan_ffi +-- asin = basic_unaryop basic_asin_ffi +-- acos = basic_unaryop basic_acos_ffi +-- atan = basic_unaryop basic_atan_ffi +-- sinh = basic_unaryop basic_sinh_ffi +-- cosh = basic_unaryop basic_cosh_ffi +-- tanh = basic_unaryop basic_tanh_ffi +-- asinh = basic_unaryop basic_asinh_ffi +-- acosh = basic_unaryop basic_acosh_ffi +-- atanh = basic_unaryop basic_atanh_ffi +-- +-- foreign import ccall "symengine/cwrapper.h ascii_art_str" ascii_art_str_ffi :: IO CString +-- foreign import ccall "symengine/cwrapper.h basic_new_heap" basic_new_heap_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h &basic_free_heap" ptr_basic_free_heap_ffi :: FunPtr (Ptr BasicStruct -> IO ()) +-- +-- -- constants +-- foreign import ccall "symengine/cwrapper.h basic_const_zero" basic_const_zero_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_one" basic_const_one_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_minus_one" basic_const_minus_one_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_I" basic_const_I_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_pi" basic_const_pi_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_E" basic_const_E_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_const_EulerGamma" basic_const_EulerGamma_ffi :: Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_str" basic_str_ffi :: Ptr BasicStruct -> IO CString +-- foreign import ccall "symengine/cwrapper.h basic_eq" basic_eq_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO Int +-- +-- foreign import ccall "symengine/cwrapper.h symbol_set" symbol_set_ffi :: Ptr BasicStruct -> CString -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_diff" basic_diff_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h integer_set_si" integer_set_si_ffi :: Ptr BasicStruct -> CLong -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h rational_set" rational_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h rational_set_si" rational_set_si_ffi :: Ptr BasicStruct -> CLong -> CLong -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h complex_set" complex_set_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_expand" basic_expand_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_add" basic_add_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_sub" basic_sub_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_mul" basic_mul_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_div" basic_div_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_pow" basic_pow_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_neg" basic_neg_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_abs" basic_abs_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_sin" basic_sin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_cos" basic_cos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_tan" basic_tan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_asin" basic_asin_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_acos" basic_acos_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_atan" basic_atan_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_sinh" basic_sinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_cosh" basic_cosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_tanh" basic_tanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- +-- foreign import ccall "symengine/cwrapper.h basic_asinh" basic_asinh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_acosh" basic_acosh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () +-- foreign import ccall "symengine/cwrapper.h basic_atanh" basic_atanh_ffi :: Ptr BasicStruct -> Ptr BasicStruct -> IO () diff --git a/src/Symengine/Context.hs b/src/Symengine/Context.hs new file mode 100644 index 0000000..c2ddde6 --- /dev/null +++ b/src/Symengine/Context.hs @@ -0,0 +1,108 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE ViewPatterns #-} + +-- | +-- Module : Symengine.Context +-- Description : Helpers to setup inline-c for Symengine +-- Copyright : (c) Tom Westerhout, 2023 +-- +-- This module defines a Template Haskell function 'importSymengine' that sets up everything you need +-- to call SymEngine functions from 'Language.C.Inline' quasiquotes. +module Symengine.Context + ( Basic (..) + , DenseMatrix (..) + , CxxString + , CxxBasic + , CxxInteger + , importSymengine + , constructBasicFrom + ) +where + +import Data.Kind (Type) +import Data.Map (Map, fromList) +import Data.Vector (Vector) +import Foreign.ForeignPtr +import Language.C.Inline qualified as C +import Language.C.Inline.Context (Context (ctxTypesTable)) +import Language.C.Inline.Cpp qualified as Cpp +import Language.C.Inline.Cpp.Exception qualified as C +import Language.C.Inline.Unsafe qualified as CU +import Language.C.Types (TypeSpecifier (..)) +import Language.Haskell.TH (DecsQ, Exp, Q, TypeQ) + +-- | Basic building block of SymEngine expressions. +newtype Basic = Basic (ForeignPtr CxxBasic) + +data DenseMatrix a = DenseMatrix {dmRows :: !Int, dmCols :: !Int, dmData :: !(Vector a)} + +data CxxBasic + +data CxxInteger + +data CxxString + +-- | One stop function to include all the neccessary machinery to call SymEngine functions via +-- inline-c. +-- +-- Put @importSymengine@ somewhere at the beginning of the file and enjoy using the C interface of +-- SymEngine via inline-c quasiquotes. +importSymengine :: DecsQ +importSymengine = + concat + <$> sequence + [ C.context symengineCxt + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , C.include "" + , defineCxxUtils + ] + +symengineCxt :: C.Context +symengineCxt = + C.funCtx <> C.fptrCtx <> C.bsCtx <> Cpp.cppCtx <> C.baseCtx <> mempty {ctxTypesTable = symengineTypePairs} + +symengineTypePairs :: Map TypeSpecifier TypeQ +symengineTypePairs = + fromList + [ (TypeName "Object", [t|CxxBasic|]) + , (TypeName "Vector", [t|Vector Basic|]) + , (TypeName "DenseMatrix", [t|DenseMatrix Basic|]) + , (TypeName "integer_class", [t|CxxInteger|]) + , (TypeName "std::string", [t|CxxString|]) + ] + +defineCxxUtils :: DecsQ +defineCxxUtils = + C.verbatim + "\ + \using Object = SymEngine::RCP; \n\ + \using Vector = SymEngine::vec_basic; \n\ + \using namespace SymEngine; \n\ + \ \n\ + \#define CONSTRUCT_BASIC(dest, expr) new (dest) Object{expr} \n\ + \ \n\ + \" + +constructBasicFrom :: String -> Q Exp +constructBasicFrom expr = + C.substitute + [("expr", const expr)] + [| + constructBasic $ \dest -> + [CU.block| void { + using namespace SymEngine; + new ($(Object* dest)) Object{@expr()}; + } |] + |] diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs index 8ac7bd1..127f50d 100644 --- a/src/Symengine/Internal.hs +++ b/src/Symengine/Internal.hs @@ -1,639 +1,124 @@ -{-# LANGUAGE CApiFFI #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TypeFamilies #-} - +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE ViewPatterns #-} + +-- | +-- Module : Symengine.Internal +-- Description : Symengine bindings to Haskell module Symengine.Internal - ( Basic, - basicFromText, - basicToText, - mkFunction, - im, - - -- ** Predicates - isNumber, - isInteger, - isRational, - isComplex, - isSymbol, - isPositive, - isNegative, - isZero, - - -- ** Complex numbers - realPart, - imagPart, - - -- ** Vector - Vec, - vecSize, - vecIndex, - - -- ** Set - Set, - setSize, - setElem, - - -- ** Utilities - freeSymbols, - functionSymbols, - symengineVersion, - - -- ** Reexports - toList, - fromList, - fromString, - ) -where - -import Control.Exception (bracket) -import Data.Bits (toIntegralSized) -import Data.ByteString (packCString, useAsCString) -import Data.Ratio -import Data.Text (Text) -import qualified Data.Text as Text -import Data.Text.Encoding (decodeUtf8, encodeUtf8) -import Foreign.C.String (CString) -import Foreign.C.Types (CInt (..), CLong (..), CSize (..)) + ( constructStringFrom + , createDenseMatrixVia + , mkUnaryFunction + , mkBinaryFunction + , unpackFunction + ) where + +import Control.Exception (bracket_) +import Control.Monad +import Data.Bits +import Data.ByteString (packCString) +import Data.Text (Text, pack, unpack) +import Data.Text.Encoding qualified as T +import Data.Vector (Vector) +import Data.Vector qualified as V +import Foreign.C.Types import Foreign.ForeignPtr +import Foreign.Marshal (allocaBytes, toBool) import Foreign.Ptr -import Foreign.Storable -import GHC.Exts (IsList (..), IsString (..)) -import GHC.Generics (Generic) -import GHC.Stack (HasCallStack) -import System.IO.Unsafe (unsafePerformIO) - -data {-# CTYPE "symengine/cwrapper.h" "basic_struct" #-} Cbasic_struct - = Cbasic_struct - {-# UNPACK #-} !(Ptr ()) - {-# UNPACK #-} !(Ptr ()) - {-# UNPACK #-} !CInt - deriving stock (Show, Eq, Generic) - -instance Storable Cbasic_struct where - sizeOf _ = 24 - {-# INLINE sizeOf #-} - alignment _ = 8 - {-# INLINE alignment #-} - peek _ = error "Storable instance for Cbasic_struct does not implement peek, because you should not rely on the internal representation of it" - poke _ _ = error "Storable instance for Cbasic_struct does not implement poke, because you should not rely on the internal representation of it" - -data {-# CTYPE "symengine/cwrapper.h" "CVecBasic" #-} CVecBasic - -data {-# CTYPE "symengine/cwrapper.h" "CSetBasic" #-} CSetBasic - -data SymengineError - = RuntimeError - | DivideByZero - | NotImplemented - | DomainError - | ParseError - | SerializationError - deriving stock (Show, Eq, Generic) - -instance Enum SymengineError where - toEnum e = case e of - 1 -> RuntimeError - 2 -> DivideByZero - 3 -> NotImplemented - 4 -> DomainError - 5 -> ParseError - 6 -> SerializationError - _ -> error "invalid error code" - fromEnum _ = error "Enum instance of SymengineError does not provide fromEnum" - -newtype Basic = Basic (ForeignPtr Cbasic_struct) - -newtype Vec = Vec (ForeignPtr CVecBasic) - -newtype Set = Set (ForeignPtr CSetBasic) - --- | Allocate a new 'Basic' and use the provided function for initialization. -newBasic :: (Ptr Cbasic_struct -> IO ()) -> IO Basic -newBasic initialize = do - x@(Basic fp) <- newBasicNoDestructor initialize - addForeignPtrFinalizer basic_free_stack fp - pure x - --- | Same as 'newBasic', but do not attach a finalizer to the underlying 'ForeignPtr' -newBasicNoDestructor :: (Ptr Cbasic_struct -> IO ()) -> IO Basic -newBasicNoDestructor initialize = do - fp <- mallocForeignPtr - withForeignPtr fp (\p -> basic_new_stack p >> initialize p) - pure $ Basic fp - -withBasic :: Basic -> (Ptr Cbasic_struct -> IO a) -> IO a -withBasic (Basic fp) = withForeignPtr fp - --- | Allocate a new 'Vec'. -newVec :: IO Vec -newVec = pure . Vec =<< newForeignPtr vecbasic_free =<< vecbasic_new - -withVec :: Vec -> (Ptr CVecBasic -> IO a) -> IO a -withVec (Vec fp) = withForeignPtr fp - -vecSize :: Vec -> Int -vecSize x = - unsafePerformIO . withVec x $ - fmap fromIntegral . vecbasic_size - -vecGet :: HasCallStack => Vec -> Int -> IO Basic -vecGet v i = withVec v $ \vPtr -> newBasic $ \xPtr -> - checkError "vecbasic_get" =<< vecbasic_get vPtr (fromIntegral i) xPtr - -vecIndex :: HasCallStack => Vec -> Int -> Basic -vecIndex v i = unsafePerformIO $! vecGet v i - -vecSet :: HasCallStack => Vec -> Int -> Basic -> IO () -vecSet v i x = withVec v $ \vPtr -> withBasic x $ \xPtr -> - checkError "vecbasic_set" =<< vecbasic_set vPtr (fromIntegral i) xPtr - -vecPushBack :: HasCallStack => Vec -> Basic -> IO () -vecPushBack v x = withVec v $ \vPtr -> withBasic x $ \xPtr -> - checkError "vecbasic_push_back" =<< vecbasic_push_back vPtr xPtr - -instance IsList Vec where - type Item Vec = Basic - toList v = unsafePerformIO $ go (vecSize v - 1) [] - where - go !i acc - | i >= 0 = do - !x <- vecGet v i - go (i - 1) (x : acc) - | otherwise = pure acc - fromList list = unsafePerformIO $ do - v <- newVec - let go [] = pure () - go (x : xs) = vecPushBack v x >> go xs - go list - pure v - -newSet :: IO Set -newSet = pure . Set =<< newForeignPtr setbasic_free =<< setbasic_new - -withSet :: Set -> (Ptr CSetBasic -> IO a) -> IO a -withSet (Set fp) = withForeignPtr fp - -setSize :: Set -> Int -setSize x = unsafePerformIO . withSet x $ fmap fromIntegral . setbasic_size - -setGet :: Set -> Int -> IO Basic -setGet s i = withSet s $ \sPtr -> newBasic $ \xPtr -> - setbasic_get sPtr (fromIntegral i) xPtr - -setInsert :: Set -> Basic -> IO Bool -setInsert s x = withSet s $ \sPtr -> withBasic x $ \xPtr -> - toEnum . fromIntegral <$> setbasic_insert sPtr xPtr - -setFind :: Set -> Basic -> IO Bool -setFind s x = withSet s $ \sPtr -> withBasic x $ - fmap (toEnum . fromIntegral) . setbasic_find sPtr - -setElem :: Basic -> Set -> Bool -setElem x s = unsafePerformIO $! setFind s x - -instance IsList Set where - type Item Set = Basic - toList s = unsafePerformIO $ go (setSize s - 1) [] - where - go !i acc - | i >= 0 = do - !x <- setGet s i - go (i - 1) (x : acc) - | otherwise = pure acc - fromList list = unsafePerformIO $ do - s <- newSet - let go [] = pure () - go (x : xs) = do - _ <- setInsert s x - go xs - go list - pure s - -checkError :: HasCallStack => Text -> CInt -> IO () -checkError name e - | e == 0 = pure () - | otherwise = - error $ - Text.unpack name <> " failed with: " <> show (toEnum (fromIntegral e) :: SymengineError) - -unaryOp :: (Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO ()) -> Basic -> Basic -unaryOp = unaryOp' pure - -unaryOp' :: (a -> IO ()) -> (Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO a) -> Basic -> Basic -unaryOp' check f x = unsafePerformIO $! - withBasic x $ \xPtr -> - newBasic (\out -> check =<< f out xPtr) - -binaryOp :: (Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO ()) -> Basic -> Basic -> Basic -binaryOp = binaryOp' pure - -binaryOp' :: (a -> IO ()) -> (Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO a) -> Basic -> Basic -> Basic -binaryOp' check f x y = unsafePerformIO $! - withBasic x $ \xPtr -> - withBasic y $ \yPtr -> - newBasic (\out -> check =<< f out xPtr yPtr) - -queryOp :: (Ptr Cbasic_struct -> IO CInt) -> Basic -> Bool -queryOp f x = - unsafePerformIO $! - withBasic x $ - fmap (toEnum . fromIntegral) . f - -basicToText :: Basic -> Text -basicToText x = unsafePerformIO $ - withBasic x $ \p -> - bracket (basic_str p) basic_str_free $ \cStr -> do - -- NOTE: need to force evaluation before the C string is freed - !r <- peekUtf8 cStr - pure r - -basicFromText :: Text -> Maybe Basic -basicFromText s = unsafePerformIO $! - withUtf8 s $ \cStr -> do - x <- newBasic (\_ -> pure ()) - withBasic x $ \p -> do - e <- basic_parse p cStr - if e /= 0 - then case toEnum (fromIntegral e) of - ParseError -> pure Nothing - otherError -> error $ "basic_parse of " <> show s <> " failed with: " <> show otherError - else pure (Just x) - -instance Show Basic where - showsPrec p x = - showParen (p > 0) - . showString - . Text.unpack - . basicToText - $ x - -instance IsString Basic where - fromString s = case (basicFromText . Text.pack) s of - Just x -> x - Nothing -> error $ "could not convert " <> show s <> " to Basic" - -instance Eq Basic where - (==) a b = unsafePerformIO $! - withBasic a $ \aPtr -> - withBasic b $ \bPtr -> - toEnum . fromIntegral <$> basic_eq aPtr bPtr - -basicFromInt :: Int -> Basic -basicFromInt n = - unsafePerformIO $! do - x <- newBasic (\_ -> pure ()) - withBasic x $ \p -> do - checkError "integer_set_si" =<< integer_set_si p (fromIntegral n) - pure x - -isZero :: Basic -> Bool -isZero x = queryOp number_is_zero x - -isPositive :: Basic -> Bool -isPositive x = queryOp number_is_positive x - -isNegative :: Basic -> Bool -isNegative x = queryOp number_is_negative x - -isNumber :: Basic -> Bool -isNumber = queryOp is_a_Number - -isInteger :: Basic -> Bool -isInteger = queryOp is_a_Integer - -isRational :: Basic -> Bool -isRational = queryOp is_a_Rational - -isSymbol :: Basic -> Bool -isSymbol = queryOp is_a_Symbol - -isComplex :: Basic -> Bool -isComplex = queryOp is_a_Complex - -instance Num Basic where - (+) = binaryOp basic_add - (-) = binaryOp basic_sub - (*) = binaryOp basic_mul - negate = unaryOp basic_neg - abs = unaryOp basic_abs - signum _ = error "Num instance of Basic does not implement signum" - fromInteger n = case toIntegralSized n of - Just k -> basicFromInt k - Nothing -> error $ "integer overflow in fromInteger " <> show n - -instance Fractional Basic where - (/) = binaryOp basic_div - fromRational r = - binaryOp' - (checkError "rational_set") - rational_set - (fromInteger (numerator r)) - (fromInteger (denominator r)) - recip r = constOne / r - -instance Floating Basic where - pi = constPi - exp = unaryOp' (checkError "basic_exp") basic_exp - log = unaryOp' (checkError "basic_log") basic_log - (**) = binaryOp' (checkError "basic_pow") basic_pow - sqrt = unaryOp' (checkError "basic_sqrt") basic_sqrt - sin = unaryOp' (checkError "basic_sin") basic_sin - cos = unaryOp' (checkError "basic_cos") basic_cos - tan = unaryOp' (checkError "basic_tan") basic_tan - asin = unaryOp' (checkError "basic_asin") basic_asin - acos = unaryOp' (checkError "basic_acos") basic_acos - atan = unaryOp' (checkError "basic_atan") basic_atan - sinh = unaryOp' (checkError "basic_sinh") basic_sinh - cosh = unaryOp' (checkError "basic_cosh") basic_cosh - tanh = unaryOp' (checkError "basic_tanh") basic_tanh - asinh = unaryOp' (checkError "basic_asinh") basic_asinh - acosh = unaryOp' (checkError "basic_acosh") basic_acosh - atanh = unaryOp' (checkError "basic_atanh") basic_atanh - -constZero :: Basic -constZero = unsafePerformIO $! newBasicNoDestructor basic_const_zero - -constOne :: Basic -constOne = unsafePerformIO $! newBasicNoDestructor basic_const_one - -constPi :: Basic -constPi = unsafePerformIO $! newBasicNoDestructor basic_const_pi - -im :: Basic -im = unsafePerformIO $! newBasicNoDestructor basic_const_I - -realPart :: HasCallStack => Basic -> Basic -realPart = unaryOp' (checkError "complex_base_real_part") complex_base_real_part - -imagPart :: HasCallStack => Basic -> Basic -imagPart = unaryOp' (checkError "complex_base_imaginary_part") complex_base_imaginary_part - -mkFunction :: HasCallStack => Text -> Vec -> Basic -mkFunction name args = - unsafePerformIO $! - newBasic $ \xPtr -> - withUtf8 name $ \namePtr -> - withVec args $ \argsPtr -> - checkError "function_symbol_set" - =<< function_symbol_set xPtr namePtr argsPtr - -freeSymbols :: HasCallStack => Basic -> Set -freeSymbols x = - unsafePerformIO $! do - s <- newSet - withBasic x $ \xPtr -> withSet s $ \sPtr -> - checkError "basic_free_symbols" =<< basic_free_symbols xPtr sPtr - pure s - -functionSymbols :: HasCallStack => Basic -> Set -functionSymbols x = - unsafePerformIO $! do - s <- newSet - withBasic x $ \xPtr -> withSet s $ \sPtr -> - checkError "basic_function_symbols" =<< basic_function_symbols sPtr xPtr - pure s - --- | Unicode-safe alternative to 'peekCString' -peekUtf8 :: CString -> IO Text -peekUtf8 = fmap decodeUtf8 . packCString - --- | Unicode-safe alternative to 'withCString' -withUtf8 :: Text -> (CString -> IO a) -> IO a -withUtf8 x = useAsCString (encodeUtf8 x) - --- | Version of the underlying SymEngine C++ library -symengineVersion :: Text -symengineVersion = unsafePerformIO $ peekUtf8 =<< symengine_version -{-# NOINLINE symengineVersion #-} - -foreign import capi unsafe "symengine/cwrapper.h basic_new_stack" - basic_new_stack :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h &basic_free_stack" - basic_free_stack :: FunPtr (Ptr Cbasic_struct -> IO ()) - -foreign import capi unsafe "symengine/cwrapper.h basic_str" - basic_str :: Ptr Cbasic_struct -> IO CString - -foreign import capi unsafe "symengine/cwrapper.h basic_str_free" - basic_str_free :: CString -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_zero" - basic_const_zero :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_one" - basic_const_one :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_minus_one" - basic_const_minus_one :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_I" - basic_const_I :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_pi" - basic_const_pi :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_E" - basic_const_E :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_EulerGamma" - basic_const_EulerGamma :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_Catalan" - basic_const_Catalan :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_const_GoldenRatio" - basic_const_GoldenRatio :: Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h integer_set_si" - integer_set_si :: Ptr Cbasic_struct -> CLong -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h rational_set" - rational_set :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_parse" - basic_parse :: Ptr Cbasic_struct -> CString -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_eq" - basic_eq :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_add" - basic_add :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_sub" - basic_sub :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_mul" - basic_mul :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_div" - basic_div :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_pow" - basic_pow :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_abs" - basic_abs :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_neg" - basic_neg :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h basic_sqrt" - basic_sqrt :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_sin" - basic_sin :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_cos" - basic_cos :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_tan" - basic_tan :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_asin" - basic_asin :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_acos" - basic_acos :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_atan" - basic_atan :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_sinh" - basic_sinh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_cosh" - basic_cosh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_tanh" - basic_tanh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_asinh" - basic_asinh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_acosh" - basic_acosh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_atanh" - basic_atanh :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_exp" - basic_exp :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_log" - basic_log :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h number_is_zero" - number_is_zero :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h number_is_negative" - number_is_negative :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h number_is_positive" - number_is_positive :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h number_is_complex" - number_is_complex :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_Number" - is_a_Number :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_Integer" - is_a_Integer :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_Rational" - is_a_Rational :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_Symbol" - is_a_Symbol :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_Complex" - is_a_Complex :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_RealDouble" - is_a_RealDouble :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_ComplexDouble" - is_a_ComplexDouble :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_RealMPFR" - is_a_RealMPFR :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_ComplexMPC" - is_a_ComplexMPC :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h is_a_Set" - is_a_Set :: Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_get_args" - basic_get_args :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_free_symbols" - basic_free_symbols :: Ptr Cbasic_struct -> Ptr CSetBasic -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h basic_function_symbols" - basic_function_symbols :: Ptr CSetBasic -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h function_symbol_set" - function_symbol_set :: Ptr Cbasic_struct -> CString -> Ptr CVecBasic -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h complex_base_real_part" - complex_base_real_part :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h complex_base_imaginary_part" - complex_base_imaginary_part :: Ptr Cbasic_struct -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h vecbasic_new" - vecbasic_new :: IO (Ptr CVecBasic) - -foreign import capi unsafe "symengine/cwrapper.h &vecbasic_free" - vecbasic_free :: FunPtr (Ptr CVecBasic -> IO ()) - -foreign import capi unsafe "symengine/cwrapper.h vecbasic_push_back" - vecbasic_push_back :: Ptr CVecBasic -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h vecbasic_get" - vecbasic_get :: Ptr CVecBasic -> CSize -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h vecbasic_set" - vecbasic_set :: Ptr CVecBasic -> CSize -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h vecbasic_erase" - vecbasic_erase :: Ptr CVecBasic -> CSize -> IO () - -foreign import capi unsafe "symengine/cwrapper.h vecbasic_size" - vecbasic_size :: Ptr CVecBasic -> IO CSize - -foreign import capi unsafe "symengine/cwrapper.h basic_max" - basic_max :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize - -foreign import capi unsafe "symengine/cwrapper.h basic_min" - basic_min :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize - -foreign import capi unsafe "symengine/cwrapper.h basic_add_vec" - basic_add_vec :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize - -foreign import capi unsafe "symengine/cwrapper.h basic_mul_vec" - basic_mul_vec :: Ptr Cbasic_struct -> Ptr CVecBasic -> IO CSize - -foreign import capi unsafe "symengine/cwrapper.h setbasic_new" - setbasic_new :: IO (Ptr CSetBasic) - -foreign import capi unsafe "symengine/cwrapper.h &setbasic_free" - setbasic_free :: FunPtr (Ptr CSetBasic -> IO ()) - -foreign import capi unsafe "symengine/cwrapper.h setbasic_insert" - setbasic_insert :: Ptr CSetBasic -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h setbasic_get" - setbasic_get :: Ptr CSetBasic -> CInt -> Ptr Cbasic_struct -> IO () - -foreign import capi unsafe "symengine/cwrapper.h setbasic_find" - setbasic_find :: Ptr CSetBasic -> Ptr Cbasic_struct -> IO CInt - -foreign import capi unsafe "symengine/cwrapper.h setbasic_size" - setbasic_size :: Ptr CSetBasic -> IO CSize - -foreign import ccall unsafe "symengine/cwrapper.h symengine_version" - symengine_version :: IO CString +import GHC.Exts (IsString (..)) +import GHC.Real (Ratio (..)) +import Language.C.Inline qualified as C +import Language.C.Inline.Context (Context (ctxTypesTable)) +import Language.C.Inline.Cpp qualified as Cpp +import Language.C.Inline.Cpp.Exception qualified as C +import Language.C.Inline.Unsafe qualified as CU +import Language.C.Types (TypeSpecifier (..)) +import Language.Haskell.TH (DecsQ, Exp, Q, TypeQ) +import Symengine.Context +import System.IO.Unsafe + +constructStringFrom :: String -> Q Exp +constructStringFrom expr = + C.substitute + [("expr", const expr)] + [| + let size = fromIntegral [CU.pure| size_t { sizeof(std::string) } |] + construct s = + [CU.block| void { + using namespace SymEngine; + new ($(std::string* s)) std::string{@expr()}; + } |] + destruct s = [CU.exp| void { $(std::string* s)->~basic_string() } |] + in allocaBytes size $ \s -> + bracket_ (construct s) (destruct s) $ + fmap T.decodeUtf8 $ + packCString + =<< [CU.exp| char const* { $(const std::string* s)->c_str() } |] + |] + +createDenseMatrixVia :: String -> Q Exp +createDenseMatrixVia expr = + C.substitute + [("expr", const expr)] + [| + allocaDenseMatrix 0 0 $ \ptr -> do + [CU.block| void { + auto& out = *$(DenseMatrix* ptr); + @expr() + } |] + peekDenseMatrix ptr + |] + +mkUnaryFunction :: String -> Q Exp +mkUnaryFunction expr = + C.substitute + [ ("expr", const expr) + ] + [| + \a' -> + unsafePerformIO $ + withBasic a' $ \a -> + constructBasic $ \dest -> + [CU.block| void { + using namespace SymEngine; + auto const& a = *$(const Object* a); + new ($(Object* dest)) Object{@expr()}; + } |] + |] + +mkBinaryFunction :: String -> Q Exp +mkBinaryFunction expr = + C.substitute + [ ("expr", const expr) + ] + [| + \a' b' -> + unsafePerformIO $ + withBasic a' $ \a -> + withBasic b' $ \b -> + constructBasic $ \dest -> + [CU.block| void { + using namespace SymEngine; + auto const& a = *$(const Object* a); + auto const& b = *$(const Object* b); + new ($(Object* dest)) Object{@expr()}; + } |] + |] + +unpackFunction :: String -> Q Exp +unpackFunction className = + C.substitute + [ ("class", const className) + ] + [| + \f' -> + withBasic f' $ \f -> + allocaVector $ \v -> do + [CU.block| void { + using namespace SymEngine; + auto const& f = down_cast<@class() const&>(**$(const Object* f)); + *$(Vector* v) = f.get_args(); + } |] + peekVector v + |] diff --git a/symengine.cabal b/symengine.cabal index 8e0491d..677bb4d 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -1,62 +1,90 @@ -cabal-version: 3.0 -name: symengine -version: 0.1.2.0 -synopsis: SymEngine symbolic mathematics engine for Haskell -description: Please see README.md -homepage: https://github.com/symengine/symengine.hs -license: MIT -license-file: LICENSE -author: Siddharth Bhat -maintainer: siddu.druid@gmail.com -copyright: 2016 Siddharth Bhat -category: FFI, Math, Symbolic Computation -build-type: Simple -tested-with: GHC == 8.10.7 +cabal-version: 3.0 +name: symengine +version: 0.2.0.0 +synopsis: SymEngine symbolic mathematics engine for Haskell +description: Please see README.md +homepage: https://github.com/symengine/symengine.hs +license: MIT +license-file: LICENSE +author: Siddharth Bhat +maintainer: siddu.druid@gmail.com +copyright: + 2016 Siddharth Bhat + 2023 Tom Westerhout + +category: FFI, Math, Symbolic Computation +build-type: Simple +tested-with: GHC ==9.2.7 + +flag no-flint + description: disable linking with Flint + manual: True + default: False + +flag no-mpc + description: disable linking with MPC + manual: True + default: False + +flag no-mpfr + description: disable linking with MPFR + manual: True + default: False common common-options - build-depends: base >= 4.13.0.0 - - ghc-options: -Wall - -Wcompat - -Widentities - -Wincomplete-uni-patterns - -Wincomplete-record-updates - if impl(ghc >= 8.0) - ghc-options: -Wredundant-constraints - if impl(ghc >= 8.2) - ghc-options: -fhide-source-paths - if impl(ghc >= 8.4) - ghc-options: -Wmissing-export-lists - -Wpartial-fields - if impl(ghc >= 8.8) - ghc-options: -Wmissing-deriving-strategies - - default-language: Haskell2010 - default-extensions: BangPatterns - FlexibleContexts - FlexibleInstances - DerivingVia + build-depends: base >=4.16.0.0 && <5 + ghc-options: + -Weverything -Wno-unsafe -Wno-all-missed-specialisations + -Wno-missing-safe-haskell-mode -Wno-implicit-prelude + -Wno-missing-import-lists -Wno-missing-kind-signatures + -Wno-monomorphism-restriction + + default-language: GHC2021 + default-extensions: DerivingStrategies library - import: common-options - hs-source-dirs: src - exposed-modules: Symengine - Symengine.Internal - build-depends: text - , bytestring - extra-libraries: symengine + import: common-options + hs-source-dirs: src + exposed-modules: Symengine + other-modules: + Symengine.Context + Symengine.Internal + + build-depends: + , bytestring + , containers + , ghc-bignum + , inline-c + , inline-c-cpp + , template-haskell + , text + , vector + + extra-libraries: symengine + + if !flag(no-flint) + extra-libraries: flint + + if !flag(no-mpc) + extra-libraries: mpc + + if !flag(no-mpfr) + extra-libraries: mpfr + if os(linux) - extra-libraries: stdc++ - if os(darwin) || os(osx) - extra-libraries: c++ + extra-libraries: stdc++ + + if os(osx) + extra-libraries: c++ test-suite symengine-test - import: common-options - type: exitcode-stdio-1.0 - hs-source-dirs: test - main-is: Spec.hs - build-depends: symengine - , tasty >= 0.10.0 - , tasty-hunit >= 0.9.0 - , tasty-quickcheck >= 0.8.0 - ghc-options: -threaded -rtsopts -with-rtsopts=-N + import: common-options + type: exitcode-stdio-1.0 + hs-source-dirs: test + main-is: Spec.hs + build-depends: + , hspec + , symengine + , text + + ghc-options: -threaded -rtsopts -with-rtsopts=-N diff --git a/test/Spec.hs b/test/Spec.hs index fef6449..6bd19ca 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,76 +1,62 @@ {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -import Data.List -import Data.Monoid -import Data.Ord --- import Symengine as Sym +import Control.Monad (unless) import Data.Ratio -import Symengine.Internal -import Test.Tasty -import Test.Tasty.HUnit -import Test.Tasty.QuickCheck as QC -import Prelude hiding (pi) - -main = defaultMain tests - -tests :: TestTree -tests = testGroup "Tests" [unitTests] - --- These are used to check invariants that can be tested by creating --- random members of the type and then checking invariants on them - --- properties :: TestTree --- properties = testGroup "Properties" [qcProps] - -unitTests = - testGroup - "Unit tests" - [ -- testCase "FFI Sanity Check - ASCII Art should be non-empty" $ - -- do - -- ascii_art <- Sym.ascii_art_str - -- assertBool "ASCII art from ascii_art_str is empty" (not . null $ ascii_art), - testCase "test_complex" $ do - let r = fromRational (100 % 47) :: Basic - i = fromRational (76 % 59) - e = r + i * im - show e @?= "100/47 + 76/59*I" - isSymbol e @?= False - isRational e @?= False - isInteger e @?= False - isComplex e @?= True - isZero e @?= False - isNegative e @?= False - isPositive e @?= False - - show (realPart e) @?= "100/47" - isSymbol (realPart e) @?= False - isRational (realPart e) @?= True - isInteger (realPart e) @?= False - isComplex (realPart e) @?= False - - show (imagPart e) @?= "76/59" - isSymbol (imagPart e) @?= False - isRational (imagPart e) @?= True - isInteger (imagPart e) @?= False - isComplex (imagPart e) @?= False, - testCase "test_free_symbols" $ do - let x = "x" :: Basic - y = "y" - z = "z" - e = 123 - expr = (e + x) ** y / z - - setSize (freeSymbols expr) @?= 3 - toList (freeSymbols expr) @?= ["x", "y", "z"], - testCase "test_function_symbols" $ do - let x = "x" :: Basic - y = "y" - z = "z" - g = mkFunction "g" [x] - h = mkFunction "h" [g] - f = mkFunction "f" [x + y, g, h] - - show (z + f) @?= "z + f(x + y, g(x), h(g(x)))" - setSize (functionSymbols f) @?= 3 - ] +import Data.Text (pack) +import Symengine +import Test.Hspec +import Test.Hspec.QuickCheck + +main :: IO () +main = hspec $ do + describe "Num" $ do + prop "Integer" $ \(a :: Integer) (b :: Integer) -> do + show (fromIntegral @_ @Basic a) `shouldBe` show a + show (fromIntegral @_ @Basic a + fromIntegral b) `shouldBe` show (a + b) + show (fromIntegral @_ @Basic a - fromIntegral b) `shouldBe` show (a - b) + show (fromIntegral @_ @Basic a * fromIntegral b) `shouldBe` show (a * b) + show (abs (fromIntegral @_ @Basic a)) `shouldBe` show (abs a) + show (negate (fromIntegral @_ @Basic a)) `shouldBe` show (negate a) + show (signum (fromIntegral @_ @Basic a)) `shouldBe` show (signum a) + describe "AST" $ do + prop "SymengineInteger" $ \(x :: Integer) -> do + toAST (fromInteger x) `shouldBe` SymengineInteger x + prop "SymengineRational" $ \(x :: Rational) -> do + if denominator x == 1 + then toAST (fromRational x) `shouldBe` SymengineInteger (numerator x) + else toAST (fromRational x) `shouldBe` SymengineRational x + it "SymengineConstant" $ do + toAST (pi :: Basic) `shouldBe` SymengineConstant pi + toAST e `shouldBe` SymengineConstant e + prop "SymengineSymbol" $ \(x :: String) -> do + unless ('\NUL' `elem` x) $ + toAST (symbol (pack x)) `shouldBe` SymengineSymbol (pack x) + it "SymengineInfinity" $ do + toAST infinity `shouldBe` SymengineInfinity + it "SymengineNaN" $ do + toAST nan `shouldBe` SymengineNaN + it "SymengineAdd" $ do + toAST (symbol "x" + symbol "z" + symbol "y") `shouldBe` SymengineAdd [symbol "x", symbol "z", symbol "y"] + toAST (symbol "x" + symbol "y") `shouldBe` SymengineAdd [symbol "x", symbol "y"] + it "SymengineMul" $ do + toAST (2 * symbol "y") `shouldBe` SymengineMul [2, symbol "y"] + toAST (-symbol "x") `shouldBe` SymengineMul [-1, symbol "x"] + it "SymenginePow" $ do + toAST (sqrt (symbol "y")) `shouldBe` SymenginePow (symbol "y") 0.5 + toAST (exp (symbol "y")) `shouldBe` SymenginePow e (symbol "y") + it "SymengineLog" $ do + toAST (log (symbol "y")) `shouldBe` SymengineLog (symbol "y") + it "SymengineSign" $ do + toAST (signum (symbol "y")) `shouldBe` SymengineSign (symbol "y") + it "SymengineFunction" $ do + toAST (parse "f(1, x, y + 2)") `shouldBe` SymengineFunction "f" [1, symbol "x", 2 + symbol "y"] + it "SymengineDerivative" $ do + toAST (diff (parse "f(x)") (symbol "x")) `shouldBe` SymengineDerivative (parse "f(x)") [symbol "x"] + toAST (diff ((symbol "x") ** 2) (symbol "x")) `shouldBe` SymengineMul [2, symbol "x"] + + describe "Misc" $ do + it "" $ do + print $ parse "a + f(x) / x - 4**2" + print $ evalf EvalSymbolic 20 $ parse "a + 8/3 * f(x) / x - 4**2" + print $ inverse InverseDefault (identityMatrix 3) From 42e6d713bb23f68a7a778db8f426dbecbd82af70 Mon Sep 17 00:00:00 2001 From: twesterhout <14264576+twesterhout@users.noreply.github.com> Date: Mon, 19 Jun 2023 12:12:50 +0200 Subject: [PATCH 17/21] Update the flake to use packageOverrides to support multiple GHC versions --- .gitignore | 4 +- flake.lock | 36 +++++++++++---- flake.nix | 125 +++++++++++++++++++++++------------------------------ 3 files changed, 83 insertions(+), 82 deletions(-) diff --git a/.gitignore b/.gitignore index d01581c..6393707 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,6 @@ tags /*.iml /src/highlight.js /src/style.css -/_site/ \ No newline at end of file +/_site/.ghc.environment.* +result +result-1 diff --git a/flake.lock b/flake.lock index cee9ed8..305e052 100644 --- a/flake.lock +++ b/flake.lock @@ -17,12 +17,15 @@ } }, "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1680776469, - "narHash": "sha256-3CXUDK/3q/kieWtdsYpDOBJw3Gw4Af6x+2EiSnIkNQw=", + "lastModified": 1685518550, + "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", "owner": "numtide", "repo": "flake-utils", - "rev": "411e8764155aa9354dbcd6d5faaeb97e9e3dce24", + "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", "type": "github" }, "original": { @@ -33,11 +36,11 @@ }, "nix-filter": { "locked": { - "lastModified": 1678109515, - "narHash": "sha256-C2X+qC80K2C1TOYZT8nabgo05Dw2HST/pSn6s+n6BO8=", + "lastModified": 1681154353, + "narHash": "sha256-MCJ5FHOlbfQRFwN0brqPbCunLEVw05D/3sRVoNVt2tI=", "owner": "numtide", "repo": "nix-filter", - "rev": "aa9ff6ce4a7f19af6415fb3721eaa513ea6c763c", + "rev": "f529f42792ade8e32c4be274af6b6d60857fbee7", "type": "github" }, "original": { @@ -48,11 +51,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1680758185, - "narHash": "sha256-sCVWwfnk7zEX8Z+OItiH+pcSklrlsLZ4TJTtnxAYREw=", + "lastModified": 1686960236, + "narHash": "sha256-AYCC9rXNLpUWzD9hm+askOfpliLEC9kwAo7ITJc4HIw=", "owner": "nixos", "repo": "nixpkgs", - "rev": "0e19daa510e47a40e06257e205965f3b96ce0ac9", + "rev": "04af42f3b31dba0ef742d254456dc4c14eedac86", "type": "github" }, "original": { @@ -69,6 +72,21 @@ "nix-filter": "nix-filter", "nixpkgs": "nixpkgs" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 326f073..72d1bf4 100644 --- a/flake.nix +++ b/flake.nix @@ -18,7 +18,17 @@ outputs = inputs: inputs.flake-utils.lib.eachDefaultSystem (system: with builtins; let - inherit (inputs.nixpkgs) lib; + src = inputs.nix-filter.lib { + root = ./.; + include = [ + "src" + "test" + "symengine.cabal" + "README.md" + "LICENSE" + ]; + }; + pkgs = import inputs.nixpkgs { inherit system; overlays = [ @@ -36,83 +46,54 @@ "-DBUILD_SHARED_LIBS=ON" ]; }); + + haskell = super.haskell // { + packageOverrides = hself: hsuper: { + symengine = (hself.callCabal2nix "symengine" src { + inherit (self) symengine; + mpc = self.libmpc; + }); + }; + }; }) ]; - }; - src = inputs.nix-filter.lib { - root = ./.; - include = [ - "src" - "test" - "symengine.cabal" - "README.md" - "LICENSE" - ]; + config.allowBroken = true; }; - # This allows us to build a Haskell package with any given GHC version. - # It will also affects all dependent libraries. - # overrides allows us to patch existing Haskell packages, or introduce new ones - # see here for specifics: https://nixos.wiki/wiki/Overlays - haskellPackagesOverride = ps: args: - ps.override - { - overrides = self: super: { - symengine = (self.callCabal2nix "symengine" src { - inherit (pkgs) symengine; - mpc = pkgs.libmpc; - }); - }; - }; - outputsFor = - { haskellPackages - , name - , package ? "" - , ... - }: - let - ps = haskellPackagesOverride haskellPackages { }; - in - { - packages.${name} = ps.${package} or ps; - devShells.${name} = ps.shellFor { - packages = ps: with ps; [ - symengine - ]; - withHoogle = true; - nativeBuildInputs = with pkgs; with ps; [ - # Building and testing - cabal-install - # Language servers - haskell-language-server - nil - # Formatters - fourmolu - cabal-fmt - nixpkgs-fmt - # Previewing markdown files - python3Packages.grip - ]; - shellHook = '' - LD_LIBRARY_PATH=${pkgs.symengine}/lib:${pkgs.flint}/lib:${pkgs.libmpc}/lib:${pkgs.mpfr}/lib:$LD_LIBRARY_PATH - SYMENGINE_PATH=${pkgs.symengine} - ''; - }; - # The formatter to use for .nix files (but not .hs files) - # Allows us to run `nix fmt` to reformat nix files. - formatter = pkgs.nixpkgs-fmt; - }; in - foldl' (acc: conf: lib.recursiveUpdate acc (outputsFor conf)) { } - (lib.mapAttrsToList (name: haskellPackages: { inherit name haskellPackages; }) - (lib.filterAttrs (_: ps: ps ? ghc) pkgs.haskell.packages) ++ [ - { - haskellPackages = pkgs.haskellPackages; - name = "default"; - package = "symengine"; - } - ]) + { + packages = { + default = pkgs.haskellPackages.symengine; + symengine = pkgs.haskellPackages.symengine; + haskell = pkgs.haskell.packages; + }; + + devShells.default = haskellPackages.shellFor { + packages = ps: with ps; [ symengine ]; + withHoogle = true; + nativeBuildInputs = with pkgs; with ps; [ + # Building and testing + cabal-install + # Language servers + haskell-language-server + nil + # Formatters + fourmolu + cabal-fmt + nixpkgs-fmt + # Previewing markdown files + python3Packages.grip + ]; + shellHook = '' + LD_LIBRARY_PATH=${pkgs.symengine}/lib:${pkgs.flint}/lib:${pkgs.libmpc}/lib:${pkgs.mpfr}/lib:$LD_LIBRARY_PATH + SYMENGINE_PATH=${pkgs.symengine} + ''; + }; + # The formatter to use for .nix files (but not .hs files) + # Allows us to run `nix fmt` to reformat nix files. + formatter = pkgs.nixpkgs-fmt; + } ); } From 1253ca957007384a9bd2c90407c05f9162c80288 Mon Sep 17 00:00:00 2001 From: twesterhout <14264576+twesterhout@users.noreply.github.com> Date: Mon, 19 Jun 2023 14:53:46 +0200 Subject: [PATCH 18/21] Restructure the flake; implement fromAST for SymengineFunction --- flake.nix | 130 ++++++++++++++++++-------------------- src/Symengine.hs | 45 +++++++------ src/Symengine/Context.hs | 85 ++++++++++--------------- src/Symengine/Internal.hs | 37 ++++++----- test/Spec.hs | 3 +- 5 files changed, 142 insertions(+), 158 deletions(-) diff --git a/flake.nix b/flake.nix index 72d1bf4..88cc76b 100644 --- a/flake.nix +++ b/flake.nix @@ -1,10 +1,6 @@ { description = "symengine/symengine.hs: SymEngine symbolic mathematics engine for Haskell"; - nixConfig = { - extra-experimental-features = "nix-command flakes"; - }; - inputs = { nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; flake-utils.url = "github:numtide/flake-utils"; @@ -15,10 +11,9 @@ }; }; - outputs = inputs: inputs.flake-utils.lib.eachDefaultSystem (system: - with builtins; + outputs = { nixpkgs, flake-utils, nix-filter, ... }: let - src = inputs.nix-filter.lib { + src = nix-filter.lib { root = ./.; include = [ "src" @@ -28,72 +23,73 @@ "LICENSE" ]; }; + overlay = self: super: { + symengine = super.symengine.overrideAttrs (attrs: rec { + version = "0.10.1"; + src = self.fetchFromGitHub { + owner = attrs.pname; + repo = attrs.pname; + rev = "v${version}"; + sha256 = "sha256-qTu0vS9K6rrr/0SXKpGC9P1QSN/AN7hyO/4DrGvhxWM="; + }; + cmakeFlags = (attrs.cmakeFlags or [ ]) ++ [ + "-DCMAKE_BUILD_TYPE=Debug" + "-DBUILD_SHARED_LIBS=ON" + ]; + }); - pkgs = import inputs.nixpkgs { - inherit system; - overlays = [ - (self: super: { - symengine = super.symengine.overrideAttrs (attrs: rec { - version = "0.10.1"; - src = self.fetchFromGitHub { - owner = attrs.pname; - repo = attrs.pname; - rev = "v${version}"; - sha256 = "sha256-qTu0vS9K6rrr/0SXKpGC9P1QSN/AN7hyO/4DrGvhxWM="; - }; - cmakeFlags = (attrs.cmakeFlags or [ ]) ++ [ - "-DCMAKE_BUILD_TYPE=Debug" - "-DBUILD_SHARED_LIBS=ON" - ]; + haskell = super.haskell // { + packageOverrides = hself: hsuper: { + symengine = (hself.callCabal2nix "symengine" src { + inherit (self) symengine; + mpc = self.libmpc; }); + }; + }; + }; - haskell = super.haskell // { - packageOverrides = hself: hsuper: { - symengine = (hself.callCabal2nix "symengine" src { - inherit (self) symengine; - mpc = self.libmpc; - }); - }; - }; - }) - ]; - + pkgsFor = system: import nixpkgs { + inherit system; + overlays = [ overlay ]; config.allowBroken = true; }; - - in { - packages = { - default = pkgs.haskellPackages.symengine; - symengine = pkgs.haskellPackages.symengine; - haskell = pkgs.haskell.packages; - }; + packages = flake-utils.lib.eachDefaultSystemMap (system: + with (pkgsFor system); { + default = haskellPackages.symengine; + symengine = haskellPackages.symengine; + haskell = haskell.packages; + }); - devShells.default = haskellPackages.shellFor { - packages = ps: with ps; [ symengine ]; - withHoogle = true; - nativeBuildInputs = with pkgs; with ps; [ - # Building and testing - cabal-install - # Language servers - haskell-language-server - nil - # Formatters - fourmolu - cabal-fmt - nixpkgs-fmt - # Previewing markdown files - python3Packages.grip - ]; - shellHook = '' - LD_LIBRARY_PATH=${pkgs.symengine}/lib:${pkgs.flint}/lib:${pkgs.libmpc}/lib:${pkgs.mpfr}/lib:$LD_LIBRARY_PATH - SYMENGINE_PATH=${pkgs.symengine} - ''; - }; - # The formatter to use for .nix files (but not .hs files) - # Allows us to run `nix fmt` to reformat nix files. - formatter = pkgs.nixpkgs-fmt; - } - ); + devShells = flake-utils.lib.eachDefaultSystemMap (system: + with (pkgsFor system); { + default = haskellPackages.shellFor { + packages = ps: with ps; [ symengine ]; + withHoogle = true; + nativeBuildInputs = with pkgs; with haskellPackages; [ + # Building and testing + cabal-install + # Language servers + haskell-language-server + nil + # Formatters + fourmolu + cabal-fmt + nixpkgs-fmt + # Previewing markdown files + python3Packages.grip + ]; + shellHook = '' + LD_LIBRARY_PATH=${pkgs.symengine}/lib:${pkgs.flint}/lib:${pkgs.libmpc}/lib:${pkgs.mpfr}/lib:$LD_LIBRARY_PATH + SYMENGINE_PATH=${pkgs.symengine} + ''; + }; + # The formatter to use for .nix files (but not .hs files) + # Allows us to run `nix fmt` to reformat nix files. + formatter = pkgs.nixpkgs-fmt; + } + ); + overlays.default = overlay; + }; } diff --git a/src/Symengine.hs b/src/Symengine.hs index 88e10c5..d603e48 100644 --- a/src/Symengine.hs +++ b/src/Symengine.hs @@ -3,6 +3,7 @@ {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -Wno-unused-matches #-} -- | -- Module : Symengine @@ -31,15 +32,14 @@ module Symengine import Control.Exception (bracket_) import Control.Monad -import Data.Bits -import Data.ByteString (packCString) +import Data.ByteString (useAsCString) import Data.Text (Text, pack, unpack) import Data.Text.Encoding qualified as T import Data.Vector (Vector) import Data.Vector qualified as V import Foreign.C.Types import Foreign.ForeignPtr -import Foreign.Marshal (allocaBytes, allocaBytesAligned, toBool, withArrayLen) +import Foreign.Marshal (allocaBytesAligned, toBool, withArrayLen) import Foreign.Ptr import GHC.Exts (IsString (..)) import GHC.Int @@ -47,30 +47,39 @@ import GHC.Num.BigNat import GHC.Num.Integer import GHC.Real (Ratio (..)) import Language.C.Inline qualified as C -import Language.C.Inline.Cpp.Exception qualified as C import Language.C.Inline.Unsafe qualified as CU import Symengine.Context import Symengine.Internal import System.IO.Unsafe +-- | Basic building block of SymEngine expressions. +newtype Basic = Basic (ForeignPtr CxxBasic) + +data DenseMatrix a = DenseMatrix {dmRows :: !Int, dmCols :: !Int, dmData :: !(Vector a)} + +data CxxBasic + +data CxxInteger + +data CxxString + importSymengine -- | Convert a pointer to @std::string@ into a string. -- -- It properly handles unicode characters. -peekCxxString :: Ptr CxxString -> IO Text -peekCxxString p = - fmap T.decodeUtf8 $ - packCString - =<< [CU.exp| char const* { $(const std::string* p)->c_str() } |] +-- peekCxxString :: Ptr CxxString -> IO Text +-- peekCxxString p = +-- fmap T.decodeUtf8 $ +-- packCString +-- =<< [CU.exp| char const* { $(const std::string* p)->c_str() } |] -- | Call 'peekCxxString' and @delete@ the pointer. -peekAndDeleteCxxString :: Ptr CxxString -> IO Text -peekAndDeleteCxxString p = do - s <- peekCxxString p - [CU.exp| void { delete $(const std::string* p) } |] - pure s - +-- peekAndDeleteCxxString :: Ptr CxxString -> IO Text +-- peekAndDeleteCxxString p = do +-- s <- peekCxxString p +-- [CU.exp| void { delete $(const std::string* p) } |] +-- pure s constructBasic :: (Ptr CxxBasic -> IO ()) -> IO Basic constructBasic construct = fmap Basic $ constructWithDeleter size deleter $ \ptr -> do @@ -190,9 +199,6 @@ instance Show Basic where deriving stock instance Show (DenseMatrix Basic) --- peekAndDeleteCxxString --- =<< [CU.exp| std::string* { new std::string{SymEngine::str(**$(Object* basic'))} } |] - instance Eq Basic where a == b = unsafePerformIO $ withBasic a $ \a' -> @@ -512,6 +518,9 @@ fromAST = \case SymengineLog x -> log x SymengineSign x -> signum x SymengineDerivative f v -> V.foldl' diff f v + SymengineFunction (T.encodeUtf8 -> s) v -> unsafePerformIO $ + withVector v $ \args -> + $(constructBasicFrom "function_symbol(std::string{$bs-ptr:s, static_cast($bs-len:s)}, *$(const Vector* args))") {- -- | Convert a C string into a Haskell string properly handling unicode characters. diff --git a/src/Symengine/Context.hs b/src/Symengine/Context.hs index c2ddde6..b495a86 100644 --- a/src/Symengine/Context.hs +++ b/src/Symengine/Context.hs @@ -1,8 +1,4 @@ -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskellQuotes #-} -{-# LANGUAGE ViewPatterns #-} -- | -- Module : Symengine.Context @@ -12,38 +8,17 @@ -- This module defines a Template Haskell function 'importSymengine' that sets up everything you need -- to call SymEngine functions from 'Language.C.Inline' quasiquotes. module Symengine.Context - ( Basic (..) - , DenseMatrix (..) - , CxxString - , CxxBasic - , CxxInteger - , importSymengine - , constructBasicFrom + ( importSymengine ) where -import Data.Kind (Type) -import Data.Map (Map, fromList) -import Data.Vector (Vector) -import Foreign.ForeignPtr +import Data.Map.Strict qualified as Map import Language.C.Inline qualified as C import Language.C.Inline.Context (Context (ctxTypesTable)) import Language.C.Inline.Cpp qualified as Cpp -import Language.C.Inline.Cpp.Exception qualified as C -import Language.C.Inline.Unsafe qualified as CU -import Language.C.Types (TypeSpecifier (..)) -import Language.Haskell.TH (DecsQ, Exp, Q, TypeQ) - --- | Basic building block of SymEngine expressions. -newtype Basic = Basic (ForeignPtr CxxBasic) - -data DenseMatrix a = DenseMatrix {dmRows :: !Int, dmCols :: !Int, dmData :: !(Vector a)} - -data CxxBasic - -data CxxInteger - -data CxxString +import Language.C.Types (CIdentifier, TypeSpecifier (..)) +import Language.Haskell.TH (DecsQ, Q, TypeQ, lookupTypeName) +import Language.Haskell.TH.Syntax (Type (..)) -- | One stop function to include all the neccessary machinery to call SymEngine functions via -- inline-c. @@ -54,7 +29,7 @@ importSymengine :: DecsQ importSymengine = concat <$> sequence - [ C.context symengineCxt + [ C.context =<< symengineCxt , C.include "" , C.include "" , C.include "" @@ -69,19 +44,35 @@ importSymengine = , defineCxxUtils ] -symengineCxt :: C.Context -symengineCxt = - C.funCtx <> C.fptrCtx <> C.bsCtx <> Cpp.cppCtx <> C.baseCtx <> mempty {ctxTypesTable = symengineTypePairs} +symengineCxt :: Q C.Context +symengineCxt = do + typePairs <- Map.fromList <$> symengineTypePairs + pure $ + C.funCtx <> C.fptrCtx <> C.bsCtx <> Cpp.cppCtx <> C.baseCtx <> mempty {ctxTypesTable = typePairs} -symengineTypePairs :: Map TypeSpecifier TypeQ +symengineTypePairs :: Q [(TypeSpecifier, TypeQ)] symengineTypePairs = - fromList - [ (TypeName "Object", [t|CxxBasic|]) - , (TypeName "Vector", [t|Vector Basic|]) - , (TypeName "DenseMatrix", [t|DenseMatrix Basic|]) - , (TypeName "integer_class", [t|CxxInteger|]) - , (TypeName "std::string", [t|CxxString|]) + optionals + [ ("Object", "CxxBasic") + , ("Vector", "Vector Basic") + , ("DenseMatrix", "DenseMatrix Basic") + , ("integer_class", "CxxInteger") + , ("std::string", "CxxString") ] + where + optional :: (CIdentifier, String) -> Q [(TypeSpecifier, TypeQ)] + optional (cName, hsName) = do + hsType <- case words hsName of + [x] -> fmap ConT <$> lookupTypeName x + -- TODO: generalize to multiple arguments + [f, x] -> do + con <- fmap ConT <$> lookupTypeName f + arg <- fmap ConT <$> lookupTypeName x + pure $ AppT <$> con <*> arg + _ -> pure Nothing + pure $ maybe [] (\x -> [(TypeName cName, pure x)]) hsType + optionals :: [(CIdentifier, String)] -> Q [(TypeSpecifier, TypeQ)] + optionals pairs = concat <$> mapM optional pairs defineCxxUtils :: DecsQ defineCxxUtils = @@ -94,15 +85,3 @@ defineCxxUtils = \#define CONSTRUCT_BASIC(dest, expr) new (dest) Object{expr} \n\ \ \n\ \" - -constructBasicFrom :: String -> Q Exp -constructBasicFrom expr = - C.substitute - [("expr", const expr)] - [| - constructBasic $ \dest -> - [CU.block| void { - using namespace SymEngine; - new ($(Object* dest)) Object{@expr()}; - } |] - |] diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs index 127f50d..b6e5fe5 100644 --- a/src/Symengine/Internal.hs +++ b/src/Symengine/Internal.hs @@ -1,13 +1,15 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TemplateHaskellQuotes #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -Wno-unused-matches #-} -- | -- Module : Symengine.Internal -- Description : Symengine bindings to Haskell module Symengine.Internal - ( constructStringFrom + ( constructBasicFrom + , constructStringFrom , createDenseMatrixVia , mkUnaryFunction , mkBinaryFunction @@ -15,29 +17,26 @@ module Symengine.Internal ) where import Control.Exception (bracket_) -import Control.Monad -import Data.Bits import Data.ByteString (packCString) -import Data.Text (Text, pack, unpack) import Data.Text.Encoding qualified as T -import Data.Vector (Vector) -import Data.Vector qualified as V -import Foreign.C.Types -import Foreign.ForeignPtr -import Foreign.Marshal (allocaBytes, toBool) -import Foreign.Ptr -import GHC.Exts (IsString (..)) -import GHC.Real (Ratio (..)) +import Foreign.Marshal (allocaBytes) import Language.C.Inline qualified as C -import Language.C.Inline.Context (Context (ctxTypesTable)) -import Language.C.Inline.Cpp qualified as Cpp -import Language.C.Inline.Cpp.Exception qualified as C import Language.C.Inline.Unsafe qualified as CU -import Language.C.Types (TypeSpecifier (..)) -import Language.Haskell.TH (DecsQ, Exp, Q, TypeQ) -import Symengine.Context +import Language.Haskell.TH (Exp, Q) import System.IO.Unsafe +constructBasicFrom :: String -> Q Exp +constructBasicFrom expr = + C.substitute + [("expr", const expr)] + [| + constructBasic $ \dest -> + [CU.block| void { + using namespace SymEngine; + new ($(Object* dest)) Object{@expr()}; + } |] + |] + constructStringFrom :: String -> Q Exp constructStringFrom expr = C.substitute diff --git a/test/Spec.hs b/test/Spec.hs index 6bd19ca..dc80433 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -51,9 +51,10 @@ main = hspec $ do toAST (signum (symbol "y")) `shouldBe` SymengineSign (symbol "y") it "SymengineFunction" $ do toAST (parse "f(1, x, y + 2)") `shouldBe` SymengineFunction "f" [1, symbol "x", 2 + symbol "y"] + show (fromAST (SymengineFunction "f" [1, symbol "x", 2 + symbol "y"])) `shouldBe` "f(1, x, 2 + y)" it "SymengineDerivative" $ do toAST (diff (parse "f(x)") (symbol "x")) `shouldBe` SymengineDerivative (parse "f(x)") [symbol "x"] - toAST (diff ((symbol "x") ** 2) (symbol "x")) `shouldBe` SymengineMul [2, symbol "x"] + toAST (diff (symbol "x" ** 2) (symbol "x")) `shouldBe` SymengineMul [2, symbol "x"] describe "Misc" $ do it "" $ do From 6fe50f95234662983808ae645486735d6d7b55de Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 17 Aug 2023 12:14:00 +0200 Subject: [PATCH 19/21] constructBasicFrom can throw; flake update; export DenseMatrix --- flake.lock | 35 +++++++++-------------------------- flake.nix | 4 ---- src/Symengine.hs | 1 + src/Symengine/Internal.hs | 3 ++- 4 files changed, 12 insertions(+), 31 deletions(-) diff --git a/flake.lock b/flake.lock index 305e052..92c0573 100644 --- a/flake.lock +++ b/flake.lock @@ -1,31 +1,15 @@ { "nodes": { - "flake-compat": { - "flake": false, - "locked": { - "lastModified": 1673956053, - "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, "flake-utils": { "inputs": { "systems": "systems" }, "locked": { - "lastModified": 1685518550, - "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", + "lastModified": 1689068808, + "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", "owner": "numtide", "repo": "flake-utils", - "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", + "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", "type": "github" }, "original": { @@ -36,11 +20,11 @@ }, "nix-filter": { "locked": { - "lastModified": 1681154353, - "narHash": "sha256-MCJ5FHOlbfQRFwN0brqPbCunLEVw05D/3sRVoNVt2tI=", + "lastModified": 1687178632, + "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=", "owner": "numtide", "repo": "nix-filter", - "rev": "f529f42792ade8e32c4be274af6b6d60857fbee7", + "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174", "type": "github" }, "original": { @@ -51,11 +35,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1686960236, - "narHash": "sha256-AYCC9rXNLpUWzD9hm+askOfpliLEC9kwAo7ITJc4HIw=", + "lastModified": 1692174805, + "narHash": "sha256-xmNPFDi/AUMIxwgOH/IVom55Dks34u1g7sFKKebxUm0=", "owner": "nixos", "repo": "nixpkgs", - "rev": "04af42f3b31dba0ef742d254456dc4c14eedac86", + "rev": "caac0eb6bdcad0b32cb2522e03e4002c8975c62e", "type": "github" }, "original": { @@ -67,7 +51,6 @@ }, "root": { "inputs": { - "flake-compat": "flake-compat", "flake-utils": "flake-utils", "nix-filter": "nix-filter", "nixpkgs": "nixpkgs" diff --git a/flake.nix b/flake.nix index 88cc76b..692c757 100644 --- a/flake.nix +++ b/flake.nix @@ -5,10 +5,6 @@ nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; flake-utils.url = "github:numtide/flake-utils"; nix-filter.url = "github:numtide/nix-filter"; - flake-compat = { - url = "github:edolstra/flake-compat"; - flake = false; - }; }; outputs = { nixpkgs, flake-utils, nix-filter, ... }: diff --git a/src/Symengine.hs b/src/Symengine.hs index d603e48..217ad10 100644 --- a/src/Symengine.hs +++ b/src/Symengine.hs @@ -10,6 +10,7 @@ -- Description : Symengine bindings to Haskell module Symengine ( Basic (..) + , DenseMatrix (..) , symbol , parse , e diff --git a/src/Symengine/Internal.hs b/src/Symengine/Internal.hs index b6e5fe5..bd2fdf4 100644 --- a/src/Symengine/Internal.hs +++ b/src/Symengine/Internal.hs @@ -24,6 +24,7 @@ import Language.C.Inline qualified as C import Language.C.Inline.Unsafe qualified as CU import Language.Haskell.TH (Exp, Q) import System.IO.Unsafe +import Language.C.Inline.Cpp.Exception qualified as C constructBasicFrom :: String -> Q Exp constructBasicFrom expr = @@ -31,7 +32,7 @@ constructBasicFrom expr = [("expr", const expr)] [| constructBasic $ \dest -> - [CU.block| void { + [C.throwBlock| void { using namespace SymEngine; new ($(Object* dest)) Object{@expr()}; } |] From 2ddfe464c795a9644bab1ca5c878f5748df082d8 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:19:05 +0200 Subject: [PATCH 20/21] Implement subs --- src/Symengine.hs | 27 ++++++++++++++++++++++----- src/Symengine/Context.hs | 2 ++ test/Spec.hs | 6 ++++++ 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/Symengine.hs b/src/Symengine.hs index 217ad10..fda6811 100644 --- a/src/Symengine.hs +++ b/src/Symengine.hs @@ -18,6 +18,7 @@ module Symengine , nan , diff , evalf + , subs , inverse , identityMatrix , zeroMatrix @@ -31,7 +32,7 @@ module Symengine , AST (..) ) where -import Control.Exception (bracket_) +import Control.Exception (bracket, bracket_) import Control.Monad import Data.ByteString (useAsCString) import Data.Text (Text, pack, unpack) @@ -64,6 +65,8 @@ data CxxInteger data CxxString +data CxxMapBasicBasic + importSymengine -- | Convert a pointer to @std::string@ into a string. @@ -376,10 +379,24 @@ evalf (evalDomainToCInt -> domain) (fromIntegral -> bits) x = unsafePerformIO $ withBasic x $ \x' -> $(constructBasicFrom "evalf(**$(const Object* x'), $(int bits), static_cast($(int domain)))") --- pureBinaryOp $ \dest f x -> do --- [CU.exp| void { --- CONSTRUCT_BASIC($(Object* dest), (*$(const Object* f))->diff( --- SymEngine::rcp_static_cast(*$(const Object* x)))) } |] +withCxxMapBasicBasic :: [(Basic, Basic)] -> (Ptr CxxMapBasicBasic -> IO a) -> IO a +withCxxMapBasicBasic pairs action = + bracket allocate destroy $ \p -> do + forM_ pairs $ \(from, to) -> + withBasic from $ \fromPtr -> withBasic to $ \toPtr -> + [CU.exp| void { + $(map_basic_basic* p)->emplace(*$(Object const* fromPtr), *$(Object const* toPtr)) } |] + action p + where + allocate = [CU.exp| map_basic_basic* { new map_basic_basic } |] + destroy p = [CU.exp| void { delete $(map_basic_basic* p) } |] + +subs :: [(Basic, Basic)] -> Basic -> Basic +subs replacements expr = + unsafePerformIO $ + withCxxMapBasicBasic replacements $ \replacementsPtr -> + withBasic expr $ \exprPtr -> + $(constructBasicFrom "subs(*$(Object const* exprPtr), *$(map_basic_basic const* replacementsPtr))") generateDenseMatrix :: Int -> Int -> (Int -> Int -> Basic) -> DenseMatrix Basic generateDenseMatrix nrows ncols f = diff --git a/src/Symengine/Context.hs b/src/Symengine/Context.hs index b495a86..4958b66 100644 --- a/src/Symengine/Context.hs +++ b/src/Symengine/Context.hs @@ -37,6 +37,7 @@ importSymengine = , C.include "" , C.include "" , C.include "" + , C.include "" , C.include "" , C.include "" , C.include "" @@ -58,6 +59,7 @@ symengineTypePairs = , ("DenseMatrix", "DenseMatrix Basic") , ("integer_class", "CxxInteger") , ("std::string", "CxxString") + , ("map_basic_basic", "CxxMapBasicBasic") ] where optional :: (CIdentifier, String) -> Q [(TypeSpecifier, TypeQ)] diff --git a/test/Spec.hs b/test/Spec.hs index dc80433..513fe2d 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -56,6 +56,12 @@ main = hspec $ do toAST (diff (parse "f(x)") (symbol "x")) `shouldBe` SymengineDerivative (parse "f(x)") [symbol "x"] toAST (diff (symbol "x" ** 2) (symbol "x")) `shouldBe` SymengineMul [2, symbol "x"] + describe "subs" $ do + it "" $ do + subs [("x", 1)] "a + f(x) / x" `shouldBe` "a + f(1)" + subs [("k", "c")] "a + b" `shouldBe` "a + b" + subs [] "a + b" `shouldBe` "a + b" + subs [("a + b", "c")] "a + b" `shouldBe` "c" describe "Misc" $ do it "" $ do print $ parse "a + f(x) / x - 4**2" From b25a51de39c6e6379feecb729f058de551bf7d6c Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:55:47 +0200 Subject: [PATCH 21/21] Add BasicKey newtype that is usable as a key with containers from the containers and unordered-containers libraries --- flake.nix | 2 +- src/Symengine.hs | 32 +++++++++++++++++++++++++++++++- symengine.cabal | 1 + 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/flake.nix b/flake.nix index 692c757..516f393 100644 --- a/flake.nix +++ b/flake.nix @@ -71,7 +71,7 @@ nil # Formatters fourmolu - cabal-fmt + # cabal-fmt nixpkgs-fmt # Previewing markdown files python3Packages.grip diff --git a/src/Symengine.hs b/src/Symengine.hs index fda6811..00bb99f 100644 --- a/src/Symengine.hs +++ b/src/Symengine.hs @@ -30,11 +30,12 @@ module Symengine , toAST , fromAST , AST (..) + , BasicKey (..) ) where import Control.Exception (bracket, bracket_) import Control.Monad -import Data.ByteString (useAsCString) +import Data.Hashable (Hashable (..)) import Data.Text (Text, pack, unpack) import Data.Text.Encoding qualified as T import Data.Vector (Vector) @@ -210,6 +211,35 @@ instance Eq Basic where toBool <$> [CU.exp| bool { eq(**$(const Object* a'), **$(const Object* b')) } |] +instance Hashable Basic where + hashWithSalt s = hashWithSalt s . hashInternal + where + hashInternal x = unsafePerformIO $ withBasic x $ \p -> [CU.exp| uint64_t { (*$(Object const* p))->hash() } |] + +newtype BasicKey = BasicKey {unBasicKey :: Basic} + +instance Eq BasicKey where + (BasicKey a) == (BasicKey b) + | hash a /= hash b = False + | otherwise = a == b + +instance Ord BasicKey where + compare (BasicKey a) (BasicKey b) = + case compare hashA hashB of + LT -> LT + GT -> GT + EQ -> case compareInternal of + -1 -> LT + 0 -> EQ + 1 -> GT + x -> error $ "__cmp__ returned invalid value: " <> show x + where + hashA = hash a + hashB = hash b + compareInternal = + unsafePerformIO $ withBasic a $ \aPtr -> withBasic b $ \bPtr -> + [CU.exp| int { (*$(Object const* aPtr))->__cmp__(**$(Object const* bPtr)) } |] + parse :: Text -> Basic parse (T.encodeUtf8 -> name) = unsafePerformIO $ $(constructBasicFrom "parse($bs-cstr:name)") diff --git a/symengine.cabal b/symengine.cabal index 677bb4d..ebe9573 100644 --- a/symengine.cabal +++ b/symengine.cabal @@ -54,6 +54,7 @@ library , bytestring , containers , ghc-bignum + , hashable , inline-c , inline-c-cpp , template-haskell