Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom docs UI fixes #126

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
7 changes: 4 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ jobs:
run: |
pre-commit run --all-files
- name: Test with pytest
# config.py won't import if schema.yaml doesn't exist yet
run: |
cp examples/mouse_and_cheese/schema.yaml .
python generate.py copy_server_files examples/mouse_and_cheese
pytest tests/test_app.py
cp examples/lpoc/schema.yaml .
python generate.py copy_server_files "examples/mouse_and_cheese"
pytest tests/
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,21 @@ parameters:
discrete: True
tuple: False
api_key: "1234567asdfgh"
project_id: 1284
model_id: 2817
```

With this configuration of the policy server, the user will only have access to one predictive endpoint, namely
`predict_raw/`, which requires users to send a JSON with the following structure `{"obs": []}` where the list elements
have to have precisely the right cardinality and ordering. In other words, this endpoint strips all validation that
the other endpoints have, but is quicker to set up due to not having to specify the structure of observations.

Note that if `discrete` is `True` is for models with all discrete actions, while setting this flag to `False` means
that the policy emits continuous actions.

Providing `project_id` and `model_id` is optional. If you provide both, the `/docs` and `/redoc` endpoints will have
feature a link to go back to the respective Pathmind experiment this policy came from.

### Starting the app

Once you have both `saved_model.zip` and `schema.yaml` ready, you can start the policy server like this:
Expand Down
37 changes: 29 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,43 @@
from fastapi.responses import FileResponse
from fastapi.security.api_key import APIKey
from ray import serve
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
from ray.rllib.offline.json_writer import JsonWriter

import config
from api import Action, Observation, RawObservation
from docs import get_redoc_html, get_swagger_ui_html
from generate import CLI
from offline import EpisodeCache
from security import get_api_key

cache = EpisodeCache()
batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder
writer = JsonWriter(config.EXPERIENCE_LOCATION)

url_path = config.parameters.get("url_path")

app = FastAPI(root_path=f"/{url_path}") if url_path else FastAPI()
app = (
FastAPI(root_path=f"/{url_path}", docs_url=None, redoc_url=None)
if url_path
else FastAPI(docs_url=None, redoc_url=None)
)


@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
openapi_url="/openapi.json",
title="Pathmind Policy Server",
swagger_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com",
project_id=config.project_id,
model_id=config.model_id,
)


@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
return get_redoc_html(
openapi_url="/openapi.json",
title="Pathmind Policy Server",
redoc_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com",
project_id=config.project_id,
model_id=config.model_id,
)


tags_metadata = [
{
Expand Down
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def base_path(local_file):
parameters = schema.get("parameters")
action_type = int if parameters.get("discrete") else float

model_id = parameters.get("model_id", None)
project_id = parameters.get("project_id", None)

payload_data = {}
# If the schema includes `max_items` set the constraints for the array
if observations:
Expand Down
208 changes: 208 additions & 0 deletions docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import json
from typing import Optional

from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse


def get_swagger_ui_html(
*,
openapi_url: str,
title: str,
swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js",
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Optional[str] = None,
init_oauth: Optional[dict] = None,
project_id: Optional[int] = None,
model_id: Optional[int] = None,
) -> HTMLResponse:

html = f"""
<!DOCTYPE html>
<html>
<head>
<link type="text/css" rel="stylesheet" href="{swagger_css_url}">
<link rel="shortcut icon" href="{swagger_favicon_url}">
<title>{title}</title>
</head>
<body>
"""

if project_id and model_id:
html += f"""
<div
style="padding: 0 20px; box-sizing: border-box; margin: 0 auto; max-width: 1460px;
width: 100%; margin-top: 20px">
<a href=https://app.pathmind.com/project/{project_id}/model/{model_id}
style="color: black; font-family: sans-serif; font-size:18px;"
>← Back to Pathmind experiment</a>
</div>
"""

html += f"""
<div id="swagger-ui">
</div>
<script src="{swagger_js_url}"></script>
<!-- `SwaggerUIBundle` is now available on the page -->
<script>
const ui = SwaggerUIBundle({{
url: '{openapi_url}',
"""

if oauth2_redirect_url:
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"

html += """
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true
})"""

if init_oauth:
html += f"""
ui.initOAuth({json.dumps(jsonable_encoder(init_oauth))})
"""

html += """
</script>
</body>
</html>
"""
return HTMLResponse(html)


def get_redoc_html(
*,
openapi_url: str,
title: str,
redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js",
redoc_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
with_google_fonts: bool = True,
project_id: Optional[int] = None,
model_id: Optional[int] = None,
) -> HTMLResponse:
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>{title}</title>
<!-- needed for adaptive design -->
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
"""
if with_google_fonts:
html += """
<link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
"""
html += f"""
<link rel="shortcut icon" href="{redoc_favicon_url}">
<!--
ReDoc doesn't change outer page styles
-->
<style>
body {{
margin: 0;
padding: 0;
}}
</style>
</head>
<body>
"""

if project_id and model_id:
html += f"""
<div
style="margin: 15px">
<a href=https://app.pathmind.com/project/{project_id}/model/{model_id}
style="color: black; font-family: sans-serif; font-size:18px;"
>← Back to Pathmind experiment</a>
</div>
"""

html += f"""
<redoc spec-url="{openapi_url}"></redoc>
<script src="{redoc_js_url}"> </script>
</body>
</html>
"""
return HTMLResponse(html)


def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse:
html = """
<!DOCTYPE html>
<html lang="en-US">
<body onload="run()">
</body>
</html>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;

if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1);
} else {
qp = location.search.substring(1);
}

arr = qp.split("&")
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';})
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value)
}
) : {}

isValid = qp.state === sentState

if ((
oauth2.auth.schema.get("flow") === "accessCode"||
oauth2.auth.schema.get("flow") === "authorizationCode"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server Passed state wasn't returned from auth server"
});
}

if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}

oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server"
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
</script>
"""
return HTMLResponse(content=html)
33 changes: 21 additions & 12 deletions tests/test_api_speed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""This assumes the server is fully configured for the LPoC example"""
import timeit

import requests
import ray
from fastapi.testclient import TestClient

data = {
from app import app
from generate import CLI

CLI.copy_server_files("examples/lpoc")
client = TestClient(app)


payload = {
"coordinates": [1, 1],
"has_core": True,
"has_down_neighbour": True,
Expand All @@ -21,15 +28,17 @@


def predict():
return requests.post(
"https://localhost:8080/api/predict",
verify=False,
auth=("foo", "bar"),
json=data,
return client.post(
"http://localhost:8000/predict/",
json=payload,
headers={"access-token": "1234567asdfgh"},
)


predict()
number = 1000
res = timeit.timeit(predict, number=1000)
print(f"A total of {number} requests took {res} milliseconds to process on average.")
def test_predict():
number = 1000
res = timeit.timeit(predict, number=number)
print(
f"A total of {number} requests took {res} milliseconds to process on average."
)
ray.shutdown()
4 changes: 4 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
client = TestClient(app)


def setup_function():
ray.shutdown()


def test_health_check():
response = client.get("/")
assert response.status_code == 200
Expand Down
Loading