diff --git a/CHANGELOG.md b/CHANGELOG.md index 987808e9..d7b91c78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,55 @@ # Changelog for Jikka +## 2021-06-25: v5.0.5.0 + +Some optimizations are implemented. +Now it can convert a O(N) Python code for fibonacci to O(log N) C++ code. + +Input, O(N): + +``` python +def f(n: int) -> int: + a = 0 + b = 1 + for _ in range(n): + c = a + b + a = b + b = c + return a + +def solve(n: int) -> int: + return f(n) % 1000000007 +``` + +Output, O(log N): + +``` c++ +#include "jikka/all.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +int64_t solve(int64_t n_317) { + return jikka::modmatap<2, 2>( + jikka::modmatpow<2>(jikka::make_array>( + jikka::make_array(1, 1), + jikka::make_array(1, 0)), + n_317, 1000000007), + jikka::make_array(1, 0), 1000000007)[1]; +} +int main() { + int64_t x318; + std::cin >> x318; + int64_t x319 = solve(x318); + std::cout << x319; + std::cout << '\n'; +} +``` + ## 2021-06-23: v5.0.4.0 Now executable C++ code is generated. diff --git a/README.md b/README.md index 35bd9e4e..057cc22a 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ for users: - [docs/language.md](https://github.com/kmyk/Jikka/blob/master/docs/language.md) - [CHANGELOG.md](https://github.com/kmyk/Jikka/blob/master/CHANGELOG.md) +- blog article [競技プログラミングの問題を自動で解きたい - うさぎ小屋](https://kimiyuki.net/blog/2020/12/09/automated-solvers-of-competitive-programming/) for developpers: @@ -37,87 +38,50 @@ for developpers: - [Haddock](https://kmyk.github.io/Jikka/) -## Examples (`v3.1.0`) +## Examples (`v5.0.5.0`) -The below are examples of old the version (at `v3.1.0`). The input was a ML code. +Input, O(N): -### Sum of Max - -Problem: -You are given a natural number K and a sequence A = (a₀, a₁, …, aₙ) of length N. -Compute the value of ∑ᵢ˱ₖ maxⱼ˱ₙ (i + 2 aⱼ). - -Input, O(K N): - -``` sml -let K = 100000 in -let given N : Nat in -let given A : N -> Nat in - -sum K (fun i -> max N (fun j -> i + 2 * A j)) +``` python +def f(n: int) -> int: + a = 0 + b = 1 + for _ in range(n): + c = a + b + a = b + b = c + return a + +def solve(n: int) -> int: + return f(n) % 1000000007 ``` -Output, O(K + N): +Output, O(log N): ``` c++ -int64_t solve(int64_t N, const vector & A) { - int64_t K = 100000; - int64_t a2 = 0; - for (int64_t i2 = 0; i2 < K; ++ i2) { - a2 += i2; - } - int64_t a1 = INT64_MIN; - for (int64_t i1 = 0; i1 < N; ++ i1) { - a1 = max(a1, 2 * A[i1]); - } - return a2 + K * a1; -} -``` - -### AtCoder Beginner Contest 134: C - Exception Handling - -Problem: - -Input, O(N^2): - -``` sml -let given N : [2, 200001) in -let given A : N -> 200001 in - -let f (i : N) = max N (fun j -> if j = i then 0 else A j) in -f -``` - -Output, O(N): - - -## Examples (`v5.0.1.0`) - -``` console -$ cat examples/fact.py -def f(n: int) -> int: - if n == 0: - return 1 - else: - return n * f(n - 1) - -$ stack run convert examples/fact.py -int64_t f0_f(int64_t a1_n) { - bool x2 = a1_n == 0; - if (x2) { - return 1; - } else { - int64_t x3 = - 1; - int64_t x4 = x3; - int64_t x5 = x4; - int64_t x6 = a1_n + x5; - int64_t x7 = x6; - int64_t x8 = f0_f(x7); - return a1_n * x8; - } +#include "jikka/all.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +int64_t solve(int64_t n_317) { + return jikka::modmatap<2, 2>( + jikka::modmatpow<2>(jikka::make_array>( + jikka::make_array(1, 1), + jikka::make_array(1, 0)), + n_317, 1000000007), + jikka::make_array(1, 0), 1000000007)[1]; } -int64_t solve(int64_t a9) { - return f0_f(a9); +int main() { + int64_t x318; + std::cin >> x318; + int64_t x319 = solve(x318); + std::cout << x319; + std::cout << '\n'; } ``` diff --git a/examples/fib.in b/examples/fib.1.in similarity index 100% rename from examples/fib.in rename to examples/fib.1.in diff --git a/examples/fib.out b/examples/fib.1.out similarity index 100% rename from examples/fib.out rename to examples/fib.1.out diff --git a/examples/fib.2.in b/examples/fib.2.in new file mode 100644 index 00000000..29d6383b --- /dev/null +++ b/examples/fib.2.in @@ -0,0 +1 @@ +100 diff --git a/examples/fib.2.out b/examples/fib.2.out new file mode 100644 index 00000000..32e8647a --- /dev/null +++ b/examples/fib.2.out @@ -0,0 +1 @@ +687995182 diff --git a/examples/fib.large.in b/examples/fib.large.in new file mode 100644 index 00000000..770fdcfb --- /dev/null +++ b/examples/fib.large.in @@ -0,0 +1 @@ +1000000000 diff --git a/examples/fib.large.out b/examples/fib.large.out new file mode 100644 index 00000000..aabe6ec3 --- /dev/null +++ b/examples/fib.large.out @@ -0,0 +1 @@ +21 diff --git a/examples/fib.py b/examples/fib.py index 6c2f5669..f824277f 100644 --- a/examples/fib.py +++ b/examples/fib.py @@ -1,4 +1,4 @@ -def solve(n: int) -> int: +def f(n: int) -> int: a = 0 b = 1 for _ in range(n): @@ -6,3 +6,6 @@ def solve(n: int) -> int: a = b b = c return a + +def solve(n: int) -> int: + return f(n) % 1000000007 diff --git a/examples/test.sh b/examples/test.sh index 8c787838..6b2744e9 100644 --- a/examples/test.sh +++ b/examples/test.sh @@ -2,10 +2,15 @@ set -ex tempdir=$(mktemp -d) trap "rm -rf $tempdir" EXIT -for f in examples/*.in ; do - diff <(stack --system-ghc run -- execute --target rpython ${f%.in}.py < $f) ${f%.in}.out - diff <(stack --system-ghc run -- execute --target core ${f%.in}.py < $f) ${f%.in}.out - stack --system-ghc run -- convert --target cxx ${f%.in}.py > $tempdir/$(basename $f .in).cpp - g++ -std=c++17 -Wall -O2 -Iruntime/include $tempdir/$(basename $f .in).cpp -o $tempdir/$(basename $f .in) - diff <($tempdir/$(basename $f .in) < $f) ${f%.in}.out +for input in examples/*.*.in ; do + output=${input%.in}.out + code=${input%.*.in}.py + name=$(basename ${input%.*.in}) + if [[ ! $input =~ large ]] ; then + diff <(stack --system-ghc run -- execute --target rpython $code < $input) $output + fi + diff <(stack --system-ghc run -- execute --target core $code < $input) $output + stack --system-ghc run -- convert --target cxx $code > $tempdir/$name.cpp + g++ -std=c++17 -Wall -O2 -Iruntime/include $tempdir/$name.cpp -o $tempdir/$name + diff <($tempdir/$name < $input) $output done diff --git a/package.yaml b/package.yaml index 272b3092..7b47fb56 100644 --- a/package.yaml +++ b/package.yaml @@ -1,5 +1,5 @@ name: Jikka -version: 5.0.4.0 +version: 5.0.5.0 github: "kmyk/Jikka" license: Apache author: "Kimiyuki Onaka" diff --git a/runtime/include/jikka/all.hpp b/runtime/include/jikka/all.hpp index f49a02de..c0cdfa2a 100644 --- a/runtime/include/jikka/all.hpp +++ b/runtime/include/jikka/all.hpp @@ -1,8 +1,10 @@ #ifndef JIKKA_ALL_HPP #define JIKKA_ALL_HPP #include +#include #include #include +#include #include #include @@ -41,6 +43,84 @@ inline int64_t pow(int64_t x, int64_t k) { return y; } +template inline T natind(T x, std::function f, int64_t n) { + if (n < 0) { + return x; + } + while (n--) { + x = f(x); + } + return x; +} + +template +using matrix = std::array, H>; + +template +std::array make_array(Args... args) { + return {args...}; +} + +template +std::array matap(const matrix &a, + const std::array &b) { + std::array c = {}; + for (size_t y = 0; y < H; ++y) { + for (size_t x = 0; x < W; ++x) { + c[y] += a[y][x] * b[x]; + } + } + return c; +} + +template matrix matzero() { return {}; } + +template matrix matone() { + matrix a = {}; + for (size_t i = 0; i < N; ++i) { + a[i][i] = 1; + } + return a; +} + +template +matrix matadd(const matrix &a, + const matrix &b) { + matrix c; + for (size_t y = 0; y < H; ++y) { + for (size_t x = 0; x < W; ++x) { + c[y][x] = a[y][x] + b[y][x]; + } + } + return c; +} + +template +matrix matmul(const matrix &a, + const matrix &b) { + matrix c = {}; + for (size_t y = 0; y < H; ++y) { + for (size_t z = 0; z < N; ++z) { + for (size_t x = 0; x < W; ++x) { + c[y][x] += a[y][z] * b[z][x]; + } + } + } + return c; +} + +template +matrix matpow(matrix x, int64_t k) { + matrix y = matone(); + for (; k; k >>= 1) { + if (k & 1) { + y = matmul(y, x); + } + x = matmul(x, x); + } + return y; +} + inline int64_t modinv(int64_t value, int64_t MOD) { assert(0 < value and value < MOD); int64_t a = value, b = MOD; @@ -77,6 +157,63 @@ inline int64_t modpow(int64_t x, int64_t k, int64_t MOD) { return y; } +template +std::array modmatap(const matrix &a, + const std::array &b, int64_t MOD) { + std::array c = {}; + for (size_t y = 0; y < H; ++y) { + for (size_t x = 0; x < W; ++x) { + c[y] += a[y][x] * b[x] % MOD; + } + c[y] = floormod(c[y], MOD); + } + return c; +} + +template +matrix modmatadd(const matrix &a, + const matrix &b, int64_t MOD) { + matrix c; + for (size_t y = 0; y < H; ++y) { + for (size_t x = 0; x < W; ++x) { + c[y][x] = floormod(a[y][x] + b[y][x], MOD); + } + } + return c; +} + +template +matrix modmatmul(const matrix &a, + const matrix &b, int64_t MOD) { + matrix c = {}; + for (size_t y = 0; y < H; ++y) { + for (size_t z = 0; z < N; ++z) { + for (size_t x = 0; x < W; ++x) { + c[y][x] += a[y][z] * b[z][x] % MOD; + } + } + } + for (size_t y = 0; y < H; ++y) { + for (size_t x = 0; x < W; ++x) { + c[y][x] = floormod(c[y][x], MOD); + } + } + return c; +} + +template +matrix modmatpow(matrix x, int64_t k, + int64_t MOD) { + matrix y = matone(); + for (; k; k >>= 1) { + if (k & 1) { + y = modmatmul(y, x, MOD); + } + x = modmatmul(x, x, MOD); + } + return y; +} + template std::vector cons(T x, const std::vector &xs) { std::vector ys(xs.size() + 1); ys[0] = x; diff --git a/src/Jikka/CPlusPlus/Convert.hs b/src/Jikka/CPlusPlus/Convert.hs index 3234d6e5..cd4d4ddf 100644 --- a/src/Jikka/CPlusPlus/Convert.hs +++ b/src/Jikka/CPlusPlus/Convert.hs @@ -9,10 +9,8 @@ import qualified Jikka.CPlusPlus.Convert.FromCore as FromCore import qualified Jikka.CPlusPlus.Language.Expr as Y import Jikka.Common.Alpha import Jikka.Common.Error -import qualified Jikka.Core.Convert.ANormal as ANormal import qualified Jikka.Core.Language.Expr as X run :: (MonadAlpha m, MonadError Error m) => X.Program -> m Y.Program run prog = do - prog <- ANormal.run prog FromCore.run prog diff --git a/src/Jikka/CPlusPlus/Convert/FromCore.hs b/src/Jikka/CPlusPlus/Convert/FromCore.hs index e1260a7b..945fcb97 100644 --- a/src/Jikka/CPlusPlus/Convert/FromCore.hs +++ b/src/Jikka/CPlusPlus/Convert/FromCore.hs @@ -78,7 +78,10 @@ runType = \case X.IntTy -> return Y.TyInt64 X.BoolTy -> return Y.TyBool X.ListTy t -> Y.TyVector <$> runType t - X.TupleTy ts -> Y.TyTuple <$> mapM runType ts + X.TupleTy ts -> + if not (null ts) && ts == replicate (length ts) (head ts) + then Y.TyArray <$> runType (head ts) <*> pure (fromIntegral (length ts)) + else Y.TyTuple <$> mapM runType ts X.FunTy ts t -> Y.TyFunction <$> runType t <*> mapM runType ts runLiteral :: MonadError Error m => X.Literal -> m Y.Expr @@ -129,9 +132,20 @@ runAppBuiltin f args = case (f, args) of (X.BitXor, [e1, e2]) -> return $ Y.BinOp Y.BitXor e1 e2 (X.BitLeftShift, [e1, e2]) -> return $ Y.BinOp Y.BitLeftShift e1 e2 (X.BitRightShift, [e1, e2]) -> return $ Y.BinOp Y.BitRightShift e1 e2 + -- matrix functions + (X.MatAp h w, [f, x]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::matap<" ++ show h ++ ", " ++ show w ++ ">")) []) [f, x] + (X.MatZero n, []) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::matzero<" ++ show n ++ ">")) []) [] + (X.MatOne n, []) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::matone<" ++ show n ++ ">")) []) [] + (X.MatAdd h w, [f, g]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::matadd<" ++ show h ++ ", " ++ show w ++ ">")) []) [f, g] + (X.MatMul h n w, [f, g]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::matmul<" ++ show h ++ ", " ++ show n ++ ", " ++ show w ++ ">")) []) [f, g] + (X.MatPow n, [f, k]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::matpow<" ++ show n ++ ">")) []) [f, k] -- modular functions (X.ModInv, [e1, e2]) -> return $ Y.Call (Y.Function "jikka::modinv" []) [e1, e2] (X.ModPow, [e1, e2, e3]) -> return $ Y.Call (Y.Function "jikka::modpow" []) [e1, e2, e3] + (X.ModMatAp h w, [f, x, m]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::modmatap<" ++ show h ++ ", " ++ show w ++ ">")) []) [f, x, m] + (X.ModMatAdd h w, [f, g, m]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::modmatadd<" ++ show h ++ ", " ++ show w ++ ">")) []) [f, g, m] + (X.ModMatMul h n w, [f, g, m]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::modmatmul<" ++ show h ++ ", " ++ show n ++ ", " ++ show w ++ ">")) []) [f, g, m] + (X.ModMatPow n, [f, k, m]) -> return $ Y.Call (Y.Function (Y.FunName ("jikka::modmatpow<" ++ show n ++ ">")) []) [f, k, m] -- list functions (X.Cons t, [e1, e2]) -> do t <- runType t @@ -191,8 +205,15 @@ runAppBuiltin f args = case (f, args) of -- tuple functions (X.Tuple ts, es) -> do ts <- mapM runType ts - return $ Y.Call (Y.Function "std::tuple" ts) es - (X.Proj _ n, [e]) -> return $ Y.Call (Y.Function (Y.FunName ("std::get<" ++ show n ++ ">")) []) [e] + return $ + if not (null ts) && ts == replicate (length ts) (head ts) + then Y.Call (Y.Function "jikka::make_array" [head ts]) es + else Y.Call (Y.Function "std::tuple" ts) es + (X.Proj ts n, [e]) -> + return $ + if not (null ts) && ts == replicate (length ts) (head ts) + then Y.At e (Y.Lit (Y.LitInt32 (fromIntegral n))) + else Y.Call (Y.Function (Y.FunName ("std::get<" ++ show n ++ ">")) []) [e] -- comparison (X.LessThan _, [e1, e2]) -> return $ Y.BinOp Y.LessThan e1 e2 (X.LessEqual _, [e1, e2]) -> return $ Y.BinOp Y.LessEqual e1 e2 @@ -327,13 +348,28 @@ runToplevelExpr env = \case case t of X.FunTy ts ret -> do let f = Y.VarName "solve" - args <- forM ts $ \t -> do - t <- runType t - y <- newFreshName ArgumentNameKind "" - return (t, y) + (args, body) <- case e of + X.Lam args body -> do + when (map snd args /= ts) $ do + throwInternalError "type error" + args <- forM args $ \(x, t) -> do + y <- renameVarName ArgumentNameKind x + return (x, t, y) + e <- runExpr (reverse args ++ env) body + let body = [Y.Return e] + args' <- forM args $ \(_, t, y) -> do + t <- runType t + return (t, y) + return (args', body) + _ -> do + args <- forM ts $ \t -> do + t <- runType t + y <- newFreshName ArgumentNameKind "" + return (t, y) + e <- runExpr env e + let body = [Y.Return (Y.Call (Y.Callable e) (map (Y.Var . snd) args))] + return (args, body) ret <- runType ret - e <- runExpr env e - let body = [Y.Return (Y.Call (Y.Callable e) (map (Y.Var . snd) args))] let solve = [Y.FunDef ret f args body] main <- runMain f t return $ solve ++ main diff --git a/src/Jikka/Common/Matrix.hs b/src/Jikka/Common/Matrix.hs index ca2959ab..aba2f15c 100644 --- a/src/Jikka/Common/Matrix.hs +++ b/src/Jikka/Common/Matrix.hs @@ -1,7 +1,10 @@ +{-# LANGUAGE DeriveFunctor #-} + module Jikka.Common.Matrix ( Matrix, unMatrix, makeMatrix, + makeMatrix', matsize, matsize', matcheck, @@ -23,7 +26,7 @@ import qualified Data.Vector.Mutable as MV -- | `Matrix` is data for matrices. -- It is guaranteed that internal arrays are not jagged arrays. newtype Matrix a = Matrix (V.Vector (V.Vector a)) - deriving (Eq, Ord, Show) + deriving (Eq, Ord, Show, Functor) unMatrix :: Matrix a -> V.Vector (V.Vector a) unMatrix (Matrix a) = a @@ -50,6 +53,11 @@ matcheck a = makeMatrix :: V.Vector (V.Vector a) -> Maybe (Matrix a) makeMatrix a = if matcheck a then Just (Matrix a) else Nothing +makeMatrix' :: V.Vector (V.Vector a) -> Matrix a +makeMatrix' a = case makeMatrix a of + Nothing -> error "Jikka.Common.Matrix.makeMatrix': the input is not a matrix" + Just a -> a + matzero :: Num a => Int -> Matrix a matzero n = Matrix $ V.replicate n (V.replicate n 0) @@ -67,12 +75,12 @@ matadd (Matrix a) (Matrix b) = -- This assumes sizes of inputs match. matmul :: Num a => Matrix a -> Matrix a -> Matrix a matmul (Matrix a) (Matrix b) = runST $ do - let (h, w) = matsize' a - let (_, w') = matsize' b - c <- MV.replicateM h (MV.replicate w' 0) - forM_ [0 .. h] $ \y -> do - forM_ [0 .. w] $ \z -> do - forM_ [0 .. w'] $ \x -> do + let (h, n) = matsize' a + let (_, w) = matsize' b + c <- MV.replicateM h (MV.replicate w 0) + forM_ [0 .. h - 1] $ \y -> do + forM_ [0 .. n - 1] $ \z -> do + forM_ [0 .. w - 1] $ \x -> do let delta = (a V.! y V.! z) * (b V.! z V.! x) row <- MV.read c y MV.modify row (+ delta) x @@ -84,8 +92,8 @@ matap :: Num a => Matrix a -> V.Vector a -> V.Vector a matap (Matrix a) b = runST $ do let (h, w) = matsize' a c <- MV.replicate h 0 - forM_ [0 .. h] $ \y -> do - forM_ [0 .. w] $ \x -> do + forM_ [0 .. h - 1] $ \y -> do + forM_ [0 .. w - 1] $ \x -> do let delta = (a V.! y V.! x) * (b V.! x) MV.modify c (+ delta) y V.freeze c @@ -96,10 +104,10 @@ matscalar a (Matrix b) = Matrix $ V.map (V.map (a *)) b -- | `matpow` calculates the power \(A^k\) of a matrix \(A\) and a natural number \(k\). -- This assumes inputs are square matrices. -- This fails for \(k \lt 0\). -matpow :: Num a => Matrix a -> Integer -> Matrix a +matpow :: (Show a, Num a) => Matrix a -> Integer -> Matrix a matpow _ k | k < 0 = error "cannot calculate a negative power for a monoid" matpow x k = go unit x k where unit = let (h, _) = matsize x in matone h go y _ 0 = y - go y x k = go (if k `mod` 2 == 1 then matmul y x else x) (matmul x x) (k `div` 2) + go y x k = go (if k `mod` 2 == 1 then matmul y x else y) (matmul x x) (k `div` 2) diff --git a/src/Jikka/Common/ModInt.hs b/src/Jikka/Common/ModInt.hs new file mode 100644 index 00000000..c8462975 --- /dev/null +++ b/src/Jikka/Common/ModInt.hs @@ -0,0 +1,38 @@ +module Jikka.Common.ModInt + ( ModInt, + toModInt, + fromModInt, + moduloOfModInt, + ) +where + +import Data.Monoid + +data ModInt = ModInt Integer (Maybe Integer) + deriving (Eq, Ord, Read, Show) + +toModInt :: Integer -> Integer -> ModInt +toModInt _ m | m <= 0 = error $ "Jikka.Common.ModInt.toModInt: modulo must be positive, but m = " ++ show m +toModInt a m = ModInt (a `mod` m) (Just m) + +fromModInt :: ModInt -> Integer +fromModInt (ModInt a _) = a + +moduloOfModInt :: ModInt -> Maybe Integer +moduloOfModInt (ModInt _ m) = m + +instance Num ModInt where + ModInt _ (Just m1) + ModInt _ (Just m2) | m1 /= m2 = error $ "Jikka.Common.ModInt.(+): modulo must be the same, but m1 = " ++ show m1 ++ " and m2 = " ++ show m2 + ModInt a m1 + ModInt b m2 = case getFirst (First m1 <> First m2) of + Nothing -> ModInt (a + b) Nothing + Just m -> ModInt (let c = a + b in if c >= m then c - m else c) (Just m) + ModInt _ (Just m1) * ModInt _ (Just m2) | m1 /= m2 = error $ "Jikka.Common.ModInt.(*): modulo must be the same, but m1 = " ++ show m1 ++ " and m2 = " ++ show m2 + ModInt a m1 * ModInt b m2 = case getFirst (First m1 <> First m2) of + Nothing -> ModInt (a * b) Nothing + Just m -> ModInt ((a * b) `mod` m) (Just m) + abs = error "Jikka.Common.ModInt.fromInteger: cannot call abs for modint" + signum = error "Jikka.Common.ModInt.fromInteger: cannot signum for modint" + fromInteger a = ModInt a Nothing + negate (ModInt a m) = case m of + Nothing -> ModInt (- a) m + Just m -> ModInt (if a == 0 then 0 else m - a) (Just m) diff --git a/src/Jikka/Core/Convert.hs b/src/Jikka/Core/Convert.hs new file mode 100644 index 00000000..228db675 --- /dev/null +++ b/src/Jikka/Core/Convert.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE FlexibleContexts #-} + +-- | +-- Module : Jikka.Core.Convert +-- Description : is a meta module to combine other optimizers. +-- Copyright : (c) Kimiyuki Onaka, 2020 +-- License : Apache License 2.0 +-- Maintainer : kimiyuki95@gmail.com +-- Stability : experimental +-- Portability : portable +-- +-- `Jikka.Core.Convert` is a module to combine other all optimizers. +module Jikka.Core.Convert + ( run, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import qualified Jikka.Core.Convert.Alpha as Alpha +import qualified Jikka.Core.Convert.ConstantFolding as ConstantFolding +import qualified Jikka.Core.Convert.ConstantPropagation as ConstantPropagation +import qualified Jikka.Core.Convert.ImmediateAppToLet as ImmediateAppToLet +import qualified Jikka.Core.Convert.LinearFunction as LinearFunction +import qualified Jikka.Core.Convert.PropagateMod as PropagateMod +import qualified Jikka.Core.Convert.RemoveUnusedVars as RemoveUnusedVars +import qualified Jikka.Core.Convert.StrengthReduction as StrengthReduction +import qualified Jikka.Core.Convert.TrivialLetElimination as TrivialLetElimination +import qualified Jikka.Core.Convert.TypeInfer as TypeInfer +import Jikka.Core.Language.Expr + +run' :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run' prog = do + prog <- Alpha.run prog + prog <- TypeInfer.run prog + prog <- RemoveUnusedVars.run prog + prog <- ImmediateAppToLet.run prog + prog <- TrivialLetElimination.run prog + prog <- LinearFunction.run prog + prog <- PropagateMod.run prog + prog <- ConstantPropagation.run prog + prog <- ConstantFolding.run prog + StrengthReduction.run prog + +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = + let iteration = 20 + in foldM (\prog _ -> run' prog) prog [0 .. iteration - 1] diff --git a/src/Jikka/Core/Convert/ANormal.hs b/src/Jikka/Core/Convert/ANormal.hs index e79296bf..35a57ac6 100644 --- a/src/Jikka/Core/Convert/ANormal.hs +++ b/src/Jikka/Core/Convert/ANormal.hs @@ -83,6 +83,18 @@ runToplevelExpr env = \case cont <- runToplevelExpr ((f, t) : env) cont return $ ToplevelLetRec f args ret body cont +-- | `run` makes a given program A-normal form. +-- A program is an A-normal form iff assigned exprs of all let-statements are values or function applications. +-- For example, this converts the following: +-- +-- > (let x = 1 in x) + ((fun y -> y) 1) +-- +-- to: +-- +-- > let x = 1 +-- > in let f = fun y -> y +-- > in let z = f x +-- > in z run :: (MonadAlpha m, MonadError Error m) => Program -> m Program run prog = wrapError' "Jikka.Core.Convert.ANormal" $ do prog <- Alpha.runProgram prog diff --git a/src/Jikka/Core/Convert/ConstantFolding.hs b/src/Jikka/Core/Convert/ConstantFolding.hs new file mode 100644 index 00000000..ab6971ca --- /dev/null +++ b/src/Jikka/Core/Convert/ConstantFolding.hs @@ -0,0 +1,71 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +module Jikka.Core.Convert.ConstantFolding + ( run, + ) +where + +import Data.Bits +import Jikka.Common.Error +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.Runtime +import Jikka.Core.Language.Util + +runExpr :: MonadError Error m => [(VarName, Type)] -> Expr -> m Expr +runExpr _ = \case + Negate' (LitInt' a) -> return $ LitInt' (- a) + Plus' (LitInt' a) (LitInt' b) -> return $ LitInt' (a + b) + Minus' (LitInt' a) (LitInt' b) -> return $ LitInt' (a - b) + Mult' (LitInt' a) (LitInt' b) -> return $ LitInt' (a * b) + FloorDiv' (LitInt' a) (LitInt' b) | b /= 0 -> LitInt' <$> floorDiv a b + FloorMod' (LitInt' a) (LitInt' b) | b /= 0 -> LitInt' <$> floorMod a b + CeilDiv' (LitInt' a) (LitInt' b) | b /= 0 -> LitInt' <$> ceilDiv a b + CeilMod' (LitInt' a) (LitInt' b) | b /= 0 -> LitInt' <$> ceilMod a b + Pow' (LitInt' a) (LitInt' b) | b >= 0 && fromInteger b * log (abs (fromInteger a)) < 100 -> return $ LitInt' (a ^ b) + Abs' (LitInt' a) -> return $ LitInt' (abs a) + Gcd' (LitInt' a) (LitInt' b) -> return $ LitInt' (gcd a b) + Lcm' (LitInt' a) (LitInt' b) -> return $ LitInt' (lcm a b) + Min2' _ (LitInt' a) (LitInt' b) -> return $ LitInt' (min a b) + Max2' _ (LitInt' a) (LitInt' b) -> return $ LitInt' (max a b) + Not' (LitBool' a) -> return $ LitBool' (not a) + And' (LitBool' a) (LitBool' b) -> return $ LitBool' (a && b) + Or' (LitBool' a) (LitBool' b) -> return $ LitBool' (a || b) + Implies' (LitBool' a) (LitBool' b) -> return $ LitBool' (not a || b) + If' _ (LitBool' a) e1 e2 -> return $ if a then e1 else e2 + BitNot' (LitInt' a) -> return $ LitInt' (complement a) + BitAnd' (LitInt' a) (LitInt' b) -> return $ LitInt' (a .&. b) + BitOr' (LitInt' a) (LitInt' b) -> return $ LitInt' (a .|. b) + BitXor' (LitInt' a) (LitInt' b) -> return $ LitInt' (a `xor` b) + BitLeftShift' (LitInt' a) (LitInt' b) | - 100 < b && b < 100 -> return $ LitInt' (a `shift` fromInteger b) + BitRightShift' (LitInt' a) (LitInt' b) | - 100 < b && b < 100 -> return $ LitInt' (a `shift` fromInteger (- b)) + LessThan' _ (LitInt' a) (LitInt' b) -> return $ LitBool' (a < b) + LessEqual' _ (LitBool' a) (LitBool' b) -> return $ LitBool' (a <= b) + LessEqual' _ (LitInt' a) (LitInt' b) -> return $ LitBool' (a <= b) + GreaterThan' _ (LitBool' a) (LitBool' b) -> return $ LitBool' (a > b) + GreaterThan' _ (LitInt' a) (LitInt' b) -> return $ LitBool' (a > b) + GreaterEqual' _ (LitBool' a) (LitBool' b) -> return $ LitBool' (a >= b) + Equal' _ (LitInt' a) (LitInt' b) -> return $ LitBool' (a == b) + Equal' _ (LitBool' a) (LitBool' b) -> return $ LitBool' (a == b) + NotEqual' _ (LitInt' a) (LitInt' b) -> return $ LitBool' (a /= b) + NotEqual' _ (LitBool' a) (LitBool' b) -> return $ LitBool' (a /= b) + e -> return e + +runProgram :: MonadError Error m => Program -> m Program +runProgram = mapExprProgramM runExpr + +-- | `run` folds constants in given programs. +-- For example, this converts the following: +-- +-- > 3 x + 2 + 1 +-- +-- to the follwoing: +-- +-- > 3 x + 3 +run :: MonadError Error m => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.ConstantFolding" $ do + prog <- runProgram prog + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Convert/ConstantPropagation.hs b/src/Jikka/Core/Convert/ConstantPropagation.hs new file mode 100644 index 00000000..7f93a4b3 --- /dev/null +++ b/src/Jikka/Core/Convert/ConstantPropagation.hs @@ -0,0 +1,75 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +module Jikka.Core.Convert.ConstantPropagation + ( run, + run', + + -- * internal functions + isSmallExpr, + ) +where + +import qualified Data.Map as M +import Data.Maybe (fromMaybe) +import Jikka.Common.Error +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint + +type Env = M.Map VarName Expr + +-- | `isSmallExpr` checks whether given exprs are suitable to propagate. +isSmallExpr :: Expr -> Bool +isSmallExpr = \case + Var _ -> True + Lit _ -> True + App f args -> all isSmallExpr (f : args) + Lam _ _ -> False + Let _ _ _ _ -> False + +runExpr :: Env -> Expr -> Expr +runExpr env = \case + Var x -> fromMaybe (Var x) (M.lookup x env) + Lit lit -> Lit lit + App f args -> App (runExpr env f) (map (runExpr env) args) + Lam args body -> Lam args (runExpr env body) + Let x t e1 e2 -> + let e1' = runExpr env e1 + in if isSmallExpr e1' + then runExpr (M.insert x e1' env) e2 + else Let x t e1' (runExpr env e2) + +runToplevelExpr :: Env -> ToplevelExpr -> ToplevelExpr +runToplevelExpr env = \case + ResultExpr e -> ResultExpr (runExpr env e) + ToplevelLet x t e cont -> + let e' = runExpr env e + in if isSmallExpr e' + then runToplevelExpr (M.insert x e' env) cont + else ToplevelLet x t e' (runToplevelExpr env cont) + ToplevelLetRec f args ret body cont -> + ToplevelLetRec f args ret (runExpr env body) (runToplevelExpr env cont) + +run' :: Program -> Program +run' = runToplevelExpr M.empty + +-- | `run` does constant propagation. +-- This assumes that the program is alpha-converted. +-- +-- For example, this converts the following: +-- +-- > let x = 1 +-- > in let f = fun y -> y +-- > in x + x + f(x) +-- +-- to: +-- +-- > let f = fun y -> y +-- > in 1 + 1 + f(1) +-- +-- NOTE: this doesn't constant folding. +run :: MonadError Error m => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.ConstantPropagation" $ do + prog <- return $ run' prog + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Convert/ImmediateAppToLet.hs b/src/Jikka/Core/Convert/ImmediateAppToLet.hs new file mode 100644 index 00000000..0377cef2 --- /dev/null +++ b/src/Jikka/Core/Convert/ImmediateAppToLet.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +module Jikka.Core.Convert.ImmediateAppToLet + ( run, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import qualified Jikka.Core.Convert.Alpha as Alpha +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.Util + +runExpr :: [(VarName, Type)] -> Expr -> Expr +runExpr _ = \case + App (Lam formal body) actual -> foldr (\((x, t), e) -> Let x t e) body (zip formal actual) + e -> e + +runProgram :: Program -> Program +runProgram = mapExprProgram runExpr + +-- | `run` does beta-reductions in given programs. +-- For example, this converts the following: +-- +-- > (fun x -> x + x) a +-- +-- to the follwoing: +-- +-- > let x = a in x + x +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.ImmediateAppToLet" $ do + prog <- Alpha.run prog + prog <- return $ runProgram prog + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Convert/LinearFunction.hs b/src/Jikka/Core/Convert/LinearFunction.hs new file mode 100644 index 00000000..cd81b7cc --- /dev/null +++ b/src/Jikka/Core/Convert/LinearFunction.hs @@ -0,0 +1,69 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +module Jikka.Core.Convert.LinearFunction + ( run, + ) +where + +import Control.Monad.Trans +import Control.Monad.Trans.Maybe +import Data.Maybe (fromMaybe) +import qualified Data.Vector as V +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Common.Matrix +import Jikka.Core.Language.ArithmeticalExpr +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.Util +import Jikka.Core.Language.Vars + +fromMatrix :: Matrix ArithmeticalExpr -> Expr +fromMatrix f = + let (h, w) = matsize f + go row = Tuple' (replicate w IntTy) (map formatArithmeticalExpr (V.toList row)) + in Tuple' (replicate h (TupleTy (replicate w IntTy))) (map go (V.toList (unMatrix f))) + +runExpr :: MonadAlpha m => [(VarName, Type)] -> Expr -> m Expr +runExpr env = \case + orig@(Lam [(x, TupleTy ts)] (Tuple' ts' es)) -> + (fromMaybe orig <$>) . runMaybeT $ do + guard $ not (null ts) && all (== IntTy) ts + guard $ not (null ts') && all (== IntTy) ts' + xs <- V.fromList <$> replicateM (length ts) (lift (genVarName x)) + let indexOfProj = \case + (Proj' ts'' i (Var x')) | ts'' == ts && x' == x -> Just i + _ -> Nothing + let replaceWithVar _ e = case indexOfProj e of + Just i -> Var (xs V.! i) + Nothing -> e + rows <- forM es $ \e -> MaybeT . return $ do + let e' = mapExpr replaceWithVar env e + guard $ x `isUnusedVar` e' + (row, c) <- makeVectorFromArithmeticalExpr xs (parseArithmeticalExpr e') + guard $ c == zeroSumExpr -- TODO: support affine functions + return row + f <- MaybeT . return $ makeMatrix (V.fromList rows) + return $ Lam [(x, TupleTy ts)] (MatAp' (length ts') (length ts) (fromMatrix f) (Var x)) + e -> return e + +runProgram :: MonadAlpha m => Program -> m Program +runProgram = mapExprProgramM runExpr + +-- | `run` simplifies a functions from tuples of integers to tuples of integers. +-- For example, this converts the following: +-- +-- > fun xs -> (xs[0] + 2 * xs[1], xs[1]) +-- +-- to the follwoing: +-- +-- > (fun xs -> matap ((1, 2), (0, 1)) xs) +-- +-- TODO: support affine functions +run :: (MonadAlpha m, MonadError Error m) => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.LinearFunction" $ do + prog <- runProgram prog + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Convert/PropagateMod.hs b/src/Jikka/Core/Convert/PropagateMod.hs new file mode 100644 index 00000000..550a4fcb --- /dev/null +++ b/src/Jikka/Core/Convert/PropagateMod.hs @@ -0,0 +1,70 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +module Jikka.Core.Convert.PropagateMod + ( run, + ) +where + +import Jikka.Common.Error +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.TypeCheck +import Jikka.Core.Language.Util + +runFloorMod :: MonadError Error m => [(VarName, Type)] -> Expr -> Expr -> m Expr +runFloorMod env e m = go e + where + go :: MonadError Error m => Expr -> m Expr + go = \case + Negate' e -> FloorMod' <$> (Negate' <$> go e) <*> pure m + Plus' e1 e2 -> FloorMod' <$> (Plus' <$> go e1 <*> go e2) <*> pure m + Minus' e1 e2 -> FloorMod' <$> (Minus' <$> go e1 <*> go e2) <*> pure m + Mult' e1 e2 -> FloorMod' <$> (Mult' <$> go e1 <*> go e2) <*> pure m + Pow' e1 e2 -> ModPow' <$> go e1 <*> pure e2 <*> pure m + ModInv' e m' | m == m' -> ModInv' <$> go e <*> pure m + ModPow' e1 e2 m' | m == m' -> ModPow' <$> go e1 <*> pure e2 <*> pure m + MatAp' h w e1 e2 -> ModMatAp' h w <$> go e1 <*> go e2 <*> pure m + MatAdd' h w e1 e2 -> ModMatAdd' h w <$> go e1 <*> go e2 <*> pure m + MatMul' h n w e1 e2 -> ModMatMul' h n w <$> go e1 <*> go e2 <*> pure m + MatPow' n e1 e2 -> ModMatPow' n <$> go e1 <*> pure e2 <*> pure m + Proj' ts i e@MatAp' {} -> Proj' ts i <$> go e + Proj' ts i e@ModMatAp' {} -> Proj' ts i <$> go e + ModMatAp' h w e1 e2 m' | m == m' -> ModMatAp' h w <$> go e1 <*> go e2 <*> pure m + ModMatAdd' h w e1 e2 m' | m == m' -> ModMatAdd' h w <$> go e1 <*> go e2 <*> pure m + ModMatMul' h n w e1 e2 m' | m == m' -> ModMatMul' h n w <$> go e1 <*> go e2 <*> pure m + ModMatPow' n e1 e2 m' | m == m' -> ModMatPow' n <$> go e1 <*> pure e2 <*> pure m + App (Lam args body) es -> App <$> (Lam args <$> runFloorMod (reverse args ++ env) body m) <*> pure es + Tuple' ts es -> Tuple' ts <$> mapM go es + FloorMod' e m' -> + if m == m' + then go e + else runFloorMod env e (Lcm' m m') + e -> do + t <- typecheckExpr env e + return $ case t of + IntTy -> FloorMod' e m + _ -> e + +runExpr :: MonadError Error m => [(VarName, Type)] -> Expr -> m Expr +runExpr env = \case + FloorMod' e m -> runFloorMod env e m + e -> return e + +runProgram :: MonadError Error m => Program -> m Program +runProgram = mapExprProgramM runExpr + +-- | `run` propagates `FloorMod` to leaves of exprs. +-- For example, this converts the following: +-- +-- > mod ((fun x -> x * x + x) y) 1000000007 +-- +-- to: +-- +-- > (fun x -> mod (mod (x * x) 1000000007 + x) 1000000007) y +run :: MonadError Error m => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.PropagateMod" $ do + prog <- runProgram prog + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Convert/StrengthReduction.hs b/src/Jikka/Core/Convert/StrengthReduction.hs index 655d915a..810385a0 100644 --- a/src/Jikka/Core/Convert/StrengthReduction.hs +++ b/src/Jikka/Core/Convert/StrengthReduction.hs @@ -175,6 +175,7 @@ reduceFoldMap = \case reduceFoldBuild :: Expr -> Expr reduceFoldBuild = \case + Foldl' _ t (Lam [(x1, t1), (x2, _)] body) x (Range1' n) | x2 `isUnusedVar` body -> NatInd' t x (Lam [(x1, t1)] body) n Len' _ (Range1' n) -> n At' _ (Range1' _) i -> i Sum' (Range1' n) -> go $ FloorDiv' (Mult' n (Minus' n Lit1)) Lit2 @@ -187,8 +188,13 @@ reduceFoldBuild = \case ArgMin' _ (Range1' _) -> Lit0 e -> e +reduceFold :: Expr -> Expr +reduceFold = \case + NatInd' _ v (Lam [(x, _)] (MatAp' n _ f (Var x'))) k | x `isUnusedVar` f && x == x' -> MatAp' n n (MatPow' n f k) v + e -> e + reduceList :: Expr -> Expr -reduceList = reduceFoldBuild . reduceFoldMap . reduceMapMap . reduceBuild +reduceList = reduceFold . reduceFoldBuild . reduceFoldMap . reduceMapMap . reduceBuild misc :: Expr -> Expr misc = \case diff --git a/src/Jikka/Core/Convert/TrivialLetElimination.hs b/src/Jikka/Core/Convert/TrivialLetElimination.hs new file mode 100644 index 00000000..55e90c84 --- /dev/null +++ b/src/Jikka/Core/Convert/TrivialLetElimination.hs @@ -0,0 +1,64 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} + +module Jikka.Core.Convert.TrivialLetElimination + ( run, + run', + ) +where + +import qualified Data.Map as M +import Data.Maybe (fromMaybe) +import Jikka.Common.Error +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Lint +import Jikka.Core.Language.Util + +type Env = M.Map VarName Expr + +runExpr :: Env -> Expr -> Expr +runExpr env = \case + Var x -> fromMaybe (Var x) (M.lookup x env) + Lit lit -> Lit lit + App f args -> App (runExpr env f) (map (runExpr env) args) + Lam args body -> Lam args (runExpr env body) + Let x t e1 e2 -> + let e1' = runExpr env e1 + in if countOccurrences x e2 <= 1 + then runExpr (M.insert x e1' env) e2 + else Let x t e1' (runExpr env e2) + +runToplevelExpr :: Env -> ToplevelExpr -> ToplevelExpr +runToplevelExpr env = \case + ResultExpr e -> ResultExpr (runExpr env e) + ToplevelLet x t e cont -> + let e' = runExpr env e + in if countOccurrencesToplevelExpr x cont <= 1 + then runToplevelExpr (M.insert x e' env) cont + else ToplevelLet x t e' (runToplevelExpr env cont) + ToplevelLetRec f args ret body cont -> + ToplevelLetRec f args ret (runExpr env body) (runToplevelExpr env cont) + +run' :: Program -> Program +run' = runToplevelExpr M.empty + +-- | `run` remove let-exprs whose assigned variables are used only at most once. +-- This assumes that the program is alpha-converted. +-- +-- For example, this converts the following: +-- +-- > let f = fun y -> y +-- > in let x = 1 +-- > in f(x + x) +-- +-- to: +-- +-- > let x = 1 +-- > in (fun y -> y) (x + x) +-- +-- NOTE: this doesn't constant folding. +run :: MonadError Error m => Program -> m Program +run prog = wrapError' "Jikka.Core.Convert.ConstantPropagation" $ do + prog <- return $ run' prog + ensureWellTyped prog + return prog diff --git a/src/Jikka/Core/Evaluate.hs b/src/Jikka/Core/Evaluate.hs index 2850aae4..87921ac2 100644 --- a/src/Jikka/Core/Evaluate.hs +++ b/src/Jikka/Core/Evaluate.hs @@ -30,9 +30,11 @@ import Data.Bits import Data.List (intercalate, sort) import qualified Data.Vector as V import Jikka.Common.Error +import Jikka.Common.Matrix import qualified Jikka.Core.Convert.MakeEager as MakeEager -import Jikka.Core.Format (formatBuiltinIsolated) +import Jikka.Core.Format (formatBuiltinIsolated, formatExpr) import Jikka.Core.Language.Expr +import Jikka.Core.Language.Runtime import Jikka.Core.Language.TypeCheck (builtinToType) import Jikka.Core.Language.Value import Text.Read (readEither) @@ -93,64 +95,24 @@ readInput t tokens = case (t, tokens) of -- ----------------------------------------------------------------------------- -- builtins -floorDiv :: MonadError Error m => Integer -> Integer -> m Integer -floorDiv _ 0 = throwRuntimeError "zero div" -floorDiv a b = return (a `div` b) - -floorMod :: MonadError Error m => Integer -> Integer -> m Integer -floorMod _ 0 = throwRuntimeError "zero div" -floorMod a b = return (a `mod` b) - -ceilDiv :: MonadError Error m => Integer -> Integer -> m Integer -ceilDiv _ 0 = throwRuntimeError "zero div" -ceilDiv a b = return ((a + b - 1) `div` b) - -ceilMod :: MonadError Error m => Integer -> Integer -> m Integer -ceilMod _ 0 = throwRuntimeError "zero div" -ceilMod a b = return (a - ((a + b - 1) `div` b) * b) - natind :: MonadError Error m => Value -> Value -> Integer -> m Value natind _ _ n | n < 0 = throwRuntimeError $ "negative number for mathematical induction: " ++ show n natind base _ 0 = return base natind base step n = do val <- natind base step (n - 1) - callValue step [val, ValInt (n - 1)] - -minimumEither :: (MonadError Error m, Ord a) => [a] -> m a -minimumEither [] = throwRuntimeError "there is no minimum for the empty list" -minimumEither a = return $ minimum a - -maximumEither :: (MonadError Error m, Ord a) => [a] -> m a -maximumEither [] = throwRuntimeError "there is no maximum for the empty list" -maximumEither a = return $ maximum a + callValue step [val] -argminEither :: (MonadError Error m, Ord a) => [a] -> m Integer -argminEither [] = throwRuntimeError "there is no minimum for the empty list" -argminEither a = return $ snd (minimum (zip a [0 ..])) - -argmaxEither :: (MonadError Error m, Ord a) => [a] -> m Integer -argmaxEither [] = throwRuntimeError "there is no maximum for the empty list" -argmaxEither a = return $ snd (maximum (zip a [0 ..])) - -inv :: MonadError Error m => Integer -> Integer -> m Integer -inv a m | m <= 0 || a `mod` m == 0 = throwRuntimeError $ "invalid argument for inv: " ++ show (a, m) -inv _ _ = throwInternalError "TODO: implement inv()" +tabulate :: MonadError Error m => Integer -> Value -> m (V.Vector Value) +tabulate n f = V.fromList <$> mapM (\i -> callValue f [ValInt i]) [0 .. n - 1] -powmod :: MonadError Error m => Integer -> Integer -> Integer -> m Integer -powmod _ _ m | m <= 0 = throwRuntimeError $ "invalid argument for powmod: MOD = " ++ show m -powmod a b m = return $ (a ^ b) `mod` m +map' :: MonadError Error m => Value -> V.Vector Value -> m (V.Vector Value) +map' f a = V.fromList <$> mapM (\val -> callValue f [val]) (V.toList a) scanM :: Monad m => (a -> b -> m a) -> a -> V.Vector b -> m (V.Vector a) scanM f y xs = do (ys, y) <- V.foldM (\(ys, y) x -> (y : ys,) <$> f y x) ([], y) xs return $ V.fromList (reverse (y : ys)) -tabulate :: MonadError Error m => Integer -> Value -> m (V.Vector Value) -tabulate n f = V.fromList <$> mapM (\i -> callValue f [ValInt i]) [0 .. n - 1] - -map' :: MonadError Error m => Value -> V.Vector Value -> m (V.Vector Value) -map' f a = V.fromList <$> mapM (\val -> callValue f [val]) (V.toList a) - atEither :: MonadError Error m => V.Vector a -> Integer -> m a atEither xs i = case xs V.!? fromInteger i of Just x -> return x @@ -177,114 +139,112 @@ range3 :: MonadError Error m => Integer -> Integer -> Integer -> m (V.Vector Val range3 l r step | not (l <= r && step >= 0) = throwRuntimeError $ "invalid argument for range3: " ++ show (l, r, step) range3 l r step = return $ V.fromList (map ValInt [l, l + step .. r]) -fact :: MonadError Error m => Integer -> m Integer -fact n | n < 0 = throwRuntimeError $ "invalid argument for fact: " ++ show n -fact n = return $ product [1 .. n] - -choose :: MonadError Error m => Integer -> Integer -> m Integer -choose n r | not (0 <= r && r <= n) = throwRuntimeError $ "invalid argument for choose: " ++ show (n, r) -choose n r = return $ product [n - r + 1 .. n] `div` product [1 .. r] - -permute :: MonadError Error m => Integer -> Integer -> m Integer -permute n r | not (0 <= r && r <= n) = throwRuntimeError $ "invalid argument for choose: " ++ show (n, r) -permute n r = return $ product [n - r + 1 .. n] - -multichoose :: MonadError Error m => Integer -> Integer -> m Integer -multichoose n r | not (0 <= r && r <= n) = throwRuntimeError $ "invalid argument for multichoose: " ++ show (n, r) -multichoose 0 0 = return 1 -multichoose n r = choose (n + r - 1) r - -- ----------------------------------------------------------------------------- -- evaluator callBuiltin :: MonadError Error m => Builtin -> [Value] -> m Value -callBuiltin builtin args = case (builtin, args) of - -- arithmetical functions - (Negate, [ValInt n]) -> return $ ValInt (- n) - (Plus, [ValInt a, ValInt b]) -> return $ ValInt (a + b) - (Minus, [ValInt a, ValInt b]) -> return $ ValInt (a - b) - (Mult, [ValInt a, ValInt b]) -> return $ ValInt (a * b) - (FloorDiv, [ValInt a, ValInt b]) -> ValInt <$> floorDiv a b - (FloorMod, [ValInt a, ValInt b]) -> ValInt <$> floorMod a b - (CeilDiv, [ValInt a, ValInt b]) -> ValInt <$> ceilDiv a b - (CeilMod, [ValInt a, ValInt b]) -> ValInt <$> ceilMod a b - (Pow, [ValInt a, ValInt b]) -> return $ ValInt (a ^ b) - -- induction functions - (NatInd _, [base, step, ValInt n]) -> natind base step n - -- advanced arithmetical functions - (Abs, [ValInt n]) -> return $ ValInt (abs n) - (Gcd, [ValInt a, ValInt b]) -> return $ ValInt (gcd a b) - (Lcm, [ValInt a, ValInt b]) -> return $ ValInt (lcm a b) - (Min2 IntTy, [ValInt a, ValInt b]) -> return $ ValInt (min a b) -- TODO: allow non-integers - (Max2 IntTy, [ValInt a, ValInt b]) -> return $ ValInt (max a b) -- TODO: allow non-integers - -- logical functions - (Not, [ValBool p]) -> return $ ValBool (not p) - (And, [ValBool p, ValBool q]) -> return $ ValBool (p && q) - (Or, [ValBool p, ValBool q]) -> return $ ValBool (p || q) - (Implies, [ValBool p, ValBool q]) -> return $ ValBool (not p || q) - (If _, [ValBool p, a, b]) -> return $ if p then a else b - -- bitwise functions - (BitNot, [ValInt a]) -> return $ ValInt (complement a) - (BitAnd, [ValInt a, ValInt b]) -> return $ ValInt (a .&. b) - (BitOr, [ValInt a, ValInt b]) -> return $ ValInt (a .|. b) - (BitXor, [ValInt a, ValInt b]) -> return $ ValInt (a `xor` b) - (BitLeftShift, [ValInt a, ValInt b]) -> return $ ValInt (a `shift` fromInteger b) - (BitRightShift, [ValInt a, ValInt b]) -> return $ ValInt (a `shift` fromInteger (- b)) - -- modular functions - (ModInv, [ValInt a, ValInt b]) -> ValInt <$> inv a b - (ModPow, [ValInt a, ValInt b, ValInt c]) -> ValInt <$> powmod a b c - -- list functions - (Cons _, [x, ValList xs]) -> return $ ValList (V.cons x xs) - (Foldl _ _, [f, x, ValList a]) -> V.foldM (\x y -> callValue f [x, y]) x a - (Scanl _ _, [f, x, ValList a]) -> ValList <$> scanM (\x y -> callValue f [x, y]) x a - (Len _, [ValList a]) -> return $ ValInt (fromIntegral (V.length a)) - (Tabulate _, [ValInt n, f]) -> ValList <$> tabulate n f - (Map _ _, [f, ValList a]) -> ValList <$> map' f a - (Filter _, [f, ValList a]) -> ValList <$> V.filterM (\x -> (/= ValBool False) <$> callValue f [x]) a -- TODO - (At _, [ValList a, ValInt n]) -> atEither a n - (SetAt _, [ValList a, ValInt n, x]) -> ValList <$> setAtEither a n x - (Elem _, [x, ValList a]) -> return $ ValBool (x `V.elem` a) - (Sum, [ValList a]) -> ValInt . sum <$> valueToIntList a - (Product, [ValList a]) -> ValInt . product <$> valueToIntList a - (Min1 IntTy, [ValList a]) -> ValInt <$> (minimumEither =<< valueToIntList a) -- TODO: allow non-integers - (Max1 IntTy, [ValList a]) -> ValInt <$> (maximumEither =<< valueToIntList a) -- TODO: allow non-integers - (ArgMin IntTy, [ValList a]) -> ValInt <$> (argminEither =<< valueToIntList a) -- TODO: allow non-integers - (ArgMax IntTy, [ValList a]) -> ValInt <$> (argmaxEither =<< valueToIntList a) -- TODO: allow non-integers - (All, [ValList a]) -> ValBool . and <$> valueToBoolList a - (Any, [ValList a]) -> ValBool . or <$> valueToBoolList a - (Sorted _, [ValList a]) -> return $ ValList (sortVector a) - (List _, [ValList a]) -> return $ ValList a - (Reversed _, [ValList a]) -> return $ ValList (V.reverse a) - (Range1, [ValInt n]) -> ValList <$> range1 n - (Range2, [ValInt l, ValInt r]) -> ValList <$> range2 l r - (Range3, [ValInt l, ValInt r, ValInt step]) -> ValList <$> range3 l r step - -- tuple functions - (Tuple _, xs) -> return $ ValTuple xs - (Proj _ n, [ValTuple xs]) -> return $ xs !! n - -- comparison - (LessThan IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a < b) -- TODO: allow non-integers - (LessEqual IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a <= b) -- TODO: allow non-integers - (GreaterThan IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a > b) -- TODO: allow non-integers - (GreaterEqual IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a >= b) -- TODO: allow non-integers - (Equal _, [a, b]) -> return $ ValBool (a == b) - (NotEqual _, [a, b]) -> return $ ValBool (a /= b) - -- combinational functions - (Fact, [ValInt n]) -> ValInt <$> fact n - (Choose, [ValInt n, ValInt r]) -> ValInt <$> choose n r - (Permute, [ValInt n, ValInt r]) -> ValInt <$> permute n r - (MultiChoose, [ValInt n, ValInt r]) -> ValInt <$> multichoose n r - _ -> throwInternalError $ "invalid builtin call: " ++ formatBuiltinIsolated builtin ++ "(" ++ intercalate "," (map formatValue args) ++ ")" - -callLambda :: MonadError Error m => Env -> [(VarName, Type)] -> Expr -> [Value] -> m Value -callLambda env formalArgs body actualArgs = case (formalArgs, actualArgs) of - ([], []) -> evaluateExpr env body - ((x, _) : formalArgs, val : actualArgs) -> callLambda ((x, val) : env) formalArgs body actualArgs - _ -> throwInternalError "wrong number of arguments for lambda function" +callBuiltin builtin args = wrapError' ("while calling builtin " ++ formatBuiltinIsolated builtin) $ do + case (builtin, args) of + -- arithmetical functions + (Negate, [ValInt n]) -> return $ ValInt (- n) + (Plus, [ValInt a, ValInt b]) -> return $ ValInt (a + b) + (Minus, [ValInt a, ValInt b]) -> return $ ValInt (a - b) + (Mult, [ValInt a, ValInt b]) -> return $ ValInt (a * b) + (FloorDiv, [ValInt a, ValInt b]) -> ValInt <$> floorDiv a b + (FloorMod, [ValInt a, ValInt b]) -> ValInt <$> floorMod a b + (CeilDiv, [ValInt a, ValInt b]) -> ValInt <$> ceilDiv a b + (CeilMod, [ValInt a, ValInt b]) -> ValInt <$> ceilMod a b + (Pow, [ValInt a, ValInt b]) -> return $ ValInt (a ^ b) + -- induction functions + (NatInd _, [base, step, ValInt n]) -> natind base step n + -- advanced arithmetical functions + (Abs, [ValInt n]) -> return $ ValInt (abs n) + (Gcd, [ValInt a, ValInt b]) -> return $ ValInt (gcd a b) + (Lcm, [ValInt a, ValInt b]) -> return $ ValInt (lcm a b) + (Min2 IntTy, [ValInt a, ValInt b]) -> return $ ValInt (min a b) -- TODO: allow non-integers + (Max2 IntTy, [ValInt a, ValInt b]) -> return $ ValInt (max a b) -- TODO: allow non-integers + -- logical functions + (Not, [ValBool p]) -> return $ ValBool (not p) + (And, [ValBool p, ValBool q]) -> return $ ValBool (p && q) + (Or, [ValBool p, ValBool q]) -> return $ ValBool (p || q) + (Implies, [ValBool p, ValBool q]) -> return $ ValBool (not p || q) + (If _, [ValBool p, a, b]) -> return $ if p then a else b + -- bitwise functions + (BitNot, [ValInt a]) -> return $ ValInt (complement a) + (BitAnd, [ValInt a, ValInt b]) -> return $ ValInt (a .&. b) + (BitOr, [ValInt a, ValInt b]) -> return $ ValInt (a .|. b) + (BitXor, [ValInt a, ValInt b]) -> return $ ValInt (a `xor` b) + (BitLeftShift, [ValInt a, ValInt b]) -> return $ ValInt (a `shift` fromInteger b) + (BitRightShift, [ValInt a, ValInt b]) -> return $ ValInt (a `shift` fromInteger (- b)) + -- matrix functions + (MatAp _ _, [f, x]) -> valueFromVector <$> (matap <$> valueToMatrix f <*> valueToVector x) + (MatZero n, []) -> return $ valueFromMatrix (matzero n) + (MatOne n, []) -> return $ valueFromMatrix (matone n) + (MatAdd _ _, [f, g]) -> valueFromMatrix <$> (matadd <$> valueToMatrix f <*> valueToMatrix g) + (MatMul _ _ _, [f, g]) -> valueFromMatrix <$> (matmul <$> valueToMatrix f <*> valueToMatrix g) + (MatPow _, [f, ValInt k]) -> valueFromMatrix <$> (matpow <$> valueToMatrix f <*> pure k) + -- modular functions + (ModInv, [ValInt x, ValInt m]) -> ValInt <$> modinv x m + (ModPow, [ValInt x, ValInt k, ValInt m]) -> ValInt <$> modpow x k m + (ModMatAp _ _, [f, x, ValInt m]) -> valueFromModVector <$> (matap <$> valueToModMatrix m f <*> valueToModVector m x) + (ModMatAdd _ _, [f, g, ValInt m]) -> valueFromModMatrix <$> (matadd <$> valueToModMatrix m f <*> valueToModMatrix m g) + (ModMatMul _ _ _, [f, g, ValInt m]) -> valueFromModMatrix <$> (matmul <$> valueToModMatrix m f <*> valueToModMatrix m g) + (ModMatPow _, [f, ValInt k, ValInt m]) -> valueFromModMatrix <$> (matpow <$> valueToModMatrix m f <*> pure k) + -- list functions + (Cons _, [x, ValList xs]) -> return $ ValList (V.cons x xs) + (Foldl _ _, [f, x, ValList a]) -> V.foldM (\x y -> callValue f [x, y]) x a + (Scanl _ _, [f, x, ValList a]) -> ValList <$> scanM (\x y -> callValue f [x, y]) x a + (Len _, [ValList a]) -> return $ ValInt (fromIntegral (V.length a)) + (Tabulate _, [ValInt n, f]) -> ValList <$> tabulate n f + (Map _ _, [f, ValList a]) -> ValList <$> map' f a + (Filter _, [f, ValList a]) -> ValList <$> V.filterM (\x -> (/= ValBool False) <$> callValue f [x]) a -- TODO + (At _, [ValList a, ValInt n]) -> atEither a n + (SetAt _, [ValList a, ValInt n, x]) -> ValList <$> setAtEither a n x + (Elem _, [x, ValList a]) -> return $ ValBool (x `V.elem` a) + (Sum, [ValList a]) -> ValInt . sum <$> valueToIntList a + (Product, [ValList a]) -> ValInt . product <$> valueToIntList a + (Min1 IntTy, [ValList a]) -> ValInt <$> (minimumEither =<< valueToIntList a) -- TODO: allow non-integers + (Max1 IntTy, [ValList a]) -> ValInt <$> (maximumEither =<< valueToIntList a) -- TODO: allow non-integers + (ArgMin IntTy, [ValList a]) -> ValInt <$> (argminEither =<< valueToIntList a) -- TODO: allow non-integers + (ArgMax IntTy, [ValList a]) -> ValInt <$> (argmaxEither =<< valueToIntList a) -- TODO: allow non-integers + (All, [ValList a]) -> ValBool . and <$> valueToBoolList a + (Any, [ValList a]) -> ValBool . or <$> valueToBoolList a + (Sorted _, [ValList a]) -> return $ ValList (sortVector a) + (List _, [ValList a]) -> return $ ValList a + (Reversed _, [ValList a]) -> return $ ValList (V.reverse a) + (Range1, [ValInt n]) -> ValList <$> range1 n + (Range2, [ValInt l, ValInt r]) -> ValList <$> range2 l r + (Range3, [ValInt l, ValInt r, ValInt step]) -> ValList <$> range3 l r step + -- tuple functions + (Tuple _, xs) -> return $ ValTuple xs + (Proj _ n, [ValTuple xs]) -> return $ xs !! n + -- comparison + (LessThan IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a < b) -- TODO: allow non-integers + (LessEqual IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a <= b) -- TODO: allow non-integers + (GreaterThan IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a > b) -- TODO: allow non-integers + (GreaterEqual IntTy, [ValInt a, ValInt b]) -> return $ ValBool (a >= b) -- TODO: allow non-integers + (Equal _, [a, b]) -> return $ ValBool (a == b) + (NotEqual _, [a, b]) -> return $ ValBool (a /= b) + -- combinational functions + (Fact, [ValInt n]) -> ValInt <$> fact n + (Choose, [ValInt n, ValInt r]) -> ValInt <$> choose n r + (Permute, [ValInt n, ValInt r]) -> ValInt <$> permute n r + (MultiChoose, [ValInt n, ValInt r]) -> ValInt <$> multichoose n r + _ -> throwInternalError $ "invalid builtin call: " ++ formatBuiltinIsolated builtin ++ "(" ++ intercalate "," (map formatValue args) ++ ")" + +callLambda :: MonadError Error m => Maybe VarName -> Env -> [(VarName, Type)] -> Expr -> [Value] -> m Value +callLambda name env formalArgs body actualArgs = wrapError' ("while calling lambda " ++ maybe "(anonymous)" unVarName name) $ do + if length formalArgs /= length actualArgs + then throwInternalError $ "wrong number of arguments for lambda function: expr = " ++ formatExpr (Lam formalArgs body) ++ ", args = (" ++ intercalate ", " (map formatValue actualArgs) ++ ")" + else case (formalArgs, actualArgs) of + ([], []) -> evaluateExpr env body + ((x, _) : formalArgs, val : actualArgs) -> callLambda name ((x, val) : env) formalArgs body actualArgs + _ -> throwInternalError "wrong number of arguments for lambda function" callValue :: MonadError Error m => Value -> [Value] -> m Value callValue f args = case f of ValBuiltin builtin -> callBuiltin builtin args - ValLambda env args' body -> callLambda env args' body args + ValLambda name env args' body -> callLambda name env args' body args _ -> throwInternalError $ "call non-function: " ++ formatValue f evaluateExpr :: MonadError Error m => Env -> Expr -> m Value @@ -297,13 +257,13 @@ evaluateExpr env = \case f <- evaluateExpr env f args <- mapM (evaluateExpr env) args callValue f args - Lam args body -> return $ ValLambda env args body + Lam args body -> return $ ValLambda Nothing env args body Let x _ e1 e2 -> do v1 <- evaluateExpr env e1 evaluateExpr ((x, v1) : env) e2 callBuiltinWithTokens :: MonadError Error m => [Token] -> Builtin -> m (Value, [Token]) -callBuiltinWithTokens tokens builtin = do +callBuiltinWithTokens tokens builtin = wrapError' ("while calling builtin " ++ formatBuiltinIsolated builtin) $ do case builtinToType builtin of FunTy ts _ -> do (args, tokens) <- readInputMap ts tokens @@ -311,14 +271,16 @@ callBuiltinWithTokens tokens builtin = do return (val, tokens) _ -> throwInternalError "all builtin must be functions" -callLambdaWithTokens :: MonadError Error m => [Token] -> Env -> [(VarName, Type)] -> Expr -> m (Value, [Token]) -callLambdaWithTokens tokens env args body = case args of - ((x, t) : args) -> do - (val, tokens) <- readInput t tokens - callLambdaWithTokens tokens ((x, val) : env) args body - [] -> do - val <- evaluateExpr env body - return (val, tokens) +callLambdaWithTokens :: MonadError Error m => [Token] -> Maybe VarName -> Env -> [(VarName, Type)] -> Expr -> m (Value, [Token]) +callLambdaWithTokens tokens name env args body = wrapError' ("while calling lambda " ++ maybe "(anonymous)" unVarName name) $ go tokens env args + where + go tokens env args = case args of + ((x, t) : args) -> do + (val, tokens) <- readInput t tokens + go tokens ((x, val) : env) args + [] -> do + val <- evaluateExpr env body + return (val, tokens) evaluateToplevelExpr :: (MonadFix m, MonadError Error m) => [Token] -> Env -> ToplevelExpr -> m (Value, [Token]) evaluateToplevelExpr tokens env = \case @@ -332,7 +294,7 @@ evaluateToplevelExpr tokens env = \case val <- evaluateExpr env e case val of ValBuiltin builtin -> callBuiltinWithTokens tokens builtin - ValLambda env args body -> callLambdaWithTokens tokens env args body + ValLambda name env args body -> callLambdaWithTokens tokens name env args body _ -> return (val, tokens) evaluateProgram :: (MonadFix m, MonadError Error m) => [Token] -> Program -> m Value diff --git a/src/Jikka/Core/Format.hs b/src/Jikka/Core/Format.hs index 883ac274..b39d4237 100644 --- a/src/Jikka/Core/Format.hs +++ b/src/Jikka/Core/Format.hs @@ -87,9 +87,20 @@ analyzeBuiltin = \case BitXor -> infixOp "^" BitLeftShift -> infixOp "<<" BitRightShift -> infixOp ">>" + -- matrix functions + MatAp _ _ -> fun "matap" + MatZero _ -> fun "matzero" + MatOne _ -> fun "matone" + MatAdd _ _ -> fun "matadd" + MatMul _ _ _ -> fun "matmul" + MatPow _ -> fun "matpow" -- modular functions ModInv -> fun "modinv" ModPow -> fun "modpow" + ModMatAp _ _ -> fun "modmatap" + ModMatAdd _ _ -> fun "modmatadd" + ModMatMul _ _ _ -> fun "modmatmul" + ModMatPow _ -> fun "modmatpow" -- list functions Cons t -> Fun [t] "cons" Foldl t1 t2 -> Fun [t1, t2] "foldl" @@ -137,7 +148,7 @@ formatTemplate = \case ts -> "<" ++ intercalate ", " (map formatType ts) ++ ">" formatFunCall :: String -> [Type] -> [Expr] -> String -formatFunCall f _ args = f ++ "(" ++ intercalate ", " (map formatExpr args) ++ ")" +formatFunCall f _ args = f ++ "(" ++ intercalate ", " (map formatExpr' args) ++ ")" formatBuiltinIsolated' :: Builtin' -> String formatBuiltinIsolated' = \case @@ -153,10 +164,10 @@ formatBuiltinIsolated = formatBuiltinIsolated' . analyzeBuiltin formatBuiltin' :: Builtin' -> [Expr] -> String formatBuiltin' builtin args = case (builtin, args) of (Fun ts name, _) -> formatFunCall name ts args - (PrefixOp op, [e1]) -> paren $ op ++ " " ++ formatExpr e1 - (InfixOp _ op, [e1, e2]) -> paren $ formatExpr e1 ++ " " ++ op ++ " " ++ formatExpr e2 - (At' _, [e1, e2]) -> paren $ formatExpr e1 ++ ")[" ++ formatExpr e2 ++ "]" - (If' _, [e1, e2, e3]) -> paren $ "if" ++ " " ++ formatExpr e1 ++ " then " ++ formatExpr e2 ++ " else " ++ formatExpr e3 + (PrefixOp op, [e1]) -> paren $ op ++ " " ++ formatExpr' e1 + (InfixOp _ op, [e1, e2]) -> paren $ formatExpr' e1 ++ " " ++ op ++ " " ++ formatExpr' e2 + (At' _, [e1, e2]) -> paren $ formatExpr' e1 ++ ")[" ++ formatExpr' e2 ++ "]" + (If' _, [e1, e2, e3]) -> paren $ "if" ++ " " ++ formatExpr' e1 ++ " then " ++ formatExpr' e2 ++ " else " ++ formatExpr' e3 _ -> formatFunCall (formatBuiltinIsolated' builtin) [] args formatBuiltin :: Builtin -> [Expr] -> String @@ -172,26 +183,29 @@ formatLiteral = \case formatFormalArgs :: [(VarName, Type)] -> String formatFormalArgs args = unwords $ map (\(x, t) -> paren (unVarName x ++ ": " ++ formatType t)) args -formatExpr :: Expr -> String -formatExpr = \case +formatExpr' :: Expr -> String +formatExpr' = \case Var x -> unVarName x Lit lit -> formatLiteral lit App f args -> case f of Var x -> formatFunCall (unVarName x) [] args Lit (LitBuiltin builtin) -> formatBuiltin builtin args - _ -> formatFunCall (formatExpr f) [] args - Lam args e -> paren $ "fun " ++ formatFormalArgs args ++ " ->\n" ++ indent ++ "\n" ++ formatExpr e ++ "\n" ++ dedent ++ "\n" - Let x t e1 e2 -> "let " ++ unVarName x ++ ": " ++ formatType t ++ " =\n" ++ indent ++ "\n" ++ formatExpr e1 ++ "\n" ++ dedent ++ "\nin " ++ formatExpr e2 + _ -> formatFunCall (formatExpr' f) [] args + Lam args e -> paren $ "fun " ++ formatFormalArgs args ++ " ->\n" ++ indent ++ "\n" ++ formatExpr' e ++ "\n" ++ dedent ++ "\n" + Let x t e1 e2 -> "let " ++ unVarName x ++ ": " ++ formatType t ++ " =\n" ++ indent ++ "\n" ++ formatExpr' e1 ++ "\n" ++ dedent ++ "\nin " ++ formatExpr' e2 + +formatExpr :: Expr -> String +formatExpr = unwords . makeIndentFromMarkers 4 . lines . formatExpr' formatToplevelExpr :: ToplevelExpr -> [String] formatToplevelExpr = \case - ResultExpr e -> [formatExpr e] + ResultExpr e -> lines (formatExpr' e) ToplevelLet x t e cont -> let' (unVarName x) t e cont ToplevelLetRec f args ret e cont -> let' ("rec " ++ unVarName f ++ " " ++ formatFormalArgs args) ret e cont where let' s t e cont = ["let " ++ s ++ ": " ++ formatType t ++ " =", indent] - ++ lines (formatExpr e) + ++ lines (formatExpr' e) ++ [dedent, "in"] ++ formatToplevelExpr cont diff --git a/src/Jikka/Core/Language/ArithmeticalExpr.hs b/src/Jikka/Core/Language/ArithmeticalExpr.hs new file mode 100644 index 00000000..60105dbe --- /dev/null +++ b/src/Jikka/Core/Language/ArithmeticalExpr.hs @@ -0,0 +1,182 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} + +module Jikka.Core.Language.ArithmeticalExpr where + +import Control.Monad +import Control.Monad.ST +import Control.Monad.Trans +import Control.Monad.Trans.Maybe +import Data.List (findIndices, groupBy, sort, sortBy) +import Data.STRef +import qualified Data.Vector as V +import qualified Data.Vector.Mutable as MV +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Jikka.Core.Language.Vars + +data ProductExpr = ProductExpr + { productExprConst :: Integer, + productExprList :: [Expr] + } + deriving (Eq, Ord, Show, Read) + +data SumExpr = SumExpr + { sumExprList :: [ProductExpr], + sumExprConst :: Integer + } + deriving (Eq, Ord, Show, Read) + +type ArithmeticalExpr = SumExpr + +oneProductExpr :: ProductExpr +oneProductExpr = + ProductExpr + { productExprConst = 1, + productExprList = [] + } + +negateProductExpr :: ProductExpr -> ProductExpr +negateProductExpr e = e {productExprConst = negate (productExprConst e)} + +multProductExpr :: ProductExpr -> ProductExpr -> ProductExpr +multProductExpr e1 e2 = + ProductExpr + { productExprConst = productExprConst e1 * productExprConst e2, + productExprList = productExprList e1 ++ productExprList e2 + } + +parseProductExpr :: Expr -> ProductExpr +parseProductExpr = \case + LitInt' n -> ProductExpr {productExprConst = n, productExprList = []} + Negate' e -> negateProductExpr (parseProductExpr e) + Mult' e1 e2 -> multProductExpr (parseProductExpr e1) (parseProductExpr e2) + e -> ProductExpr {productExprConst = 1, productExprList = [e]} + +sumExprFromProductExpr :: ProductExpr -> SumExpr +sumExprFromProductExpr e = + SumExpr + { sumExprList = [e], + sumExprConst = 0 + } + +sumExprFromInteger :: Integer -> SumExpr +sumExprFromInteger n = + SumExpr + { sumExprConst = n, + sumExprList = [] + } + +zeroSumExpr :: SumExpr +zeroSumExpr = sumExprFromInteger 0 + +negateSumExpr :: SumExpr -> SumExpr +negateSumExpr e = + SumExpr + { sumExprList = map negateProductExpr (sumExprList e), + sumExprConst = negate (sumExprConst e) + } + +plusSumExpr :: SumExpr -> SumExpr -> SumExpr +plusSumExpr e1 e2 = + SumExpr + { sumExprList = sumExprList e1 ++ sumExprList e2, + sumExprConst = sumExprConst e1 + sumExprConst e2 + } + +multSumExpr :: SumExpr -> SumExpr -> SumExpr +multSumExpr e1 e2 = + SumExpr + { sumExprList = + let es1 = parseProductExpr (LitInt' (sumExprConst e1)) : sumExprList e1 + es2 = parseProductExpr (LitInt' (sumExprConst e2)) : sumExprList e2 + in map (uncurry multProductExpr) ((,) <$> es1 <*> es2), + sumExprConst = sumExprConst e1 * sumExprConst e2 + } + +parseSumExpr :: Expr -> SumExpr +parseSumExpr = \case + LitInt' n -> SumExpr {sumExprList = [], sumExprConst = n} + Negate' e -> negateSumExpr (parseSumExpr e) + Plus' e1 e2 -> plusSumExpr (parseSumExpr e1) (parseSumExpr e2) + Minus' e1 e2 -> plusSumExpr (parseSumExpr e1) (negateSumExpr (parseSumExpr e2)) + Mult' e1 e2 -> multSumExpr (parseSumExpr e1) (parseSumExpr e2) + e -> sumExprFromProductExpr (parseProductExpr e) + +-- | `parseArithmeticalExpr` converts a given expr to a normal form \(\sum_i \prod_j e _ {i,j})\). +-- This assumes given exprs have the type \(\mathbf{int}\). +parseArithmeticalExpr :: Expr -> ArithmeticalExpr +parseArithmeticalExpr = parseSumExpr + +formatProductExpr :: ProductExpr -> Expr +formatProductExpr e = + let k = LitInt' (productExprConst e) + k' e' = if productExprConst e == 0 then Lit0 else Mult' e' k + in case productExprList e of + [] -> k + eHead : esTail -> k' (foldl Mult' eHead esTail) + +formatSumExpr :: SumExpr -> Expr +formatSumExpr e = + let k = LitInt' (sumExprConst e) + in case sumExprList e of + [] -> k + eHead : esTail -> + let op e' + | productExprConst e' > 0 = Plus' + | productExprConst e' < 0 = Minus' + | otherwise = const + go e1 e2 = op e2 e1 (formatProductExpr (e2 {productExprConst = abs (productExprConst e2)})) + k' e' + | sumExprConst e > 0 = Plus' e' k + | sumExprConst e < 0 = Minus' e' k + | otherwise = e' + in k' (foldl go (formatProductExpr eHead) esTail) + +formatArithmeticalExpr :: ArithmeticalExpr -> Expr +formatArithmeticalExpr = formatSumExpr + +normalizeProductExpr :: ProductExpr -> ProductExpr +normalizeProductExpr e = + let es = + if productExprConst e == 0 + then [] + else sort (productExprList e) + in e {productExprList = es} + +normalizeSumExpr :: SumExpr -> SumExpr +normalizeSumExpr e = + let cmp e1 e2 = productExprList e1 `compare` productExprList e2 + cmp' e1 e2 = cmp e1 e2 == EQ + es = sortBy cmp (map normalizeProductExpr (sumExprList e)) + es' = groupBy cmp' es + es'' = map (\group -> ProductExpr {productExprConst = sum (map productExprConst group), productExprList = productExprList (head group)}) es' + es''' = filter (\e -> productExprConst e /= 0 && not (null (productExprList e))) es'' + k = sum (map (\e -> if null (productExprList e) then productExprConst e else 0) es'') + in SumExpr + { sumExprList = es''', + sumExprConst = sumExprConst e + k + } + +normalizeArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr +normalizeArithmeticalExpr = normalizeSumExpr + +-- | `makeVectorFromArithmeticalExpr` makes a vector \(f\) and a expr \(c\) from a given vector of variables \(x_0, x_1, \dots, x _ {n - 1}\) and a given expr \(e\) s.t. \(f\) and \(c\) don't have \(x_0, x_1, \dots, x _ {n - 1}\) as free variables and \(e = c + f \cdot (x_0, x_1, \dots, x _ {n - 1})\) holds. +-- This assumes given variables and exprs have the type \(\mathbf{int}\). +makeVectorFromArithmeticalExpr :: V.Vector VarName -> ArithmeticalExpr -> Maybe (V.Vector ArithmeticalExpr, ArithmeticalExpr) +makeVectorFromArithmeticalExpr xs es = runST $ do + runMaybeT $ do + f <- lift $ MV.replicate (V.length xs) zeroSumExpr + c <- lift $ newSTRef (sumExprFromInteger (sumExprConst es)) + forM_ (sumExprList es) $ \e -> do + let indices = V.imap (\i x -> map (i,) (findIndices (x `isFreeVar`) (productExprList e))) xs + case concat (V.toList indices) of + [] -> lift $ modifySTRef c (plusSumExpr (sumExprFromProductExpr e)) + [(i, j)] -> do + let e' = e {productExprList = take j (productExprList e) ++ drop (j + 1) (productExprList e)} + lift $ MV.modify f (plusSumExpr (sumExprFromProductExpr e')) i + _ -> MaybeT $ return Nothing + f <- V.freeze f + c <- lift $ readSTRef c + return (V.map normalizeArithmeticalExpr f, normalizeArithmeticalExpr c) diff --git a/src/Jikka/Core/Language/BuiltinPatterns.hs b/src/Jikka/Core/Language/BuiltinPatterns.hs index 88bdee1f..cf809bd6 100644 --- a/src/Jikka/Core/Language/BuiltinPatterns.hs +++ b/src/Jikka/Core/Language/BuiltinPatterns.hs @@ -72,11 +72,33 @@ pattern BitLeftShift' e1 e2 = AppBuiltin BitLeftShift [e1, e2] pattern BitRightShift' e1 e2 = AppBuiltin BitRightShift [e1, e2] +-- matrix functions + +pattern MatAp' h w e1 e2 = AppBuiltin (MatAp h w) [e1, e2] + +pattern MatZero' n = AppBuiltin (MatZero n) [] + +pattern MatOne' n = AppBuiltin (MatOne n) [] + +pattern MatAdd' h w e1 e2 = AppBuiltin (MatAdd h w) [e1, e2] + +pattern MatMul' h n w e1 e2 = AppBuiltin (MatMul h n w) [e1, e2] + +pattern MatPow' n e1 e2 = AppBuiltin (MatPow n) [e1, e2] + -- modular functions pattern ModInv' e1 e2 = AppBuiltin ModInv [e1, e2] pattern ModPow' e1 e2 e3 = AppBuiltin ModPow [e1, e2, e3] +pattern ModMatAp' h w e1 e2 e3 = AppBuiltin (ModMatAp h w) [e1, e2, e3] + +pattern ModMatAdd' h w e1 e2 e3 = AppBuiltin (ModMatAdd h w) [e1, e2, e3] + +pattern ModMatMul' h n w e1 e2 e3 = AppBuiltin (ModMatMul h n w) [e1, e2, e3] + +pattern ModMatPow' n e1 e2 e3 = AppBuiltin (ModMatPow n) [e1, e2, e3] + -- list functions pattern Cons' t e1 e2 = AppBuiltin (Cons t) [e1, e2] diff --git a/src/Jikka/Core/Language/Expr.hs b/src/Jikka/Core/Language/Expr.hs index a236830c..6ffab626 100644 --- a/src/Jikka/Core/Language/Expr.hs +++ b/src/Jikka/Core/Language/Expr.hs @@ -118,12 +118,34 @@ data Builtin BitLeftShift | -- | \(: \int \times \int \to \int\) BitRightShift + | -- matrix functions + + -- | matrix application \(: \int^{H \times W} \times \int^W \to \int^H\) + MatAp Int Int + | -- | zero matrix \(: \to \int^{n \times n}\) + MatZero Int + | -- | unit matrix \(: \to \int^{n \times n}\) + MatOne Int + | -- | matrix addition \(: \int^{H \times W} \times \int^{H \times W} \to \int^{H \times W}\) + MatAdd Int Int + | -- | matrix multiplication \(: \int^{H \times n} \times \int^{n \times W} \to \int^{H \times W}\) + MatMul Int Int Int + | -- | matrix power \(: \int^{n \times n} \times \int \to \int^{n \times n}\) + MatPow Int | -- modular functions -- | \(: \int \times \int \to \int\) ModInv | -- | \(: \int \times \int \times \int \to \int\) ModPow + | -- | matrix application \(: \int^{H \times W} \times \int^W \times \int \to \int^H\) + ModMatAp Int Int + | -- | matrix addition \(: \int^{H \times W} \times \int^{H \times W} \times \int \to \int^{H \times W}\) + ModMatAdd Int Int + | -- | matrix multiplication \(: \int^{H \times n} \times \int^{n \times W} \times \int \to \int^{H \times W}\) + ModMatMul Int Int Int + | -- | matrix power \(: \int^{n \times n} \times \int \to \int^{n \times n}\) + ModMatPow Int | -- list functions -- | \(: \forall \alpha. \alpha \times \list(\alpha) \to \list(\alpha)\) @@ -259,6 +281,14 @@ pattern FunLTy t <- where FunLTy t = FunTy [ListTy t] t +vectorTy :: Int -> Type +vectorTy n = TupleTy (replicate n IntTy) + +matrixTy :: Int -> Int -> Type +matrixTy h w = TupleTy (replicate h (TupleTy (replicate w IntTy))) + +pattern LitInt' n = Lit (LitInt n) + pattern Lit0 = Lit (LitInt 0) pattern Lit1 = Lit (LitInt 1) @@ -267,6 +297,8 @@ pattern Lit2 = Lit (LitInt 2) pattern LitMinus1 = Lit (LitInt (-1)) +pattern LitBool' p = Lit (LitBool p) + pattern LitTrue = Lit (LitBool True) pattern LitFalse = Lit (LitBool False) diff --git a/src/Jikka/Core/Language/Runtime.hs b/src/Jikka/Core/Language/Runtime.hs new file mode 100644 index 00000000..3028835f --- /dev/null +++ b/src/Jikka/Core/Language/Runtime.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE FlexibleContexts #-} + +module Jikka.Core.Language.Runtime where + +import Jikka.Common.Error + +floorDiv :: MonadError Error m => Integer -> Integer -> m Integer +floorDiv _ 0 = throwRuntimeError "zero div" +floorDiv a b = return (a `div` b) + +floorMod :: MonadError Error m => Integer -> Integer -> m Integer +floorMod _ 0 = throwRuntimeError "zero div" +floorMod a b = return (a `mod` b) + +ceilDiv :: MonadError Error m => Integer -> Integer -> m Integer +ceilDiv _ 0 = throwRuntimeError "zero div" +ceilDiv a b = return ((a + b - 1) `div` b) + +ceilMod :: MonadError Error m => Integer -> Integer -> m Integer +ceilMod _ 0 = throwRuntimeError "zero div" +ceilMod a b = return (a - ((a + b - 1) `div` b) * b) + +minimumEither :: (MonadError Error m, Ord a) => [a] -> m a +minimumEither [] = throwRuntimeError "there is no minimum for the empty list" +minimumEither a = return $ minimum a + +maximumEither :: (MonadError Error m, Ord a) => [a] -> m a +maximumEither [] = throwRuntimeError "there is no maximum for the empty list" +maximumEither a = return $ maximum a + +argminEither :: (MonadError Error m, Ord a) => [a] -> m Integer +argminEither [] = throwRuntimeError "there is no minimum for the empty list" +argminEither a = return $ snd (minimum (zip a [0 ..])) + +argmaxEither :: (MonadError Error m, Ord a) => [a] -> m Integer +argmaxEither [] = throwRuntimeError "there is no maximum for the empty list" +argmaxEither a = return $ snd (maximum (zip a [0 ..])) + +modinv :: MonadError Error m => Integer -> Integer -> m Integer +modinv a m | m <= 0 || a `mod` m == 0 = throwRuntimeError $ "invalid argument for inv: " ++ show (a, m) +modinv _ _ = throwInternalError "TODO: implement inv()" + +modpow :: MonadError Error m => Integer -> Integer -> Integer -> m Integer +modpow _ _ m | m <= 0 = throwRuntimeError $ "invalid argument for modpow: MOD = " ++ show m +modpow a b m = return $ (a ^ b) `mod` m + +fact :: MonadError Error m => Integer -> m Integer +fact n | n < 0 = throwRuntimeError $ "invalid argument for fact: " ++ show n +fact n = return $ product [1 .. n] + +choose :: MonadError Error m => Integer -> Integer -> m Integer +choose n r | not (0 <= r && r <= n) = throwRuntimeError $ "invalid argument for choose: " ++ show (n, r) +choose n r = return $ product [n - r + 1 .. n] `div` product [1 .. r] + +permute :: MonadError Error m => Integer -> Integer -> m Integer +permute n r | not (0 <= r && r <= n) = throwRuntimeError $ "invalid argument for choose: " ++ show (n, r) +permute n r = return $ product [n - r + 1 .. n] + +multichoose :: MonadError Error m => Integer -> Integer -> m Integer +multichoose n r | not (0 <= r && r <= n) = throwRuntimeError $ "invalid argument for multichoose: " ++ show (n, r) +multichoose 0 0 = return 1 +multichoose n r = choose (n + r - 1) r diff --git a/src/Jikka/Core/Language/TypeCheck.hs b/src/Jikka/Core/Language/TypeCheck.hs index e94f03f3..2da817ca 100644 --- a/src/Jikka/Core/Language/TypeCheck.hs +++ b/src/Jikka/Core/Language/TypeCheck.hs @@ -3,7 +3,9 @@ module Jikka.Core.Language.TypeCheck where +import Data.List (intercalate) import Jikka.Common.Error +import Jikka.Core.Format (formatExpr, formatType) import Jikka.Core.Language.Expr builtinToType :: Builtin -> Type @@ -19,7 +21,7 @@ builtinToType = \case CeilMod -> Fun2Ty IntTy Pow -> Fun2Ty IntTy -- induction functions - NatInd t -> FunTy [t, FunTy [IntTy, t] t, IntTy] t + NatInd t -> FunTy [t, FunTy [t] t, IntTy] t -- advanced arithmetical functions Abs -> Fun1Ty IntTy Gcd -> Fun2Ty IntTy @@ -39,9 +41,20 @@ builtinToType = \case BitXor -> Fun2Ty IntTy BitLeftShift -> Fun2Ty IntTy BitRightShift -> Fun2Ty IntTy + -- matrix functions + MatAp h w -> FunTy [matrixTy h w, vectorTy w] (vectorTy h) + MatZero n -> FunTy [] (matrixTy n n) + MatOne n -> FunTy [] (matrixTy n n) + MatAdd h w -> FunTy [matrixTy h w, matrixTy h w] (matrixTy h w) + MatMul h n w -> FunTy [matrixTy h n, matrixTy n w] (matrixTy h w) + MatPow n -> FunTy [matrixTy n n, IntTy] (matrixTy n n) -- modular functions ModInv -> Fun2Ty IntTy ModPow -> Fun3Ty IntTy + ModMatAp h w -> FunTy [matrixTy h w, vectorTy w, IntTy] (vectorTy h) + ModMatAdd h w -> FunTy [matrixTy h w, matrixTy h w, IntTy] (matrixTy h w) + ModMatMul h n w -> FunTy [matrixTy h n, matrixTy n w, IntTy] (matrixTy h w) + ModMatPow n -> FunTy [matrixTy n n, IntTy, IntTy] (matrixTy n n) -- list functions Cons t -> FunTy [t, ListTy t] (ListTy t) Foldl t1 t2 -> FunTy [FunTy [t2, t1] t2, t2, ListTy t1] t2 @@ -96,7 +109,7 @@ type TypeEnv = [(VarName, Type)] typecheckExpr :: MonadError Error m => TypeEnv -> Expr -> m Type typecheckExpr env = \case Var x -> case lookup x env of - Nothing -> throwInternalError $ "undefined variable: " ++ show (unVarName x) + Nothing -> throwInternalError $ "undefined variable: " ++ unVarName x Just t -> return t Lit lit -> return $ literalToType lit App e args -> do @@ -104,28 +117,30 @@ typecheckExpr env = \case ts <- mapM (typecheckExpr env) args case t of FunTy ts' ret | ts' == ts -> return ret - _ -> throwInternalError $ "invalid funcall: " ++ show (App e args, t, ts) + _ -> throwInternalError $ "wrong type funcall: expr = " ++ formatExpr (App e args) ++ ", expected type = " ++ intercalate " * " (map formatType ts) ++ " -> ?, actual type = " ++ formatType t Lam args e -> FunTy (map snd args) <$> typecheckExpr (reverse args ++ env) e Let x t e1 e2 -> do t' <- typecheckExpr env e1 - if t == t' - then typecheckExpr ((x, t) : env) e2 - else throwInternalError $ "wrong type binding: " ++ show (Let x t e1 e2) + when (t /= t') $ do + throwInternalError $ "wrong type binding: " ++ formatExpr (Let x t e1 e2) + typecheckExpr ((x, t) : env) e2 typecheckToplevelExpr :: MonadError Error m => TypeEnv -> ToplevelExpr -> m Type typecheckToplevelExpr env = \case ResultExpr e -> typecheckExpr env e ToplevelLet x t e cont -> do t' <- typecheckExpr env e - if t' == t then return () else throwInternalError "assigned type is not correct" + when (t' /= t) $ do + throwInternalError $ "assigned type is not correct: context = (let " ++ unVarName x ++ ": " ++ formatType t ++ " = " ++ formatExpr e ++ " in ...), expected type = " ++ formatType t ++ ", actual type = " ++ formatType t' typecheckToplevelExpr ((x, t) : env) cont - ToplevelLetRec x args ret body cont -> do + ToplevelLetRec f args ret body cont -> do let t = case args of [] -> ret _ -> FunTy (map snd args) ret - ret' <- typecheckExpr (reverse args ++ (x, t) : env) body - if ret' == ret then return () else throwInternalError "returned type is not correct" - typecheckToplevelExpr ((x, t) : env) cont + ret' <- typecheckExpr (reverse args ++ (f, t) : env) body + when (ret' /= ret) $ do + throwInternalError $ "returned type is not correct: context = (let rec " ++ unVarName f ++ " " ++ unwords (map (\(x, t) -> unVarName x ++ ": " ++ formatType t) args) ++ ": " ++ formatType ret ++ " = " ++ formatExpr body ++ " in ...), expected type = " ++ formatType ret ++ ", actual type = " ++ formatType ret' + typecheckToplevelExpr ((f, t) : env) cont typecheckProgram :: MonadError Error m => Program -> m Type typecheckProgram prog = wrapError' "Jikka.Core.Language.TypeCheck.typecheckProgram" $ do diff --git a/src/Jikka/Core/Language/Util.hs b/src/Jikka/Core/Language/Util.hs index d29989b5..1f9358f2 100644 --- a/src/Jikka/Core/Language/Util.hs +++ b/src/Jikka/Core/Language/Util.hs @@ -2,6 +2,7 @@ module Jikka.Core.Language.Util where +import Control.Monad.Identity import Jikka.Common.Alpha import Jikka.Core.Language.Expr @@ -52,9 +53,20 @@ mapTypeInBuiltin f = \case BitXor -> BitXor BitLeftShift -> BitLeftShift BitRightShift -> BitRightShift + -- matrix functions + MatAp h w -> MatAp h w + MatZero n -> MatZero n + MatOne n -> MatOne n + MatAdd h w -> MatAdd h w + MatMul h n w -> MatMul h n w + MatPow n -> MatPow n -- modular functionsmodular ModInv -> ModInv ModPow -> ModPow + ModMatAp h w -> ModMatAp h w + ModMatAdd h w -> ModMatAdd h w + ModMatMul h n w -> ModMatMul h n w + ModMatPow n -> ModMatPow n -- list functionslist Cons t -> Cons (f t) Foldl t1 t2 -> Foldl (f t1) (f t2) @@ -95,3 +107,47 @@ mapTypeInBuiltin f = \case Choose -> Choose Permute -> Permute MultiChoose -> MultiChoose + +countOccurrences :: VarName -> Expr -> Int +countOccurrences x = \case + Var y -> if x == y then 1 else 0 + Lit _ -> 0 + App f args -> sum (map (countOccurrences x) (f : args)) + Lam args body -> if x `elem` map fst args then 0 else countOccurrences x body + Let y _ e1 e2 -> countOccurrences x e1 + (if x == y then 0 else countOccurrences x e2) + +countOccurrencesToplevelExpr :: VarName -> ToplevelExpr -> Int +countOccurrencesToplevelExpr x = \case + ResultExpr e -> countOccurrences x e + ToplevelLet y _ e cont -> countOccurrences x e + (if x == y then 0 else countOccurrencesToplevelExpr x cont) + ToplevelLetRec f args _ body cont -> if x == f then 0 else countOccurrencesToplevelExpr x cont + (if x `elem` map fst args then 0 else countOccurrences x body) + +mapExprM :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> Expr -> m Expr +mapExprM f env = \case + Var y -> f env (Var y) + Lit lit -> f env (Lit lit) + App g args -> f env =<< (App <$> mapExprM f env g <*> mapM (mapExprM f env) args) + Lam args body -> f env . Lam args =<< mapExprM f (reverse args ++ env) body + Let y t e1 e2 -> f env =<< (Let y t <$> mapExprM f env e1 <*> mapExprM f ((y, t) : env) e2) + +mapExprToplevelExprM :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr +mapExprToplevelExprM f env = \case + ResultExpr e -> ResultExpr <$> mapExprM f env e + ToplevelLet y t e cont -> + let env' = (y, t) : env + in ToplevelLet y t <$> mapExprM f env' e <*> mapExprToplevelExprM f env' cont + ToplevelLetRec g args ret body cont -> + let env' = (g, FunTy (map snd args) ret) : env + in ToplevelLetRec g args ret <$> mapExprM f (reverse args ++ env) body <*> mapExprToplevelExprM f env' cont + +mapExprProgramM :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> Program -> m Program +mapExprProgramM f = mapExprToplevelExprM f [] + +mapExpr :: ([(VarName, Type)] -> Expr -> Expr) -> [(VarName, Type)] -> Expr -> Expr +mapExpr f env e = runIdentity $ mapExprM (\env e -> return $ f env e) env e + +mapExprToplevelExpr :: ([(VarName, Type)] -> Expr -> Expr) -> [(VarName, Type)] -> ToplevelExpr -> ToplevelExpr +mapExprToplevelExpr f env e = runIdentity $ mapExprToplevelExprM (\env e -> return $ f env e) env e + +mapExprProgram :: ([(VarName, Type)] -> Expr -> Expr) -> Program -> Program +mapExprProgram f prog = runIdentity $ mapExprProgramM (\env e -> return $ f env e) prog diff --git a/src/Jikka/Core/Language/Value.hs b/src/Jikka/Core/Language/Value.hs index 514e40d8..f83b5955 100644 --- a/src/Jikka/Core/Language/Value.hs +++ b/src/Jikka/Core/Language/Value.hs @@ -7,6 +7,8 @@ import Data.Char (toLower) import Data.List (intercalate) import qualified Data.Vector as V import Jikka.Common.Error +import Jikka.Common.Matrix +import Jikka.Common.ModInt import Jikka.Core.Language.Expr data Value @@ -15,7 +17,7 @@ data Value | ValList (V.Vector Value) | ValTuple [Value] | ValBuiltin Builtin - | ValLambda Env [(VarName, Type)] Expr + | ValLambda (Maybe VarName) Env [(VarName, Type)] Expr deriving (Eq, Ord, Show, Read) type Env = [(VarName, Value)] @@ -35,6 +37,10 @@ valueToInt = \case valueToIntList :: MonadError Error m => V.Vector Value -> m [Integer] valueToIntList = mapM valueToInt . V.toList +valueToIntList' :: MonadError Error m => Value -> m [Integer] +valueToIntList' (ValList xs) = valueToIntList xs +valueToIntList' _ = throwRuntimeError "Internal Error: type error" + valueToBool :: MonadError Error m => Value -> m Bool valueToBool = \case ValBool p -> return p @@ -43,6 +49,36 @@ valueToBool = \case valueToBoolList :: MonadError Error m => V.Vector Value -> m [Bool] valueToBoolList = mapM valueToBool . V.toList +valueToVector :: MonadError Error m => Value -> m (V.Vector Integer) +valueToVector (ValTuple x) = V.fromList <$> mapM valueToInt x +valueToVector _ = throwRuntimeError "Internal Error: value is not a vector" + +valueToMatrix :: MonadError Error m => Value -> m (Matrix Integer) +valueToMatrix (ValTuple f) = do + f <- V.fromList <$> mapM valueToVector f + case makeMatrix f of + Nothing -> throwRuntimeError "Internal Error: value is not a matrix" + Just f -> return f +valueToMatrix _ = throwRuntimeError "Internal Error: value is not a matrix" + +valueFromVector :: V.Vector Integer -> Value +valueFromVector x = ValTuple (map ValInt (V.toList x)) + +valueFromMatrix :: Matrix Integer -> Value +valueFromMatrix f = ValTuple (map (ValTuple . map ValInt . V.toList) (V.toList (unMatrix f))) + +valueToModVector :: MonadError Error m => Integer -> Value -> m (V.Vector ModInt) +valueToModVector m x = V.map (`toModInt` m) <$> valueToVector x + +valueToModMatrix :: MonadError Error m => Integer -> Value -> m (Matrix ModInt) +valueToModMatrix m f = fmap (`toModInt` m) <$> valueToMatrix f + +valueFromModVector :: V.Vector ModInt -> Value +valueFromModVector = valueFromVector . V.map fromModInt + +valueFromModMatrix :: Matrix ModInt -> Value +valueFromModMatrix = valueFromMatrix . fmap fromModInt + formatValue :: Value -> String formatValue = \case ValInt n -> show n diff --git a/src/Jikka/Core/Optimize.hs b/src/Jikka/Core/Optimize.hs deleted file mode 100644 index d5870554..00000000 --- a/src/Jikka/Core/Optimize.hs +++ /dev/null @@ -1,31 +0,0 @@ -{-# LANGUAGE FlexibleContexts #-} - --- | --- Module : Jikka.Core.Optimize --- Description : is a meta module to combine other optimizers. --- Copyright : (c) Kimiyuki Onaka, 2020 --- License : Apache License 2.0 --- Maintainer : kimiyuki95@gmail.com --- Stability : experimental --- Portability : portable --- --- `Jikka.Core.Optimize` is a module to combine other all optimizers. -module Jikka.Core.Optimize - ( run, - ) -where - -import Jikka.Common.Alpha -import Jikka.Common.Error -import qualified Jikka.Core.Convert.Alpha as Alpha -import qualified Jikka.Core.Convert.RemoveUnusedVars as RemoveUnusedVars -import qualified Jikka.Core.Convert.StrengthReduction as StrengthReduction -import qualified Jikka.Core.Convert.TypeInfer as TypeInfer -import Jikka.Core.Language.Expr - -run :: (MonadAlpha m, MonadError Error m) => Program -> m Program -run prog = do - prog <- Alpha.run prog - prog <- TypeInfer.run prog - prog <- RemoveUnusedVars.run prog - StrengthReduction.run prog diff --git a/src/Jikka/Main/Subcommand/Convert.hs b/src/Jikka/Main/Subcommand/Convert.hs index 74afdd86..7eb54dfc 100644 --- a/src/Jikka/Main/Subcommand/Convert.hs +++ b/src/Jikka/Main/Subcommand/Convert.hs @@ -7,8 +7,8 @@ import qualified Jikka.CPlusPlus.Convert as FromCore import qualified Jikka.CPlusPlus.Format as FormatCPlusPlus import Jikka.Common.Alpha import Jikka.Common.Error +import qualified Jikka.Core.Convert as Convert import qualified Jikka.Core.Format as FormatCore -import qualified Jikka.Core.Optimize as Optimize import Jikka.Main.Target import qualified Jikka.Python.Convert.ToRestrictedPython as ToRestrictedPython import qualified Jikka.Python.Parse as ParsePython @@ -32,7 +32,7 @@ runCore path input = flip evalAlphaT 0 $ do prog <- ParsePython.run path input prog <- ToRestrictedPython.run prog prog <- ToCore.run prog - prog <- Optimize.run prog + prog <- Convert.run prog FormatCore.run prog runCPlusPlus :: FilePath -> Text -> Either Error Text @@ -40,7 +40,7 @@ runCPlusPlus path input = flip evalAlphaT 0 $ do prog <- ParsePython.run path input prog <- ToRestrictedPython.run prog prog <- ToCore.run prog - prog <- Optimize.run prog + prog <- Convert.run prog prog <- FromCore.run prog FormatCPlusPlus.run prog diff --git a/src/Jikka/Main/Subcommand/Execute.hs b/src/Jikka/Main/Subcommand/Execute.hs index 0e52c331..3e565b78 100644 --- a/src/Jikka/Main/Subcommand/Execute.hs +++ b/src/Jikka/Main/Subcommand/Execute.hs @@ -7,9 +7,9 @@ import Control.Monad.Except import qualified Data.Text.IO as T (readFile) import Jikka.Common.Alpha import Jikka.Common.Error +import qualified Jikka.Core.Convert as ConvertCore import qualified Jikka.Core.Evaluate as EvaluateCore import qualified Jikka.Core.Language.Value as ValueCore -import qualified Jikka.Core.Optimize as OptimizeCore import Jikka.Main.Target import qualified Jikka.Python.Convert.ToRestrictedPython as ToRestrictedPython import qualified Jikka.Python.Parse as FromPython @@ -37,7 +37,7 @@ runCore path = flip evalAlphaT 0 $ do prog <- liftEither $ FromPython.run path prog prog <- ToRestrictedPython.run prog prog <- ToCore.run prog - prog <- OptimizeCore.run prog + prog <- ConvertCore.run prog value <- EvaluateCore.run prog liftIO $ putStrLn (ValueCore.formatValue value) diff --git a/src/Jikka/RestrictedPython/Convert/Alpha.hs b/src/Jikka/RestrictedPython/Convert/Alpha.hs index 75514687..5cce6875 100644 --- a/src/Jikka/RestrictedPython/Convert/Alpha.hs +++ b/src/Jikka/RestrictedPython/Convert/Alpha.hs @@ -67,6 +67,17 @@ renameNew x = do } return y +-- | `renameShadow` renames given variables ignoring the current `Env` and record them to the `Env`. +renameShadow :: (MonadAlpha m, MonadState Env m) => VarName -> m VarName +renameShadow x = do + env <- get + y <- genVarName x + put $ + env + { currentMapping = (x, y) : currentMapping env + } + return y + -- | `renameCompletelyNew` throws errors when given variables already exists in environments. renameCompletelyNew :: (MonadAlpha m, MonadState Env m, MonadError Error m) => VarName -> m VarName renameCompletelyNew x = do @@ -80,7 +91,10 @@ renameToplevel :: (MonadAlpha m, MonadState Env m, MonadError Error m) => VarNam renameToplevel x = do env <- get case lookupName x env of - Just _ -> throwSemanticError $ "cannot redefine variable in toplevel: " ++ unVarName x + Just _ -> do + if x `S.member` builtinNames + then throwSemanticError $ "cannot assign to builtin function: " ++ unVarName x + else throwSemanticError $ "cannot redefine variable in toplevel: " ++ unVarName x Nothing -> do when (unVarName x /= "_") $ do put $ @@ -89,6 +103,10 @@ renameToplevel x = do } return x +-- | `renameToplevelArgument` always introduces a new variable. +renameToplevelArgument :: (MonadAlpha m, MonadState Env m, MonadError Error m) => VarName -> m VarName +renameToplevelArgument = renameShadow + popRename :: (MonadState Env m, MonadError Error m) => VarName -> m () popRename x = when (unVarName x /= "_") $ do @@ -206,7 +224,7 @@ runToplevelStatement = \case g <- renameToplevel f withToplevelScope $ do args <- forM args $ \(x, t) -> do - y <- renameToplevel x + y <- renameToplevelArgument x return (y, t) body <- runStatements body return $ ToplevelFunctionDef g args ret body diff --git a/test/Jikka/Common/MatrixSpec.hs b/test/Jikka/Common/MatrixSpec.hs new file mode 100644 index 00000000..278fd45f --- /dev/null +++ b/test/Jikka/Common/MatrixSpec.hs @@ -0,0 +1,57 @@ +module Jikka.Common.MatrixSpec + ( spec, + ) +where + +import qualified Data.Vector as V +import Jikka.Common.Matrix +import Test.Hspec + +makeMatrix'' :: [[Integer]] -> Matrix Integer +makeMatrix'' = makeMatrix' . V.fromList . map V.fromList + +spec :: Spec +spec = do + describe "matcheck" $ do + it "works" $ do + let f = V.fromList $ map V.fromList [[1, 2, 3], [3, 4, 5]] + let expected = True + matcheck f `shouldBe` expected + it "works'" $ do + let f = V.fromList $ map V.fromList [[1, 2, 3], [3, 4]] + let expected = False + matcheck f `shouldBe` expected + describe "matap" $ do + it "works" $ do + let f = makeMatrix'' [[1, 2, 3], [3, 4, 5]] + let x = V.fromList [1, 2, 3] + let y = V.fromList [14, 26] + matap f x `shouldBe` y + + describe "matadd" $ do + it "works" $ do + let f = makeMatrix'' [[1, 2, 3], [3, 4, 5]] + let g = makeMatrix'' [[7, 7, 7], [6, 5, 4]] + let h = makeMatrix'' [[8, 9, 10], [9, 9, 9]] + matadd f g `shouldBe` h + + describe "matmul" $ do + it "works" $ do + let f = makeMatrix'' [[1, 2, 3], [3, 4, 5]] + let g = makeMatrix'' [[1, 2], [3, 4], [5, 6]] + let h = makeMatrix'' [[22, 28], [40, 52]] + matmul f g `shouldBe` h + + describe "matscalar" $ do + it "works" $ do + let k = 3 + let f = makeMatrix'' [[1, 2, 3], [3, 4, 5]] + let g = makeMatrix'' [[3, 6, 9], [9, 12, 15]] + matscalar k f `shouldBe` g + + describe "matpow" $ do + it "works" $ do + let f = makeMatrix'' [[1, 1], [1, 0]] + let k = 10 + let g = makeMatrix'' [[89, 55], [55, 34]] + matpow f k `shouldBe` g diff --git a/test/Jikka/Core/Convert/ConstantFoldingSpec.hs b/test/Jikka/Core/Convert/ConstantFoldingSpec.hs new file mode 100644 index 00000000..e48137bf --- /dev/null +++ b/test/Jikka/Core/Convert/ConstantFoldingSpec.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.ConstantFoldingSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.ConstantFolding (run) +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let prog = + ResultExpr $ + Lam [("x", IntTy)] (Plus' (Mult' (LitInt' 3) (Var "x")) (Plus' (LitInt' 2) (LitInt' 1))) + let expected = + ResultExpr $ + Lam [("x", IntTy)] (Plus' (Mult' (LitInt' 3) (Var "x")) (LitInt' 3)) + run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/Convert/ConstantPropagationSpec.hs b/test/Jikka/Core/Convert/ConstantPropagationSpec.hs new file mode 100644 index 00000000..242232dd --- /dev/null +++ b/test/Jikka/Core/Convert/ConstantPropagationSpec.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.ConstantPropagationSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.ConstantPropagation (run) +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let prog = + ResultExpr + ( Let + "x" + IntTy + Lit1 + ( Let + "f" + (FunTy [IntTy] IntTy) + (Lam [("y", IntTy)] (Var "y")) + (Plus' (Var "x") (Plus' (Var "x") (App (Var "f") [Var "x"]))) + ) + ) + let expected = + ResultExpr + ( Let + "f" + (FunTy [IntTy] IntTy) + (Lam [("y", IntTy)] (Var "y")) + (Plus' Lit1 (Plus' Lit1 (App (Var "f") [Lit1]))) + ) + run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/Convert/ImmediateAppToLetSpec.hs b/test/Jikka/Core/Convert/ImmediateAppToLetSpec.hs new file mode 100644 index 00000000..28231cf6 --- /dev/null +++ b/test/Jikka/Core/Convert/ImmediateAppToLetSpec.hs @@ -0,0 +1,41 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.ImmediateAppToLetSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.ImmediateAppToLet (run) +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let prog = + ResultExpr + ( Lam + [("a", IntTy)] + ( App + (Lam [("x", IntTy)] (Plus' (Var "x") (Var "x"))) + [Var "a"] + ) + ) + let expected = + ResultExpr + ( Lam + [("a$0", IntTy)] + ( Let + "x$1" + IntTy + (Var "a$0") + (Plus' (Var "x$1") (Var "x$1")) + ) + ) + run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/Convert/LinearFunctionSpec.hs b/test/Jikka/Core/Convert/LinearFunctionSpec.hs new file mode 100644 index 00000000..f2bb3c17 --- /dev/null +++ b/test/Jikka/Core/Convert/LinearFunctionSpec.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.LinearFunctionSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.LinearFunction (run) +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let xs = "xs" + let ts2 = [IntTy, IntTy] + let ts3 = [IntTy, IntTy, IntTy] + let ts23 = [TupleTy ts3, TupleTy ts3] + let proj i = Proj' ts3 i (Var xs) + let prog = + ResultExpr + ( Let + "c" + IntTy + (LitInt' 10) + ( Lam + [(xs, TupleTy ts3)] + (Tuple' ts2 [Plus' (proj 0) (Mult' (Var "c") (proj 1)), Plus' (proj 0) (proj 2)]) + ) + ) + let expected = + ResultExpr + ( Let + "c" + IntTy + (LitInt' 10) + ( Lam + [(xs, TupleTy ts3)] + (MatAp' 2 3 (Tuple' ts23 [Tuple' ts3 [Lit1, Mult' (Var "c") Lit1, Lit0], Tuple' ts3 [Lit1, Lit0, Lit1]]) (Var xs)) + ) + ) + run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/Convert/PropagateModSpec.hs b/test/Jikka/Core/Convert/PropagateModSpec.hs new file mode 100644 index 00000000..be8e272f --- /dev/null +++ b/test/Jikka/Core/Convert/PropagateModSpec.hs @@ -0,0 +1,34 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.PropagateModSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.PropagateMod (run) +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let f e = FloorMod' e (LitInt' 1000000007) + let prog = + ResultExpr + ( Lam + [("y", IntTy)] + (f (App (Lam [("x", IntTy)] (Plus' (Mult' (Var "x") (Var "x")) (Var "x"))) [Var "y"])) + ) + let expected = + ResultExpr + ( Lam + [("y", IntTy)] + (App (Lam [("x", IntTy)] (f (Plus' (f (Mult' (f (Var "x")) (f (Var "x")))) (f (Var "x"))))) [Var "y"]) + ) + run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/Convert/TrivialLetEliminationSpec.hs b/test/Jikka/Core/Convert/TrivialLetEliminationSpec.hs new file mode 100644 index 00000000..f1a1390d --- /dev/null +++ b/test/Jikka/Core/Convert/TrivialLetEliminationSpec.hs @@ -0,0 +1,42 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Convert.TrivialLetEliminationSpec + ( spec, + ) +where + +import Jikka.Common.Alpha +import Jikka.Common.Error +import Jikka.Core.Convert.TrivialLetElimination (run) +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +run' :: Program -> Either Error Program +run' = flip evalAlphaT 0 . run + +spec :: Spec +spec = describe "run" $ do + it "works" $ do + let prog = + ResultExpr + ( Let + "f" + (FunTy [IntTy] IntTy) + (Lam [("y", IntTy)] (Var "y")) + ( Let + "x" + IntTy + Lit1 + (App (Var "f") [Plus' (Var "x") (Var "x")]) + ) + ) + let expected = + ResultExpr + ( Let + "x" + IntTy + Lit1 + (App (Lam [("y", IntTy)] (Var "y")) [Plus' (Var "x") (Var "x")]) + ) + run' prog `shouldBe` Right expected diff --git a/test/Jikka/Core/Language/ArithmeticalExprSpec.hs b/test/Jikka/Core/Language/ArithmeticalExprSpec.hs new file mode 100644 index 00000000..b63e2a53 --- /dev/null +++ b/test/Jikka/Core/Language/ArithmeticalExprSpec.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Jikka.Core.Language.ArithmeticalExprSpec + ( spec, + ) +where + +import qualified Data.Vector as V +import Jikka.Core.Language.ArithmeticalExpr +import Jikka.Core.Language.BuiltinPatterns +import Jikka.Core.Language.Expr +import Test.Hspec + +spec :: Spec +spec = do + describe "parseProductExpr" $ do + it "works" $ do + let e = Mult' (LitInt' 3) (Var "y") + let expected = ProductExpr {productExprConst = 3, productExprList = [Var "y"]} + parseProductExpr e `shouldBe` expected + describe "multSumExpr" $ do + it "may introduce empty ProductExpr" $ do + let e1 = parseSumExpr (LitInt' 3) + let e2 = parseSumExpr (Var "y") + let expected = + SumExpr + { sumExprList = + [ ProductExpr {productExprConst = 0, productExprList = []}, + ProductExpr {productExprConst = 3, productExprList = [Var "y"]} + ], + sumExprConst = 0 + } + multSumExpr e1 e2 `shouldBe` expected + describe "parseArithmeticalExpr" $ do + it "works" $ do + let e = Plus' (Var "x") (Minus' (Mult' (LitInt' 3) (Var "y")) (Plus' (Var "x") (LitInt' 10))) + let expected = + SumExpr + { sumExprList = + [ ProductExpr {productExprConst = 1, productExprList = [Var "x"]}, + ProductExpr {productExprConst = 0, productExprList = []}, + ProductExpr {productExprConst = 3, productExprList = [Var "y"]}, + ProductExpr {productExprConst = -1, productExprList = [Var "x"]} + ], + sumExprConst = -10 + } + parseArithmeticalExpr e `shouldBe` expected + describe "makeVectorFromArithmeticalExpr" $ do + it "works" $ do + let xs = V.fromList ["x", "y"] + let e = + parseArithmeticalExpr + (Plus' (Var "x") (Plus' (Mult' (LitInt' 3) (Var "y")) (Minus' (Var "x") (LitInt' 10)))) + let f = V.fromList [parseArithmeticalExpr (LitInt' 2), parseArithmeticalExpr (LitInt' 3)] + let c = parseArithmeticalExpr (LitInt' (-10)) + makeVectorFromArithmeticalExpr xs e `shouldBe` Just (f, c) diff --git a/test/Jikka/RestrictedPython/Convert/AlphaSpec.hs b/test/Jikka/RestrictedPython/Convert/AlphaSpec.hs index 113f200e..82dbbfbb 100644 --- a/test/Jikka/RestrictedPython/Convert/AlphaSpec.hs +++ b/test/Jikka/RestrictedPython/Convert/AlphaSpec.hs @@ -20,18 +20,40 @@ spec = describe "run" $ do it "works" $ do let parsed = [ ToplevelFunctionDef + "f" + [("n", IntTy)] + IntTy + [ If + (Compare (Name "n") (CmpOp' Eq' (VarTy "t")) (constIntExp 0)) + [ Return (constIntExp 1) + ] + [ Return (BinOp (Name "n") Mult (Call (Name "f") [BinOp (Name "n") Sub (constIntExp 1)])) + ] + ], + ToplevelFunctionDef "solve" - [("x", IntTy)] + [("n", IntTy)] IntTy - [ AnnAssign (NameTrg "y") IntTy (Name "x") + [ Return (BinOp (Call (Name "f") [Name "n"]) FloorMod (constIntExp 1000000007)) ] ] let expected = [ ToplevelFunctionDef + "f" + [("n$0", IntTy)] + IntTy + [ If + (Compare (Name "n$0") (CmpOp' Eq' (VarTy "t")) (constIntExp 0)) + [ Return (constIntExp 1) + ] + [ Return (BinOp (Name "n$0") Mult (Call (Name "f") [BinOp (Name "n$0") Sub (constIntExp 1)])) + ] + ], + ToplevelFunctionDef "solve" - [("x", IntTy)] + [("n$1", IntTy)] IntTy - [ AnnAssign (NameTrg "y$0") IntTy (Name "x") + [ Return (BinOp (Call (Name "f") [Name "n$1"]) FloorMod (constIntExp 1000000007)) ] ] run' parsed `shouldBe` Right expected @@ -82,15 +104,15 @@ spec = describe "run" $ do let expected = [ ToplevelFunctionDef "foo" - [("x", IntTy)] + [("x$0", IntTy)] IntTy - [ AnnAssign (NameTrg "y$0") IntTy (Name "x") + [ AnnAssign (NameTrg "y$1") IntTy (Name "x$0") ], ToplevelFunctionDef "bar" - [("x", IntTy)] + [("x$2", IntTy)] IntTy - [ AnnAssign (NameTrg "y$1") IntTy (Name "x") + [ AnnAssign (NameTrg "y$3") IntTy (Name "x$2") ] ] run' parsed `shouldBe` Right expected @@ -170,9 +192,9 @@ spec = describe "run" $ do let expected = [ ToplevelFunctionDef "f" - [("x", IntTy)] + [("x$0", IntTy)] IntTy - [ Return (Call (Name "f") [Name "x"]) + [ Return (Call (Name "f") [Name "x$0"]) ] ] run' parsed `shouldBe` Right expected @@ -188,9 +210,9 @@ spec = describe "run" $ do let expected = [ ToplevelFunctionDef "f" - [("x", VarTy "x")] + [("x$0", VarTy "x")] (VarTy "f") - [ Return (Call (Name "f") [Name "x"]) + [ Return (Call (Name "f") [Name "x$0"]) ] ] run' parsed `shouldBe` Right expected