Skip to content

Commit

Permalink
Merge branch 'master' into UsageDashboard_timezone_error
Browse files Browse the repository at this point in the history
  • Loading branch information
clr-li committed Nov 6, 2023
2 parents d75565c + 781fa7a commit cba5a51
Show file tree
Hide file tree
Showing 23 changed files with 371 additions and 231 deletions.
8 changes: 7 additions & 1 deletion daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,13 @@ def check_wav_audio_format(filename: str) -> bool:
filename,
]
print("\t$ " + " ".join(args))
data = json.loads(subprocess.check_output(args))
try:
data = json.loads(subprocess.check_output(args, stderr=subprocess.STDOUT))
except subprocess.CalledProcessError as e:
ffmpeg_output_error = ValueError(e.output, e)
raise ValueError(
"Invalid audio file. Please confirm the file is not corrupted and has a supported format (google 'ffmpeg supported audio file types')"
) from ffmpeg_output_error
return (
len(data["streams"]) == 1
and data["streams"][0]["codec_name"] == "pcm_s16le"
Expand Down
189 changes: 140 additions & 49 deletions daras_ai_v2/azure_doc_extract.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
import csv
import io
import re
import typing
from time import sleep

import requests
from furl import furl
from tabulate import tabulate
from jinja2.lexer import whitespace_re

from daras_ai_v2 import settings
from daras_ai_v2.redis_cache import redis_cache_decorator
from gooeysite import wsgi

assert wsgi

from time import sleep
from daras_ai_v2.text_splitter import default_length_function

auth_headers = {"Ocp-Apim-Subscription-Key": settings.AZURE_FORM_RECOGNIZER_KEY}


def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"):
result = azure_form_recognizer(pdf_url, model_id)
return [
records_to_text(extract_records(result, page["pageNumber"]))
for page in result["pages"]
]


@redis_cache_decorator
def azure_form_recognizer(pdf_url: str, model_id: str):
r = requests.post(
str(
furl(settings.AZURE_FORM_RECOGNIZER_ENDPOINT)
Expand All @@ -26,19 +37,12 @@ def azure_doc_extract_pages(pdf_url: str, model_id: str = "prebuilt-layout"):
r.raise_for_status()
location = r.headers["Operation-Location"]
while True:
r = requests.get(
location,
headers=auth_headers,
)
r = requests.get(location, headers=auth_headers)
r.raise_for_status()
r_json = r.json()
match r_json.get("status"):
case "succeeded":
result = r_json["analyzeResult"]
return [
records_to_text(extract_records(result, page["pageNumber"]))
for page in result["pages"]
]
return r_json["analyzeResult"]
case "failed":
raise Exception(r_json)
case _:
Expand All @@ -59,11 +63,16 @@ def extract_records(result: dict, page_num: int) -> list[dict]:
outer=table["polygon"], inner=para["boundingRegions"][0]["polygon"]
):
if not table.get("added"):
records.append({"role": "table", "content": table["content"]})
records.append({"role": "csv", "content": table["content"]})
table["added"] = True
break
else:
records.append({"role": para.get("role", ""), "content": para["content"]})
records.append(
{
"role": para.get("role", ""),
"content": strip_content(para["content"]),
}
)
return records


Expand All @@ -81,26 +90,13 @@ def records_to_text(records: list[dict]) -> str:
return ret.strip()


# def table_to_html(table):
# with redirect_stdout(io.StringIO()) as f:
# print("<table>")
# print("<tr>")
# idx = 0
# for cell in table["cells"]:
# if idx != cell["rowIndex"]:
# print("</tr>")
# print("<tr>")
# idx = cell["rowIndex"]
# if cell.get("kind") == "columnHeader":
# tag = "th"
# else:
# tag = "td"
# print(
# f"<{tag} rowspan={cell.get('rowSpan', 1)} colspan={cell.get('columnSpan',1)}>{cell['content'].strip()}</{tag}>"
# )
# print("</tr>")
# print("</table>")
# return f.getvalue()
def rect_contains(*, outer: list[int], inner: list[int]):
tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y = outer
for pt_x, pt_y in zip(inner[::2], inner[1::2]):
# if the point is inside the bounding box, return True
if tl_x <= pt_x <= tr_x and tl_y <= pt_y <= bl_y:
return True
return False


def extract_tables(result, page):
Expand All @@ -111,7 +107,7 @@ def extract_tables(result, page):
continue
except (KeyError, IndexError):
continue
plain = table_to_plain(table)
plain = table_to_csv(table)
table_polys.append(
{
"polygon": table["boundingRegions"][0]["polygon"],
Expand All @@ -122,17 +118,112 @@ def extract_tables(result, page):
return table_polys


def rect_contains(*, outer: list[int], inner: list[int]):
tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y = outer
for pt_x, pt_y in zip(inner[::2], inner[1::2]):
# if the point is inside the bounding box, return True
if tl_x <= pt_x <= tr_x and tl_y <= pt_y <= bl_y:
return True
return False
def table_to_csv(table: dict) -> str:
return table_arr_to_csv(table_to_arr(table))


THEAD = "**"

def table_to_plain(table):
ret = [["" for _ in range(table["columnCount"])] for _ in range(table["rowCount"])]

def table_to_arr(table: dict) -> list[list[str]]:
with open(f"table-{table['columnCount']}.json", "w") as f:
f.write(str(table))
arr = [["" for _ in range(table["columnCount"])] for _ in range(table["rowCount"])]
for cell in table["cells"]:
ret[cell["rowIndex"]][cell["columnIndex"]] = cell["content"].strip()
return tabulate(ret, tablefmt="plain")
for i in range(cell.get("rowSpan", 1)):
row_idx = cell["rowIndex"] + i
for j in range(cell.get("columnSpan", 1)):
col_idx = cell["columnIndex"] + j
content = strip_content(cell["content"])
if cell.get("kind") in ("rowHeader", "columnHeader", "stubHead"):
content = THEAD + content + THEAD
arr[row_idx][col_idx] = content
return arr


# NOTE: These are individual tokens in the gpt-4 vocab, and must be handled with care
THEAD_SEP = "|--"
TROW_END = "|\n"
TROW_SEP = " |"


def table_arr_to_prompt(arr: typing.Iterable[list[str]]) -> str:
text = ""
prev_is_header = True
for row in arr:
is_header = _strip_header_from_row(row)
row = _remove_long_dupe_header(row)
if prev_is_header and not is_header:
text += THEAD_SEP * len(row) + TROW_END
text += TROW_SEP + TROW_SEP.join(row) + TROW_END
prev_is_header = is_header
return text


def table_arr_to_prompt_chunked(
arr: typing.Iterable[list[str]], chunk_size: int
) -> typing.Iterable[str]:
header = ""
chunk = ""
prev_is_header = True
for row in arr:
is_header = _strip_header_from_row(row)
row = _remove_long_dupe_header(row)
if prev_is_header and not is_header:
header += THEAD_SEP * len(row) + TROW_END
next_chunk = TROW_SEP + TROW_SEP.join(row) + TROW_END
if is_header:
header += next_chunk
if default_length_function(header) > chunk_size:
yield header
header = ""
else:
if default_length_function(header + chunk + next_chunk) > chunk_size:
yield header + chunk.rstrip()
chunk = ""
chunk += next_chunk
prev_is_header = is_header
if chunk:
yield header + chunk.rstrip()


def _strip_header_from_row(row):
is_header = False
for i, cell in enumerate(row):
if cell.startswith(THEAD) and cell.endswith(THEAD):
row[i] = cell[len(THEAD) : -len(THEAD)]
is_header = True
return is_header


def _remove_long_dupe_header(row: list[str], cutoff: int = 2) -> list[str]:
r = -1
l = 0
for cell in row[-2::-1]:
if not row[r]:
r -= 1
l -= 1
continue
if cell == row[r]:
l -= 1
else:
break
if -l >= cutoff:
row = row[:l] + [""] * -l
return row


def table_arr_to_csv(arr: typing.Iterable[list[str]]) -> str:
f = io.StringIO()
writer = csv.writer(f)
writer.writerows(arr)
return f.getvalue()


selection_marks_re = re.compile(r":(un)?selected:")


def strip_content(text: str) -> str:
text = selection_marks_re.sub("", text)
text = whitespace_re.sub(" ", text)
return text.strip()
82 changes: 51 additions & 31 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,44 +449,60 @@ def render_author(self):
if self.run_user.display_name:
html += f"<div>{self.run_user.display_name}</div>"
html += "</div>"
st.markdown(
html,
unsafe_allow_html=True,
)

if self.is_current_user_admin():
linkto = lambda: st.link(
to=self.app_url(
tab_name=MenuTabs.paths[MenuTabs.history],
query_params={"uid": self.run_user.uid},
)
)
else:
linkto = st.dummy

with linkto():
st.html(html)

def get_credits_click_url(self):
if self.request.user and self.request.user.is_anonymous:
return "/pricing/"
else:
return "/account/"

def get_submit_container_props(self):
return dict(className="position-sticky bottom-0 bg-white")

def render_submit_button(self, key="--submit-1"):
col1, col2 = st.columns([2, 1], responsive=False)
col2.node.props["className"] += " d-flex justify-content-end align-items-center"
with col1:
st.caption(
f"Run cost = [{self.get_price_roundoff(st.session_state)} credits]({self.get_credits_click_url()}) \\\n"
f"_By submitting, you agree to Gooey.AI's [terms](https://gooey.ai/terms) & [privacy policy](https://gooey.ai/privacy)._ ",
)
additional_notes = self.additional_notes()
if additional_notes:
st.caption(additional_notes)
with col2:
submitted = st.button(
"🏃 Submit",
key=key,
type="primary",
# disabled=bool(st.session_state.get(StateKeys.run_status)),
)
if not submitted:
return False
try:
self.validate_form_v2()
except AssertionError as e:
st.error(e)
return False
else:
return True
with st.div(**self.get_submit_container_props()):
st.write("---")
col1, col2 = st.columns([2, 1], responsive=False)
col2.node.props[
"className"
] += " d-flex justify-content-end align-items-center"
with col1:
st.caption(
f"Run cost = [{self.get_price_roundoff(st.session_state)} credits]({self.get_credits_click_url()}) \\\n"
f"_By submitting, you agree to Gooey.AI's [terms](https://gooey.ai/terms) & [privacy policy](https://gooey.ai/privacy)._ ",
)
additional_notes = self.additional_notes()
if additional_notes:
st.caption(additional_notes)
with col2:
submitted = st.button(
"🏃 Submit",
key=key,
type="primary",
# disabled=bool(st.session_state.get(StateKeys.run_status)),
)
if not submitted:
return False
try:
self.validate_form_v2()
except AssertionError as e:
st.error(e)
return False
else:
return True

def _render_step_row(self):
with st.expander("**ℹ️ Details**"):
Expand Down Expand Up @@ -889,6 +905,8 @@ def _history_tab(self):
)
raise RedirectException(str(redirect_url))
uid = self.request.user.uid
if self.is_current_user_admin():
uid = self.request.query_params.get("uid", uid)

before = gooey_get_query_params().get("updated_at__lt", None)
if before:
Expand Down Expand Up @@ -927,7 +945,9 @@ def _render(sr: SavedRun):
grid_layout(3, run_history, _render)

next_url = (
furl(self._get_current_app_url()) / MenuTabs.paths[MenuTabs.history] / "/"
furl(self._get_current_app_url(), query_params=self.request.query_params)
/ MenuTabs.paths[MenuTabs.history]
/ "/"
)
next_url.query.params.set(
"updated_at__lt", run_history[-1].to_dict()["updated_at"]
Expand Down
Loading

0 comments on commit cba5a51

Please sign in to comment.