Skip to content

Commit

Permalink
end-to-end compiling and running via LLVM text
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jan 9, 2025
1 parent 7bf7ad9 commit d3e490c
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 291 deletions.
22 changes: 7 additions & 15 deletions dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ data-files: lib/*.dx
, static/style.css
, src/lib/dexrt.bc

flag cuda
description: Enables building with CUDA support
default: False

flag optimized
description: Enables GHC optimizations
default: False
Expand Down Expand Up @@ -57,7 +53,7 @@ library
, PPrint
, RawName
-- , RuntimePrint
-- , Serialize
, Serialize
, Simplify
, Subst
, SourceRename
Expand Down Expand Up @@ -119,8 +115,8 @@ library
-Wno-unticked-promoted-constructors
-fPIC
-optP-Wno-nonportable-include-path
cxx-sources: src/lib/dexrt.cpp
cxx-options: -std=c++11 -fPIC
cxx-sources: src/lib/dexllvm.cpp
cxx-options: -std=c++17 -fPIC
-- Mimicking -XGHC2021 in GHC 8.10.1
default-extensions: BangPatterns
, BinaryLiterals
Expand Down Expand Up @@ -186,15 +182,12 @@ library
, ViewPatterns

pkgconfig-depends: libpng
if flag(cuda)
include-dirs: /usr/local/cuda/include
extra-libraries: cuda
cxx-options: -DDEX_CUDA
cpp-options: -DDEX_CUDA
extra-libraries: stdc++ LLVM-16
if flag(optimized)
ghc-options: -O3
else
ghc-options: -O0
extra-lib-dirs: /usr/lib/llvm-16/lib

executable dex
main-is: dex.hs
Expand Down Expand Up @@ -224,8 +217,7 @@ executable dex
, LambdaCase
, OverloadedStrings
, BlockArguments
if flag(cuda)
cpp-options: -DDEX_CUDA
extra-libraries: stdc++ LLVM-16
if flag(optimized)
ghc-options: -O3
else
Expand Down Expand Up @@ -253,7 +245,7 @@ foreign-library Dex
, text
hs-source-dirs: src/
c-sources: src/Dex/Foreign/rts.c
cc-options: -std=c11 -fPIC
cc-options: -std=c17 -fPIC
ghc-options: -Wall
-fPIC
-optP-Wno-nonportable-include-path
Expand Down
26 changes: 2 additions & 24 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -113,31 +113,9 @@ ifneq (,$(DEX_CI))
STACK_FLAGS := $(STACK_FLAGS) --flag dex:debug
endif

possible-clang-locations := clang++-9 clang++-10 clang++-11 clang++-12 clang++

CLANG := clang++

ifeq (1,$(DEX_LLVM_HEAD))
ifeq ($(PLATFORM),Darwin)
$(error LLVM head builds not supported on macOS!)
endif
STACK_FLAGS := $(STACK_FLAGS) --flag dex:llvm-head
STACK := $(STACK) --stack-yaml=stack-llvm-head.yaml
else
CLANG := $(shell for clangversion in $(possible-clang-locations) ; do \
if [[ $$(command -v "$$clangversion" 2>/dev/null) ]]; \
then echo "$$clangversion" ; break ; fi ; done)
ifeq (,$(CLANG))
$(error "Please install clang++-12")
endif
clang-version-compatible := $(shell $(CLANG) -dumpversion | awk '{ print(gsub(/^((9\.)|(10\.)|(11\.)|(12\.)).*$$/, "")) }')
ifneq (1,$(clang-version-compatible))
$(error "Please install clang++-12")
endif
endif

CXXFLAGS := $(CFLAGS) -std=c++11 -fno-exceptions -fno-rtti -pthread
CFLAGS := $(CFLAGS) -std=c11
CXXFLAGS := $(CFLAGS) -std=c++17 -fno-exceptions -fno-rtti -pthread
CFLAGS := $(CFLAGS) -std=c17

.PHONY: all
all: build
Expand Down
4 changes: 2 additions & 2 deletions src/dex.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import ConcreteSyntax (parseSourceBlocks)
import PPrint
import Util (readFileText, BString)


data EvalMode = ReplMode
| ScriptMode FilePath
| WebMode FilePath
Expand All @@ -41,7 +40,8 @@ runMode (CmdOpts evalMode cfg) = case evalMode of
forM_ blocks \block -> do
liftIO $ BS.putStr $ pprint block
evalSourceBlockRepl block
Doit -> error "This is an entry point for running ad-hoc Haskell code."
Doit -> undefined -- do whatever you want


stdOutLogger :: Outputs -> IO ()
stdOutLogger (Outputs outs) = do
Expand Down
43 changes: 28 additions & 15 deletions src/lib/LLVMFFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,37 @@
module LLVMFFI (LLVMContext, initializeLLVM, compileLLVM, getFunctionPtr,
callEntryFun) where

import Control.Monad
import qualified Data.ByteString as BS
import Foreign.Ptr
import qualified Types.LLVM as L
import Data.Int
import Util (BString)
import PPrint

foreign import ccall "doit_cpp" doit_cpp :: Int64 -> IO Int64
foreign import ccall "initialize_jit" initialize_jit :: IO Int
foreign import ccall "add_to_jit" add_to_jit :: Ptr () -> Int64 -> IO Int
foreign import ccall "get_function_ptr" get_function_ptr :: Ptr () -> Int64 -> IO (Ptr ())
foreign import ccall "call_function_ptr" call_function_ptr :: Ptr () -> IO (Ptr ())

type FunctionPtr = ()
type LLVMContext = ()
type DataPtr = ()
type DataListPtr = ()
type FunctionPtr = Ptr ()
type DataPtr = Ptr ()
type DataListPtr = Ptr ()

initializeLLVM :: IO LLVMContext
initializeLLVM = return undefined

compileLLVM :: LLVMContext -> BString -> IO ()
compileLLVM _ _ = return undefined

getFunctionPtr :: LLVMContext -> BString -> IO FunctionPtr
getFunctionPtr _ _ = return undefined

callEntryFun :: FunctionPtr -> [DataPtr] -> IO DataPtr
callEntryFun _ _ = return undefined
initializeLLVM = initialize_jit >> return ()

compileLLVM :: LLVMContext -> L.Module -> IO ()
compileLLVM _ f = do
BS.useAsCStringLen (pprint f) \(ptr, n) ->
void $ add_to_jit (castPtr ptr) (fromIntegral n)

getFunctionPtr :: LLVMContext -> L.Name -> IO FunctionPtr
getFunctionPtr _ fname = do
BS.useAsCStringLen fname.val \(ptr, n) ->
castPtr <$> get_function_ptr(castPtr ptr) (fromIntegral n)

callEntryFun :: FunctionPtr -> [DataPtr] -> IO ()
callEntryFun fPtr [] = do
call_function_ptr fPtr
return ()
25 changes: 20 additions & 5 deletions src/lib/ToLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ import Util

-- === entrypoint ===

toLLVMEntryFun :: Monad m => L.Name -> TopLamExpr -> m L.Function
toLLVMEntryFun :: Monad m => L.Name -> TopLamExpr -> m L.Module
toLLVMEntryFun fname fun = do
finalState <- runTranslateM do
toLLVMEntryFun' fun
startNewBlock $ L.Name "__unused__"
let blocks = reverse finalState.basicBlocks
return $ L.Function fname [] blocks
let decl = L.FunctionDef $ L.Function fname [] blocks
return $ L.Module $ libDecls ++ [decl]

libDecls :: [L.TopDecl]
libDecls = [
L.FunctionDecl floatTy "printfloat" [floatTy]
]

floatTy :: L.Type
floatTy = L.BaseType $ Scalar Float32Type

-- === monad for the translation ===

Expand Down Expand Up @@ -71,6 +80,10 @@ extendEnv b x cont = TranslateM do
put $ updateSubst newState prevState.subst
return ans

-- lowering of ()
unitOperand :: L.Operand
unitOperand = L.Operand (L.Lit (Int32Lit 0)) (L.BaseType $ Scalar Int32Type)

lookupEnv :: Name i -> TranslateM i L.Operand
lookupEnv v = TranslateM do
env <- gets (.subst)
Expand Down Expand Up @@ -98,8 +111,8 @@ startNewBlock blockName = TranslateM $ modify \state -> do

toLLVMEntryFun' :: TopLamExpr -> TranslateM VoidS ()
toLLVMEntryFun' (TopLamExpr (Abs Empty body)) = do
trExpr body
return ()
ans <- trExpr body
emitStatement $ L.Return ans

trExpr :: Expr i -> TranslateM i L.Operand
trExpr = \case
Expand Down Expand Up @@ -134,4 +147,6 @@ trPrimOp resultTy op = case op of
BinOp b x y -> case b of
FAdd -> emitInstr resultTy $ L.FAdd x y
MiscOp op' -> case op' of
DebugPrintInt x -> undefined
DebugPrintInt x -> do
emitStatement $ L.Call floatTy "printfloat" [x]
return unitOperand
10 changes: 5 additions & 5 deletions src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ execUDecl decl = do
CTopLet Nothing expr <- checkPass TypePass $ inferTopUDecl renamed
simpFun <- simplifyTopFun (exprAsNullaryFun expr)
logPass SimpPass simpFun
let tempFunName = L.Name "main" -- TODO: need to get a name
let tempFunName = L.Name "__top_level_expr__" -- TODO: need to get a name
llvmContext <- TopperM $ asks topperLLVMContext
llvmFun <- toLLVMEntryFun tempFunName simpFun
logPass LLVMPass llvmFun
-- liftIO do
-- compileLLVM llvmContext llvmFun
-- f <- getFunctionPtr llvmContext tempFunName
-- callEntryFun f []
liftIO do
compileLLVM llvmContext llvmFun
f <- liftIO $ getFunctionPtr llvmContext tempFunName
callEntryFun f []
return ()

execCDecl :: CTopDecl -> TopperM ()
Expand Down
50 changes: 38 additions & 12 deletions src/lib/Types/LLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

module Types.LLVM where

import Data.String
import Control.Monad
import Control.Monad.State
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BS
Expand All @@ -21,7 +23,11 @@ import Util (bs2str)
newtype Name = Name { val :: ByteString }
type Binder = (Name, Type)

data Module = Module { functions :: [Function] }
data Module = Module { functions :: [TopDecl] }

data TopDecl =
FunctionDef Function
| FunctionDecl Type Name [Type]

data Function = Function
{ name :: Name
Expand All @@ -35,6 +41,7 @@ data BasicBlock = BasicBlock
type Decl = (Maybe Name, Type, Instruction)
data Instruction =
FAdd Operand Operand
| Call Type Name [Operand]
| Return Operand

data Operand = Operand { val :: UntypedOperand, ty :: Type }
Expand All @@ -50,6 +57,24 @@ data Type =

-- This is load-bearing! We have to generate correct LLVM textual representation.

instance Pretty Module where
prLines m = forM_ m.functions \f -> do
prLines f
emitLine ""

instance Pretty TopDecl where
prLines = \case
FunctionDef f -> prLines f
FunctionDecl ty fname argTys -> do
emitLine $ "declare" <+> pr ty <+> app (prTopName fname) (prDeclArgs argTys)

prDeclArgs :: [Type] -> [BS.Builder]
prDeclArgs tys = flip evalState (0::Int) do
forM tys \ty -> do
i <- get
put (i + 1)
return $ pr ty <+> "%" <> pr i

instance Pretty Function where
prLines (Function name [] body) = do
emitLine $ "define i32" <+> prTopName name <> "() {"
Expand All @@ -66,20 +91,29 @@ prLocalName name = "%" <> BS.byteString name.val

prDecl :: Decl -> BS.Builder
prDecl (Just v, resultTy, instr) = prLocalName v <> " = " <> prInstr resultTy instr
prDecl (Nothing, resultTy, instr) = prInstr resultTy instr

prInstr :: Type -> Instruction -> BS.Builder
prInstr resultTy = \case
FAdd x y -> "fadd " <> pr resultTy <+> pr x.val <> ", " <> pr y.val
FAdd x y -> "fadd" <+> pr resultTy <+> pr x.val <> ", " <> pr y.val
Call ty f xs -> "call" <+> pr ty <+> app (prTopName f) (map pr xs)
Return x -> "ret" <+> pr x

instance Pretty BasicBlock where
prLines (BasicBlock name decls) = do
emitLine $ pr name <> ":"
indent do
forM_ decls \decl -> emitLine $ prDecl decl

instance IsString Name where
fromString s = Name $ fromString s

instance Pretty Name where
pr name = BS.byteString name.val

instance Pretty Operand where
pr x = pr x.ty <+> pr x.val

instance Pretty UntypedOperand where
pr = \case
LocalOcc v -> prLocalName v
Expand All @@ -88,14 +122,6 @@ instance Pretty UntypedOperand where
instance Pretty Type where
pr = \case
BaseType (P.Scalar b) -> case b of
P.Float32Type -> "f32"
P.Float32Type -> "float"
P.Int32Type -> "i32"
VoidType -> "void"


-- instance LLVMSer Operand where
-- lpr x = cat [lpr (getType x), ", ", printOperandWithoutType x]

-- instance Pretty Type where
-- pr = undefined


Loading

0 comments on commit d3e490c

Please sign in to comment.