From da3ec68123ab4e2b0b580c3a5e946de4545550f9 Mon Sep 17 00:00:00 2001 From: Alistair Michael Date: Wed, 18 Dec 2024 14:54:50 +1000 Subject: [PATCH] client/server --- bin/dune | 4 +- bin/main.ml | 351 ++++++++++++++++++++++++++++++++++--------- dune-project | 2 +- gtirb_semantics.opam | 3 + 4 files changed, 284 insertions(+), 76 deletions(-) diff --git a/bin/dune b/bin/dune index 548ed2c..3c89207 100644 --- a/bin/dune +++ b/bin/dune @@ -1,7 +1,7 @@ (executable (public_name gtirb_semantics) (name main) - (flags (:standard -w -69)) + (flags (:standard -w -69 -w -32 -w -27)) (preprocess (pps ppx_jane -dont-apply=sexp_message)) - (libraries base64 yojson gtirb_semantics asli.libASL janestreet_lru_cache)) + (libraries base64 yojson gtirb_semantics asli.libASL janestreet_lru_cache lwt.unix mtime mtime.clock)) diff --git a/bin/main.ml b/bin/main.ml index ab67916..11d70b6 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -13,6 +13,7 @@ module Result = OcamlResult (* TYPES *) +let () = Printexc.record_backtrace true (* These could probably be simplified *) (* OCaml representation of mid-evaluation code block *) type rectified_block = { @@ -30,10 +31,13 @@ type dis_error = { error: string } + +type opcode_sem = ((string list, dis_error) result) + (* ASLi semantic info for a block *) type ast_block = { auuid : bytes; - asts : ((string list, dis_error) result) list; + asts : opcode_sem list; } (* Wrapper for polymorphic code/data/not-set block pre-rectification *) @@ -43,14 +47,17 @@ type content_block = { address : int; } -let decode_instr_success = ref 0 -let decode_instr_fail = ref 0 - (* CONSTANTS *) let opcode_length = 4 let json_file = ref "" +let serve = ref false +let client = ref true +let shutdown_server = ref false let speclist = [ ("--json", Arg.Set_string json_file, "output json semantics to given file (default: none, use /dev/stderr for stderr)"); + ("--serve", Arg.Set serve, "Start server process (in foreground)"); + ("--local", Arg.Clear client, "Do not use client to server"); + ("--shutdown-server", Arg.Set shutdown_server, "Stop server process"); ] let count_pos_args = ref (0) let in_file = ref "/nowhere/input" @@ -62,9 +69,10 @@ let handle_rest_arg arg = | 2 -> out_file := arg | _ -> () +let usage_string = "[options] [input.gtirb output.gts]" +let usage_message = Printf.sprintf "usage: %s %s\n" Sys.argv.(0) usage_string + -let usage_string = "GTIRB_FILE OUTPUT_FILE [--json JSON_SEMANTICS_OUTPUT]" -let usage_message = Printf.sprintf "usage: %s [--help] %s\n" Sys.argv.(0) usage_string (* ASL specifications are from the bundled ARM semantics in libASL. *) (* Protobuf spelunking *) @@ -100,65 +108,236 @@ let do_block ~(need_flip: bool) (b, c : content_block * CodeBlock.t): rectified_ { size; offset; ruuid; contents; opcodes; address } +let (let*) = Lwt.bind + +module Rpc = struct + + let message_count = ref 0 + + let sockfpath = match (Sys.getenv_opt "GTIRB_SEM_SOCKET") with + | Some x -> (x) + | None -> ("gtirb_semantics_socket") + + + let sockaddr = Lwt_unix.ADDR_UNIX sockfpath + + type msg_call = + | Shutdown + | Lift of {addr: int; opcode_be: string} + | LiftAll of (string * int) list + + type msg_resp = + | Ok of opcode_sem + | All of opcode_sem list +end -module DisCache = Lru_cache.Make (struct - open! Core.Bytes - open Core - open! Lru_cache - type t = (string * int) [@@deriving compare, hash, sexp_of] - let invariant = ignore - end) -let disas_cache : ((string list, dis_error) result) DisCache.t = DisCache.create ~max_size:10000 () +module InsnLifter = struct + module DisCache = Lru_cache.Make (struct + open! Core.Bytes + open Core + open! Lru_cache + type t = (string * int) [@@deriving compare, hash, sexp_of] + let invariant = ignore + end) -let env = - match Arm_env.aarch64_evaluation_environment () with - | Some e -> e - | None -> Printf.eprintf "unable to load bundled asl files. has aslp been installed correctly?"; exit 1 + let disas_cache : ((string list, dis_error) result) DisCache.t = DisCache.create ~max_size:5000 () -let to_asli_impl (opcode_be: string) (addr : int) : ((string list, dis_error) result) = - let p_raw a = Utils.to_string (Asl_parser_pp.pp_raw_stmt a) |> String.trim in - let p_pretty a = Asl_utils.pp_stmt a |> String.trim in - let p_byte (b: char) = Printf.sprintf "%02X" (Char.code b) in - let address = Some (string_of_int addr) in + (* number of cache misses *) + let decode_instr_success = ref 0 + (* number of serviced decode requests*) + let decode_instr_total = ref 0 + (* number of errors *) + let decode_instr_fail = ref 0 - (* below, opnum is the numeric opcode (necessarily BE) and opcode_* are always LE. *) - (* TODO: change argument of to_asli to follow this convention. *) - let opnum = Int32.to_int String.(get_int32_be opcode_be 0) in - let opnum_str = Printf.sprintf "0x%08lx" Int32.(of_int opnum) in - let opcode_list : char list = List.(rev @@ of_seq @@ String.to_seq opcode_be) in - let opcode_str = String.concat " " List.(map p_byte opcode_list) in - let _opcode : bytes = Bytes.of_seq List.(to_seq opcode_list) in - let do_dis () : ((string list * string list), dis_error) result = - (match Dis.retrieveDisassembly ?address env (Dis.build_env env) opnum_str with - | res -> Ok (List.map p_raw res, List.map p_pretty res) - | exception exc -> - Printf.eprintf - "error during aslp disassembly (unsupported opcode %s, bytes %s):\n\nException : %s\n" - opnum_str opcode_str (Printexc.to_string exc); - (* Printexc.print_backtrace stderr; *) - Error { - opcode = opnum_str; - error = (Printexc.to_string exc) - } - ) - in Result.map fst (do_dis ()) + let env = lazy begin + match Arm_env.aarch64_evaluation_environment () with + | Some e -> e + | None -> Printf.eprintf "unable to load bundled asl files. has aslp been installed correctly?"; exit 1 + end -let cache = true + let to_asli_impl (opcode_be: string) (addr : int) : ((string list, dis_error) result) = + let p_raw a = Utils.to_string (Asl_parser_pp.pp_raw_stmt a) |> String.trim in + let p_pretty a = Asl_utils.pp_stmt a |> String.trim in + let p_byte (b: char) = Printf.sprintf "%02X" (Char.code b) in + let address = Some (string_of_int addr) in -let to_asli (opcode_be: string) (addr : int) : ((string list, dis_error) result) = - if cache then ( - let k : (string * int) = (opcode_be, addr) in - DisCache.find_or_add disas_cache k ~default:(fun () -> to_asli_impl opcode_be addr) - ) else (to_asli_impl opcode_be addr) + (* below, opnum is the numeric opcode (necessarily BE) and opcode_* are always LE. *) + (* TODO: change argument of to_asli to follow this convention. *) + let opnum = Int32.to_int String.(get_int32_be opcode_be 0) in + let opnum_str = Printf.sprintf "0x%08lx" Int32.(of_int opnum) in -let do_module (m: Module.t): Module.t = + let opcode_list : char list = List.(rev @@ of_seq @@ String.to_seq opcode_be) in + let opcode_str = String.concat " " List.(map p_byte opcode_list) in + let _opcode : bytes = Bytes.of_seq List.(to_seq opcode_list) in + + let do_dis () : ((string list * string list), dis_error) result = + (match Dis.retrieveDisassembly ?address (Lazy.force env) (Dis.build_env (Lazy.force env)) opnum_str with + | res -> + decode_instr_success := !decode_instr_success + 1 ; + Ok (List.map p_raw res, List.map p_pretty res) + | exception exc -> + Printf.eprintf + "error during aslp disassembly (unsupported opcode %s, bytes %s):\n\nException : %s\n" + opnum_str opcode_str (Printexc.to_string exc); + decode_instr_fail := !decode_instr_fail + 1 ; + (* Printexc.print_backtrace stderr; *) + Error { + opcode = opnum_str; + error = (Printexc.to_string exc) + } + ) + in Result.map fst (do_dis ()) + + let to_asli ?(cache=true) (opcode_be: string) (addr : int) : ((string list, dis_error) result) = + if cache then ( + let k : (string * int) = (opcode_be, addr) in + DisCache.find_or_add disas_cache k ~default:(fun () -> to_asli_impl opcode_be addr) + ) else (to_asli_impl opcode_be addr) + +end + + +module Server = struct + + let shutdown = ref false + + let rec respond (ic: Lwt_io.input_channel) (oc:Lwt_io.output_channel) : unit Lwt.t = + + let stop () = + let* () = Lwt_io.close ic in + let* () = Lwt_io.close oc in + Lwt.return () + in + if (Lwt_io.is_closed ic || Lwt_io.is_closed oc || !shutdown) + then stop () + else + let* r: Rpc.msg_call = Lwt.catch (fun () -> Lwt_io.read_value ic) (function + | exn -> + let* () = stop () in + Lwt.fail exn + ) + in + Rpc.message_count := !Rpc.message_count + 1 ; + let* () = match r with + | Shutdown -> + shutdown := true ; + stop () + | Lift {addr; opcode_be} -> + let lifted : opcode_sem = InsnLifter.to_asli opcode_be addr in + let resp : Rpc.msg_resp = Ok lifted in + Lwt_io.write_value oc resp + | LiftAll (ops) -> + let lifted = List.map (fun (op, addr) -> InsnLifter.to_asli op addr) ops in + let resp : Rpc.msg_resp = All lifted in + Lwt_io.write_value oc resp + in + respond ic oc + + and handle_conn (addr: Lwt_unix.sockaddr) ((ic: Lwt_io.input_channel) , (oc:Lwt_io.output_channel)) = + Lwt.catch (fun () -> respond ic oc) (function + | End_of_file -> (let* () = Lwt_io.close ic in let* () = Lwt_io.close oc; in Lwt.return ()) + | x -> Lwt_io.printf "%s" (Printexc.to_string x) + ) + + + let server = lazy (Lwt_io.establish_server_with_client_address Rpc.sockaddr handle_conn) + + + let rec run_server () = + if !shutdown + then + Lwt.return () + else + let* () = Lwt_io.printf "Decoded %d instructions (%d failure) (%f cache hit rate) (%d messages)\n" + !InsnLifter.decode_instr_success !InsnLifter.decode_instr_fail (InsnLifter.DisCache.hit_rate InsnLifter.disas_cache) !Rpc.message_count + in + let* () = Lwt_unix.sleep 5.0 in + run_server () + + let start_server () = + let start = + let* _ = Lwt.return ( + let* r = Lwt.return (Lazy.force InsnLifter.env) in + let* m = Lwt.return ((Mtime.Span.to_float_ns (Mtime_clock.elapsed ())) /. 1000000000.0) in + Lwt_io.printf "Initialiesd lifter environment in %f seconds\n" m + ) in + let* s = Lazy.force server in + let* _ = Lwt_io.printf "Serving on domain socket GTIRB_SEM_SOCKET=%s\n" Rpc.sockfpath in + + Lwt_unix.on_signal + Sys.sigint + (fun _ -> exit 0) + |> ignore; + + Lwt_main.at_exit (fun () -> begin + print_endline "shutdown server" ; + (Lwt_io.shutdown_server s) + end + ); + (run_server ()) + in Lwt_main.run start + +end + +module Client = struct + open Lwt + + let connection = lazy ( Lwt_io.open_connection Rpc.sockaddr ) + + let cin () = let* (ic,oc) = Lazy.force connection in + if (Lwt_io.is_closed ic) then (failwith "connection (in) closed") ; + return ic + let cout () = let* (ic,oc) = Lazy.force connection in + if (Lwt_io.is_closed oc) then (failwith "connection (out) closed") ; + return oc + + let shutdown_server () = + let* cout = cout () in + let m : Rpc.msg_call = Shutdown in + Lwt_io.write_value cout m + + let lift (opcode_be: string) (addr : int) = + let* cout = cout() in + let* cin = cin() in + let cm : Rpc.msg_call = Lift {opcode_be; addr} in + let*() = Lwt_io.write_value cout cm in + let* resp : Rpc.msg_resp = Lwt_io.read_value cin in + match resp with + | Ok x -> return x + | All x -> Lwt.fail_with "did not expect multi response" + + let lift_one (opcode_be: string) (addr : int) = + Lwt_main.run (lift opcode_be addr) + + let lift_multi (opcodes: (string * int) list) : opcode_sem list Lwt.t = + let* lift_m = + let* cout = cout() in let* cin = cin() in + let cm : Rpc.msg_call = LiftAll opcodes in + let*() = Lwt_io.write_value cout cm in + let* resp : Rpc.msg_resp = Lwt_io.read_value cin in + match resp with + | All x -> return x + | Ok x -> return [x] + in + let* _ = Lwt_list.iter_s (function + | Ok x -> InsnLifter.decode_instr_success := !InsnLifter.decode_instr_success + 1; Lwt.return () ; + | Error ({opcode; error}) -> ( + InsnLifter.decode_instr_fail := !InsnLifter.decode_instr_fail + 1; + Lwt_io.printf "Lift error : %s :: %s\n" opcode error; + ) + ) lift_m + in + return lift_m +end + +let do_module (m: Module.t): Module.t Lwt.t = let all_sects = m.sections in let intervals = List.flatten @@ List.map (fun (s : Section.t) -> s.byte_intervals) all_sects in @@ -179,21 +358,34 @@ let do_module (m: Module.t): Module.t = let need_flip = m.byte_order = ByteOrder.LittleEndian in let rblocks = List.map (do_block ~need_flip) cblocks in - Printexc.record_backtrace true; (* Evaluate each instruction one by one with a new environment for each *) - let rec asts opcodes addr = + + let rec ops opcodes addr = match opcodes with | [] -> [] - | h :: t -> (to_asli (String.of_bytes h) addr) :: (asts t (addr + opcode_length)) + | h :: t -> ((String.of_bytes h), addr) :: (ops t (addr + opcode_length)) + in + let asts opcodes addr = if (!client) then (Client.lift_multi (ops opcodes addr)) + else begin + let rec getasts opcodes addr = + match opcodes with + | [] -> [] + | h :: t -> (InsnLifter.to_asli (String.of_bytes h) addr) :: (getasts t (addr + opcode_length)) + in Lwt.return @@ getasts opcodes addr + end in - (* let map' f l = + + (* + let map' f l = if List.length blk_orded > 10000 then Parmap.parmap ~ncores:2 f Parmap.(L l) else map f l in *) - let with_asts = List.map (fun b -> { + let* with_asts = Lwt_list.map_p (fun b -> + let* asts = asts b.opcodes b.address in + Lwt.return { auuid = b.ruuid; - asts = asts b.opcodes b.address; + asts = asts; }) rblocks in @@ -235,17 +427,11 @@ let do_module (m: Module.t): Module.t = let new_aux = ast_aux serialisable in let full_auxes = (aux_key, Some new_aux) :: orig_auxes in let mod_fixed = {m with aux_data = full_auxes} in - mod_fixed + Lwt.return mod_fixed -(* MAIN *) -let () = - (* BEGINNING *) - Arg.parse speclist handle_rest_arg usage_message; - (* Printf.eprintf "gtirb-semantics: %s -> %s\n" !in_file !out_file; *) - if !count_pos_args <> 2 then - (output_string stderr usage_message; exit 1); - +let gtirb_to_gts () : unit = + let bt = Sys.time() in (* Read bytes from the file, skip first 8 *) let bytes = let ic = open_in_bin !in_file in @@ -270,21 +456,40 @@ let () = Printf.sprintf "%s%s" "Could not reply request: " (Ocaml_protoc_plugin.Result.show_error e) ) in - let bt = Sys.time() in - let modules' = List.map do_module ir.modules in + let modules' = Lwt_main.run @@ Lwt_list.map_p do_module ir.modules in let new_ir = {ir with modules = modules'} in let serial = IR.to_proto new_ir in let encoded = Writer.contents serial in - let et = Sys.time() in - let time_delta = et -. bt in (* Reserialise to disk *) let out = open_out_bin !out_file in output_string out encoded; close_out out; - Printf.printf "Decoded %d instructions in %f seconds (%d failure) (%f cache hit rate)\n" - !decode_instr_success time_delta !decode_instr_fail (DisCache.hit_rate disas_cache); + let et = Sys.time () in + let usr_time_delta = et -. bt in + let time_delta = Float.div (Mtime.Span.to_float_ns (Mtime_clock.elapsed ())) (1000000000.0) in + Printf.printf "Lifted %d instructions in %f sec (%f user time) (%d failure) (%f cache hit rate)\n" + !InsnLifter.decode_instr_success time_delta usr_time_delta !InsnLifter.decode_instr_fail (InsnLifter.DisCache.hit_rate InsnLifter.disas_cache) + + +(* MAIN *) +let () = + (* BEGINNING *) + Arg.parse speclist handle_rest_arg usage_message; + (* Printf.eprintf "gtirb-semantics: %s -> %s\n" !in_file !out_file; *) + if (not !serve) && (not !shutdown_server) && !count_pos_args <> 2 then + (output_string stderr usage_message; exit 1); + + if (!shutdown_server) then begin + Lwt_main.run @@ Client.shutdown_server () + end + else + if (!serve) then begin + Server.start_server () + end else begin + output_string stdout "Lifting\n" ; + gtirb_to_gts () + end - diff --git a/dune-project b/dune-project index d91bfc1..ad73d7a 100644 --- a/dune-project +++ b/dune-project @@ -19,7 +19,7 @@ (name gtirb_semantics) (synopsis "Add semantic information to the IR of a disassembled ARM64 binary") (description "A longer description") - (depends ocaml dune yojson asli (ocaml-protoc-plugin (>= 6.1.0)) base64 janestreet_lru_cache) + (depends ocaml dune yojson asli (ocaml-protoc-plugin (>= 6.1.0)) base64 janestreet_lru_cache lwt lwt_ppx mtime) (tags (decompilers instruction-lifters static-analysis))) diff --git a/gtirb_semantics.opam b/gtirb_semantics.opam index dac3dcf..085e664 100644 --- a/gtirb_semantics.opam +++ b/gtirb_semantics.opam @@ -17,6 +17,9 @@ depends: [ "ocaml-protoc-plugin" {>= "6.1.0"} "base64" "janestreet_lru_cache" + "lwt" + "lwt_ppx" + "mtime" "odoc" {with-doc} ] build: [