Skip to content

Commit

Permalink
Add case-insensitive table lookups, and use for host and content-type. (
Browse files Browse the repository at this point in the history
#236)

A new utility function called `key` takes a table and a key. It loops
through the table items, comparing the each key with the given search
key in a case-insensitive manner. Return the first matching key;
otherwise, return the given search key. This allows operations like
these to be performed:

```lua
local t = {s=37, S=73}
t[utils.key(t,'a')] = 1  -- Insert a:1              t == {s=37,S=73,a=1}
y = t[utils.key(t,'a')]  -- Find 'a'                y == 1
z = t[utils.key(t,'A')]  -- Find 'A' (same as 'a')  z == 1
m = t[utils.key(t,'b')]  -- Show 'b' is missing.    m == nil
k = utils.key(t,'s')     -- Which 's/S' is first?   k is indeterminate
l = t['s']               -- Get 's' and 'S'         l = 37
u = t['S']               -- in the usual manner.    u = 73
t[utils.key(t,'a')] = nil  -- Delete 'a'            t == {s=37, S=73}
```

As implied by the `k = ` example above, lua associative tables are
unordered, so there's no guarantee that 's' (or 'S') is the first one to
be found. In this context, that shouldn't be too much of an issue.

Use this new function to find the 'Host' and 'Content-Type' headers'
values no matter in what case they were defined.
  • Loading branch information
PhilRunninger authored Oct 10, 2023
1 parent 8b62563 commit 5bcaa10
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 22 deletions.
11 changes: 3 additions & 8 deletions lua/rest-nvim/curl/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,9 @@ local function create_callback(curl_cmd, opts)
return
end
local res_bufnr = M.get_or_create_buf()
local content_type = nil

-- get content type
for _, header in ipairs(res.headers) do
if string.lower(header):find("^content%-type") then
content_type = header:match("application/([-a-z]+)") or header:match("text/(%l+)")
break
end
local content_type = res.headers[utils.key(res.headers,'content-type')]
if content_type then
content_type = content_type:match("application/([-a-z]+)") or content_type:match("text/(%l+)")
end

if script_str ~= nil then
Expand Down
8 changes: 1 addition & 7 deletions lua/rest-nvim/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,7 @@ local function splice_body(headers, payload)
else
lines = payload.body_tpl
end
local content_type = ""
for key, val in pairs(headers) do
if string.lower(key) == "content-type" then
content_type = val
break
end
end
local content_type = headers[utils.key(headers,"content-type")] or ""
local has_json = content_type:find("application/[^ ]*json")

local body = ""
Expand Down
9 changes: 3 additions & 6 deletions lua/rest-nvim/request/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,9 @@ M.buf_get_request = function(bufnr, curpos)

local curl_args, body_start = get_curl_args(bufnr, headers_end, end_line)

if headers["host"] ~= nil then
headers["host"] = headers["host"]:gsub("%s+", "")
headers["host"] = string.gsub(headers["host"], "%s+", "")
parsed_url.url = headers["host"] .. parsed_url.url
headers["host"] = nil
end
local host = headers[utils.key(headers,"host")] or ""
parsed_url.url = host:gsub("%s+", "") .. parsed_url.url
headers[utils.key(headers,"host")] = nil

local body = get_body(bufnr, body_start, end_line)

Expand Down
14 changes: 14 additions & 0 deletions lua/rest-nvim/utils/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ M.get_file_variables = function()
end
return variables
end

-- Gets the variables from the currently selected env_file
M.get_env_variables = function()
local variables = {}
Expand Down Expand Up @@ -286,6 +287,19 @@ M.has_value = function(tbl, str)
return false
end

-- key returns the provided table's key that matches the given case-insensitive pattern.
-- if not found, return the given key.
-- @param tbl Table to iterate over
-- @param key The key to be searched in the table
M.key = function(tbl, key)
for tbl_key, _ in pairs(tbl) do
if string.lower(tbl_key) == string.lower(key) then
return tbl_key
end
end
return key
end

-- tbl_to_str recursively converts the provided table into a json string
-- @param tbl Table to convert into a String
-- @param json If the string should use a key:value syntax
Expand Down
2 changes: 1 addition & 1 deletion tests/get_with_host.http
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
###

GET /api/users?page=5
Host: https://reqres.in
host: https://reqres.in

###

Expand Down
1 change: 1 addition & 0 deletions tests/main_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ describe("rest testing framework", function()
assert(rest.run_file("tests/basic_get.http", opts) == true)
assert(rest.run_file("tests/post_json_form.http", opts) == true)
assert(rest.run_file("tests/post_create_user.http", opts) == true)
assert(rest.run_file("tests/get_with_host.http", opts) == true)
assert(rest.run_file("tests/put_update_user.http", opts) == true)
assert(rest.run_file("tests/patch_update_user.http", opts) == true)
assert(rest.run_file("tests/delete.http", opts) == true)
Expand Down

0 comments on commit 5bcaa10

Please sign in to comment.