From 97c908e0f4d80fb7f8c368119c512ad3c0b799f7 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Fri, 13 Aug 2021 11:27:01 +0200 Subject: [PATCH 01/12] custom docs html --- docs.py | 208 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 docs.py diff --git a/docs.py b/docs.py new file mode 100644 index 0000000..ad25b91 --- /dev/null +++ b/docs.py @@ -0,0 +1,208 @@ +import json +from typing import Optional + +from fastapi.encoders import jsonable_encoder +from starlette.responses import HTMLResponse + + +def get_swagger_ui_html( + *, + openapi_url: str, + title: str, + swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js", + swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css", + swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", + oauth2_redirect_url: Optional[str] = None, + init_oauth: Optional[dict] = None, + project_id: Optional[int] = None, + model_id: Optional[int] = None, +) -> HTMLResponse: + + html = f""" + + + + + + {title} + + + """ + + if project_id and model_id: + html += f""" +
+ ← Back to Pathmind experiment +
+ """ + + html += f""" +
+
+ + + + + + """ + return HTMLResponse(html) + + +def get_redoc_html( + *, + openapi_url: str, + title: str, + redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js", + redoc_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", + with_google_fonts: bool = True, + project_id: Optional[int] = None, + model_id: Optional[int] = None, +) -> HTMLResponse: + html = f""" + + + + {title} + + + + """ + if with_google_fonts: + html += """ + + """ + html += f""" + + + + + + """ + + if project_id and model_id: + html += f""" +
+ ← Back to Pathmind experiment +
+ """ + + html += f""" + + + + + """ + return HTMLResponse(html) + + +def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse: + html = """ + + + + + + + """ + return HTMLResponse(content=html) From 1bfafd58a059f37f403e7e2779311ffaa32849c7 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Fri, 13 Aug 2021 11:27:16 +0200 Subject: [PATCH 02/12] document changes --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 4d22143..628649e 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,8 @@ parameters: discrete: True tuple: False api_key: "1234567asdfgh" + project_id: 1284 + model_id: 2817 ``` With this configuration of the policy server, the user will only have access to one predictive endpoint, namely @@ -109,6 +111,12 @@ With this configuration of the policy server, the user will only have access to have to have precisely the right cardinality and ordering. In other words, this endpoint strips all validation that the other endpoints have, but is quicker to set up due to not having to specify the structure of observations. +Note that if `discrete` is `True` is for models with all discrete actions, while setting this flag to `False` means +that the policy emits continuous actions. + +Providing `project_id` and `model_id` is optional. If you provide both, the `/docs` and `/redoc` endpoints will have +feature a link to go back to the respective Pathmind experiment this policy came from. + ### Starting the app Once you have both `saved_model.zip` and `schema.yaml` ready, you can start the policy server like this: From 469e5158a0374ce2f9bf517fb358bc2531152eeb Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Fri, 13 Aug 2021 13:09:14 +0200 Subject: [PATCH 03/12] change docs, favicon etc. fixes #87, fixes #74 --- app.py | 82 ++++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/app.py b/app.py index 69ed661..2c246cb 100644 --- a/app.py +++ b/app.py @@ -16,8 +16,8 @@ from security import get_api_key from fastapi.responses import FileResponse -from fastapi import Depends, FastAPI, HTTPException, status, Security -from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey +from fastapi import Depends +from fastapi.security.api_key import APIKey from fastapi.openapi.utils import get_openapi from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder @@ -27,7 +27,34 @@ batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder writer = JsonWriter(config.EXPERIENCE_LOCATION) -app = FastAPI() +from fastapi import FastAPI +from docs import get_swagger_ui_html, get_redoc_html + + +app = FastAPI(docs_url=None, redoc_url=None) + + +@app.get("/docs", include_in_schema=False) +def overridden_swagger(): + return get_swagger_ui_html( + openapi_url="/openapi.json", + title="Pathmind Policy Server", + swagger_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com", + project_id=config.project_id, + model_id=config.model_id, + ) + + +@app.get("/redoc", include_in_schema=False) +def overridden_redoc(): + return get_redoc_html( + openapi_url="/openapi.json", + title="Pathmind Policy Server", + redoc_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com", + project_id=config.project_id, + model_id=config.model_id, + ) + tags_metadata = [ { @@ -67,6 +94,7 @@ def custom_openapi(): if config.USE_RAY: + @app.on_event("startup") async def startup_event(): @@ -76,6 +104,7 @@ async def startup_event(): backend_config = serve.BackendConfig(num_replicas=4) from api import PathmindPolicy + client.create_backend("pathmind_policy", PathmindPolicy, config=backend_config) client.create_endpoint("predict", backend="pathmind_policy") @@ -87,25 +116,32 @@ async def startup_event(): # Note: for basic auth, use "logged_in: bool = Depends(verify_credentials)" as parameter @app.post("/predict/", response_model=Action, tags=["Predictions"]) async def predict(payload: Observation, api_key: APIKey = Depends(get_api_key)): - lists = [[getattr(payload, obs)] if not isinstance(getattr(payload, obs), List) else getattr(payload, obs) - for obs in config.observations.keys()] + lists = [ + [getattr(payload, obs)] + if not isinstance(getattr(payload, obs), List) + else getattr(payload, obs) + for obs in config.observations.keys() + ] observations = list(itertools.chain(*lists)) # Note that ray can't pickle the pydantic "Observation" model, so we need to convert it here. return await SERVE_HANDLE.remote(observations) - @app.post("/predict_deterministic/", response_model=Action, tags=["Predictions"]) - async def predict_deterministic(payload: Observation, api_key: APIKey = Depends(get_api_key)): + async def predict_deterministic( + payload: Observation, api_key: APIKey = Depends(get_api_key) + ): return _predict_deterministic(payload) - @app.post("/distribution/", tags=["Predictions"]) - async def distribution(payload: Observation, api_key: APIKey = Depends(get_api_key)): + async def distribution( + payload: Observation, api_key: APIKey = Depends(get_api_key) + ): return _distribution(payload) - @app.post("/collect_experience/", response_model=Action, tags=["Predictions"]) - async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_api_key)): + async def collect_experience( + payload: Experience, api_key: APIKey = Depends(get_api_key) + ): global cache global batch_builder @@ -114,8 +150,12 @@ async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_ rew = payload.reward done = payload.done - lists = [[getattr(observation, obs)] if not isinstance(getattr(observation, obs), List) else getattr(observation, obs) - for obs in config.observations.keys()] + lists = [ + [getattr(observation, obs)] + if not isinstance(getattr(observation, obs), List) + else getattr(observation, obs) + for obs in config.observations.keys() + ] obs = list(itertools.chain(*lists)) obs = np.reshape(np.asarray(obs), (4,)) @@ -128,7 +168,7 @@ async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_ if cache.is_empty(): cache.store(t=0, prev_obs=obs, prev_action=action.actions, prev_reward=rew) - act =action.actions[0] + act = action.actions[0] batch_builder.add_values( agent_index=0, @@ -137,18 +177,16 @@ async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_ action_logp=np.log(action.probability), t=cache.t, eps_id=cache.episode, - prev_actions=cache.prev_action, prev_rewards=cache.prev_reward, obs=cache.prev_obs, # prep.transform(...) - # sent from environment new_obs=obs, dones=done, infos=None, rewards=rew, ) - cache.store(t=cache.t+1, prev_obs=obs, prev_action=act, prev_reward=rew) + cache.store(t=cache.t + 1, prev_obs=obs, prev_action=act, prev_reward=rew) if done: writer.write(batch_builder.build_and_reset()) @@ -164,13 +202,13 @@ async def predict_raw(payload: RawObservation, api_key: APIKey = Depends(get_api @app.get("/clients", tags=["Clients"]) async def clients(api_key: APIKey = Depends(get_api_key)): - shutil.make_archive('clients', 'zip', './clients') - return FileResponse(path='clients.zip', filename='clients.zip') + shutil.make_archive("clients", "zip", "./clients") + return FileResponse(path="clients.zip", filename="clients.zip") @app.get("/schema", tags=["Clients"]) async def server_schema(api_key: APIKey = Depends(get_api_key)): - with open(config.PATHMIND_SCHEMA, 'r') as schema_file: + with open(config.PATHMIND_SCHEMA, "r") as schema_file: schema_str = schema_file.read() schema = yaml.safe_load(schema_str) return schema @@ -184,8 +222,8 @@ async def health_check(): app.openapi = custom_openapi # Write the swagger file locally -with open(config.LOCAL_SWAGGER, 'w') as f: +with open(config.LOCAL_SWAGGER, "w") as f: f.write(json.dumps(app.openapi())) # Generate all clients on startup -CLI.clients() \ No newline at end of file +CLI.clients() From a1e233d3230a0d1232169619302aa7021f092a24 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Fri, 13 Aug 2021 13:09:22 +0200 Subject: [PATCH 04/12] black --- api.py | 48 ++++++++++++++++++++----------- config.py | 7 +++-- examples/rail_drl/preprocessor.py | 5 +--- frontend.py | 24 ++++++++++++---- generate.py | 22 ++++++++++---- security.py | 8 +++--- test_client.py | 18 ++++++------ tests/test_api_speed.py | 40 ++++++++++++-------------- tests/test_predict.py | 10 +++---- utils.py | 2 +- 10 files changed, 109 insertions(+), 75 deletions(-) diff --git a/api.py b/api.py index a038a4c..d46718d 100644 --- a/api.py +++ b/api.py @@ -6,18 +6,13 @@ from fluent import sender from fastapi import HTTPException -logger = sender.FluentSender('policy_server', host='0.0.0.0', port=24224) +logger = sender.FluentSender("policy_server", host="0.0.0.0", port=24224) -RawObservation = create_model( - 'RawObservation', - **{"obs": (List[float], ...)} -) +RawObservation = create_model("RawObservation", **{"obs": (List[float], ...)}) + +Observation = create_model("Observation", **config.payload_data) -Observation = create_model( - 'Observation', - **config.payload_data -) class Experience(BaseModel): observation: Observation @@ -43,11 +38,14 @@ def __init__(self): def __call__(self, request): array = np.asarray(request.data) op = np.reshape(array, (1, array.size)) - tensors = tf.convert_to_tensor(op, dtype=tf.float32, name='observations') + tensors = tf.convert_to_tensor(op, dtype=tf.float32, name="observations") result = self.model( - is_training=self.is_training_tensor, observations=tensors, prev_action=self.prev_action_tensor, - prev_reward=self.prev_reward_tensor, seq_lens=self.seq_lens_tensor + is_training=self.is_training_tensor, + observations=tensors, + prev_action=self.prev_action_tensor, + prev_reward=self.prev_reward_tensor, + seq_lens=self.seq_lens_tensor, ) action_keys = [k for k in result.keys() if "actions" in k] @@ -64,7 +62,14 @@ def __call__(self, request): actions = [config.action_type(x) for x in numpy_tensors] global logger - logger.emit('predict', {'observation': request.data, 'action': actions, 'probability': probability}) + logger.emit( + "predict", + { + "observation": request.data, + "action": actions, + "probability": probability, + }, + ) return Action(actions=actions, probability=probability) @@ -75,6 +80,7 @@ def __call__(self, request): def _predict(payload: Observation): class Dummy: data = None + dummy = Dummy() dummy.data = payload return pm(dummy) @@ -85,9 +91,13 @@ def _predict_deterministic(payload: Observation): to restore the agent. Not in itself a problem, just less convenient compared to what we have now (don't need big JARs hanging around).""" if not config.parameters.get("discrete"): - raise HTTPException(status_code=405, detail="Endpoint only available for discrete actions") + raise HTTPException( + status_code=405, detail="Endpoint only available for discrete actions" + ) if config.parameters.get("tuple"): - raise HTTPException(status_code=405, detail="Endpoint only available for non-tuple scenarios") + raise HTTPException( + status_code=405, detail="Endpoint only available for non-tuple scenarios" + ) max_action = None max_prob = 0.0 @@ -106,9 +116,13 @@ def _predict_deterministic(payload: Observation): def _distribution(payload: Observation): if not config.parameters.get("discrete"): - raise HTTPException(status_code=405, detail="Endpoint only available for discrete actions") + raise HTTPException( + status_code=405, detail="Endpoint only available for discrete actions" + ) if config.parameters.get("tuple"): - raise HTTPException(status_code=405, detail="Endpoint only available for non-tuple scenarios") + raise HTTPException( + status_code=405, detail="Endpoint only available for non-tuple scenarios" + ) distro_dict = {} found_all_actions = False trials = 0 diff --git a/config.py b/config.py index ddf08f8..44cb993 100644 --- a/config.py +++ b/config.py @@ -17,7 +17,7 @@ def base_path(local_file): # If you put BASE_PATH on your PATH we use that, otherwise the current working directory. -BASE_PATH = os.environ.get('BASE_PATH', os.path.expanduser(".")) +BASE_PATH = os.environ.get("BASE_PATH", os.path.expanduser(".")) PATHMIND_POLICY = base_path("saved_model.zip") PATHMIND_SCHEMA = base_path("schema.yaml") @@ -41,8 +41,9 @@ def base_path(local_file): parameters = schema.get("parameters") action_type = int if parameters.get("discrete") else float +model_id = parameters.get("model_id", None) +project_id = parameters.get("project_id", None) + payload_data = {} if observations: payload_data = {k: (v.get("type"), ...) for k, v in observations.items()} - - diff --git a/examples/rail_drl/preprocessor.py b/examples/rail_drl/preprocessor.py index 761969a..418144d 100644 --- a/examples/rail_drl/preprocessor.py +++ b/examples/rail_drl/preprocessor.py @@ -1,5 +1,2 @@ def preprocess(obs): - return [ - obs.get("speed_trainA"), - obs.get("speed_trainB") - ] + return [obs.get("speed_trainA"), obs.get("speed_trainB")] diff --git a/frontend.py b/frontend.py index a5be1dc..d7aa0f0 100644 --- a/frontend.py +++ b/frontend.py @@ -81,7 +81,9 @@ def generate_frontend_from_observations(schema: dict): validate_list_items(values, val) check_min_max_items(values, val) else: - raise Exception(f"Unsupported data type {prop_type} in YAML schema for model server.") + raise Exception( + f"Unsupported data type {prop_type} in YAML schema for model server." + ) result[key] = val return result @@ -91,21 +93,31 @@ def validate_list_items(values, val): if items_type == "integer": assert all(isinstance(v, int) for v in val), "All list items must be integers." elif items_type == "number": - assert all(isinstance(v, float) for v in val), "All list items must be floating point numbers." + assert all( + isinstance(v, float) for v in val + ), "All list items must be floating point numbers." def check_min_max_items(values, val): if "minItems" in values.keys(): - assert len(val) >= values.get("minItems"), f"Array too small, expected at least {values.get('minItems')} items." + assert len(val) >= values.get( + "minItems" + ), f"Array too small, expected at least {values.get('minItems')} items." if "maxItems" in values.keys(): - assert len(val) <= values.get("maxItems"), f"Array too large, expected at most {values.get('maxItems')} items." + assert len(val) <= values.get( + "maxItems" + ), f"Array too large, expected at most {values.get('maxItems')} items." def check_min_max(values, val): if "minimum" in values.keys(): - assert val >= values.get("minimum"), f"value {val} too small, expected a minimum of {values.get('minimum')}." + assert val >= values.get( + "minimum" + ), f"value {val} too small, expected a minimum of {values.get('minimum')}." if "maximum" in values.keys(): - assert val <= values.get("maximum"), f"value {val} too large, expected a maximum of {values.get('maximum')}." + assert val <= values.get( + "maximum" + ), f"value {val} too large, expected a maximum of {values.get('maximum')}." if __name__ == "__main__": diff --git a/generate.py b/generate.py index 37ec7be..85123ea 100644 --- a/generate.py +++ b/generate.py @@ -5,6 +5,7 @@ import subprocess from utils import safe_remove, unzip + class CLI: """Simple wrapper class to expose to "fire" to auto-generate a command line interface for this project. @@ -20,7 +21,7 @@ def clients(): generate_client_for("python") generate_client_for("java") generate_client_for("scala") - #generate_client_for("r") + # generate_client_for("r") @staticmethod def copy_server_files(path): @@ -44,8 +45,19 @@ def clean(): def generate_client_for(lang): os.makedirs(f"./clients/{lang}") - subprocess.run(["swagger-codegen", "generate", "-i", config.LOCAL_SWAGGER, "-l", lang, "-o", f"clients/{lang}"]) - - -if __name__ == '__main__': + subprocess.run( + [ + "swagger-codegen", + "generate", + "-i", + config.LOCAL_SWAGGER, + "-l", + lang, + "-o", + f"clients/{lang}", + ] + ) + + +if __name__ == "__main__": fire.Fire(CLI) diff --git a/security.py b/security.py index 954e6fb..1e8398d 100644 --- a/security.py +++ b/security.py @@ -20,9 +20,9 @@ async def get_api_key( - api_key_query: str = Security(api_key_query), - api_key_header: str = Security(api_key_header), - api_key_cookie: str = Security(api_key_cookie), + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), + api_key_cookie: str = Security(api_key_cookie), ): if api_key_query == API_KEY: return api_key_query @@ -45,4 +45,4 @@ def verify_credentials(credentials: HTTPBasicCredentials = Depends(basic_auth)): detail="Incorrect email or password", headers={"WWW-Authenticate": "Basic"}, ) - return True \ No newline at end of file + return True diff --git a/test_client.py b/test_client.py index 60f6d7e..ff64135 100644 --- a/test_client.py +++ b/test_client.py @@ -44,7 +44,7 @@ def get_observation(self) -> typing.List[float]: float(self.mouse[1]) / 5.0, abs(self.cheese[0] - self.mouse[0]) / 5.0, abs(self.cheese[1] - self.mouse[1]) / 5.0, - ] + ] def get_reward(self) -> float: return 1 if self.mouse == self.cheese else 0 @@ -65,11 +65,7 @@ def get_payload(obs, reward, done): "mouse_row_dist": obs[2], "mouse_col_dist": obs[3], } - payload = { - "observation": obs_dict, - "reward": reward, - "done": done - } + payload = {"observation": obs_dict, "reward": reward, "done": done} return payload @@ -79,10 +75,14 @@ def get_payload(obs, reward, done): done = False while not done: payload = get_payload(obs, reward, done) - response = requests.post("http://localhost:8000/collect_experience/", json=payload, auth=auth).json() + response = requests.post( + "http://localhost:8000/collect_experience/", json=payload, auth=auth + ).json() action = response.get("actions")[0] obs, reward, done, info = env.step(action) if done: payload = get_payload(obs, reward, done) - response = requests.post("http://localhost:8000/collect_experience/", json=payload, auth=auth).json() - print(">>> Episode complete.") \ No newline at end of file + response = requests.post( + "http://localhost:8000/collect_experience/", json=payload, auth=auth + ).json() + print(">>> Episode complete.") diff --git a/tests/test_api_speed.py b/tests/test_api_speed.py index d2e0000..33f4b80 100644 --- a/tests/test_api_speed.py +++ b/tests/test_api_speed.py @@ -4,34 +4,32 @@ data = { - "coordinates": [ - 1, - 1 - ], - "has_core": True, - "has_down_neighbour": True, - "has_left_neighbour": True, - "has_right_neighbour": True, - "has_up_neighbour": True, - "id": 0, - "is_down_free": True, - "is_left_free": False, - "is_right_free": True, - "is_up_free": True, - "name": "pt07_c", - "target": [ - 6, - 2 - ] + "coordinates": [1, 1], + "has_core": True, + "has_down_neighbour": True, + "has_left_neighbour": True, + "has_right_neighbour": True, + "has_up_neighbour": True, + "id": 0, + "is_down_free": True, + "is_left_free": False, + "is_right_free": True, + "is_up_free": True, + "name": "pt07_c", + "target": [6, 2], } def predict(): - return requests.post("https://localhost:8080/api/predict", verify=False, auth=("foo", "bar"), json=data) + return requests.post( + "https://localhost:8080/api/predict", + verify=False, + auth=("foo", "bar"), + json=data, + ) predict() number = 1000 res = timeit.timeit(predict, number=1000) print(f"A total of {number} requests took {res} milliseconds to process on average.") - diff --git a/tests/test_predict.py b/tests/test_predict.py index b24f760..aa3cc91 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -2,10 +2,10 @@ import json payload = { - 'mouse_row': 1, - 'mouse_col': 1, - 'mouse_row_dist': 1, - 'mouse_col_dist': 1, + "mouse_row": 1, + "mouse_col": 1, + "mouse_row_dist": 1, + "mouse_col_dist": 1, } @@ -20,5 +20,5 @@ def predict(): res = requests.get("http://localhost:8000/openapi.json") -with open("../openapi.json", 'w') as f: +with open("../openapi.json", "w") as f: f.write(res.content.decode("utf-8")) diff --git a/utils.py b/utils.py index 412ed34..d013520 100644 --- a/utils.py +++ b/utils.py @@ -28,6 +28,6 @@ def unzip(local_file): """Unzip a file if it has a zip extension, otherwise leave as is.""" extension = get_extension(local_file) if extension == "zip": - with zipfile.ZipFile(local_file, 'r') as zip_ref: + with zipfile.ZipFile(local_file, "r") as zip_ref: zip_ref.extractall(MODEL_FOLDER) print(">>> Successfully unzipped model.") From 891b7eac587e364faac28fe0d370cd62d59a1813 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Fri, 13 Aug 2021 14:04:25 +0200 Subject: [PATCH 05/12] black formatting --- app.py | 8 +++++++- tests/test_api_speed.py | 4 +++- tests/test_predict.py | 1 - 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/app.py b/app.py index ac2f8c6..349cb0f 100644 --- a/app.py +++ b/app.py @@ -23,7 +23,12 @@ url_path = config.parameters.get("url_path") -app = FastAPI(root_path=f"/{url_path}", docs_url=None, redoc_url=None) if url_path else FastAPI(docs_url=None, redoc_url=None) +app = ( + FastAPI(root_path=f"/{url_path}", docs_url=None, redoc_url=None) + if url_path + else FastAPI(docs_url=None, redoc_url=None) +) + @app.get("/docs", include_in_schema=False) def overridden_swagger(): @@ -46,6 +51,7 @@ def overridden_redoc(): model_id=config.model_id, ) + tags_metadata = [ { "name": "Predictions", diff --git a/tests/test_api_speed.py b/tests/test_api_speed.py index 7479784..a5f1de0 100644 --- a/tests/test_api_speed.py +++ b/tests/test_api_speed.py @@ -37,5 +37,7 @@ def predict(): def test_predict(): number = 1000 res = timeit.timeit(predict, number=number) - print(f"A total of {number} requests took {res} milliseconds to process on average.") + print( + f"A total of {number} requests took {res} milliseconds to process on average." + ) ray.shutdown() diff --git a/tests/test_predict.py b/tests/test_predict.py index 1bad269..463b9eb 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -37,4 +37,3 @@ def test_write_openapi_json(): f.write(res.content.decode("utf-8")) ray.shutdown() - From 74a932ae0dc43053988408de07ef736703c9efad Mon Sep 17 00:00:00 2001 From: Slin Lee Date: Mon, 30 Aug 2021 15:00:44 -0700 Subject: [PATCH 06/12] Apply pre-commit changes to pass tests --- app.py | 14 +++++--------- docs.py | 2 +- tests/test_api_speed.py | 3 ++- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/app.py b/app.py index 349cb0f..67d11c4 100644 --- a/app.py +++ b/app.py @@ -5,22 +5,18 @@ import ray import yaml +from fastapi import Depends, FastAPI +from fastapi.openapi.utils import get_openapi +from fastapi.responses import FileResponse +from fastapi.security.api_key import APIKey from ray import serve import config from api import Action, Observation, RawObservation +from docs import get_redoc_html, get_swagger_ui_html from generate import CLI from security import get_api_key -from fastapi.responses import FileResponse -from fastapi import Depends -from fastapi.security.api_key import APIKey -from fastapi.openapi.utils import get_openapi - -from fastapi import FastAPI -from docs import get_swagger_ui_html, get_redoc_html - - url_path = config.parameters.get("url_path") app = ( diff --git a/docs.py b/docs.py index ad25b91..7597baa 100644 --- a/docs.py +++ b/docs.py @@ -32,7 +32,7 @@ def get_swagger_ui_html( if project_id and model_id: html += f"""
Date: Mon, 30 Aug 2021 18:33:01 -0700 Subject: [PATCH 07/12] Run all pytest suite --- .github/workflows/python-package.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fc32e3b..e316aab 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -40,6 +40,4 @@ jobs: pre-commit run --all-files - name: Test with pytest run: | - cp examples/mouse_and_cheese/schema.yaml . - python generate.py copy_server_files examples/mouse_and_cheese - pytest tests/test_app.py + pytest . From 58fd92428768e9abe70beb0f2c45bd28a3ba5a23 Mon Sep 17 00:00:00 2001 From: Slin Lee Date: Mon, 30 Aug 2021 19:12:00 -0700 Subject: [PATCH 08/12] Fix pytest command --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index e316aab..fbd8ac7 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -40,4 +40,4 @@ jobs: pre-commit run --all-files - name: Test with pytest run: | - pytest . + pytest tests/ From 596ac892fdf618301a4e73425e81a2607bbc6b74 Mon Sep 17 00:00:00 2001 From: Slin Lee Date: Mon, 30 Aug 2021 21:51:14 -0700 Subject: [PATCH 09/12] Copy at least one set of server files first before running auto tests --- .github/workflows/python-package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fbd8ac7..22cdea6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -40,4 +40,5 @@ jobs: pre-commit run --all-files - name: Test with pytest run: | + python generate.py copy_server_files "examples/lpoc" pytest tests/ From 10b7837b673d8b473e99762f3b7960559bab67ea Mon Sep 17 00:00:00 2001 From: Slin Lee Date: Mon, 30 Aug 2021 21:57:06 -0700 Subject: [PATCH 10/12] Allow copy_server_files to run first --- .github/workflows/python-package.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 22cdea6..c8207ec 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -39,6 +39,8 @@ jobs: run: | pre-commit run --all-files - name: Test with pytest + # config.py won't import if schema.yaml doesn't exist yet run: | + cp examples/lpoc/schema.yaml . python generate.py copy_server_files "examples/lpoc" pytest tests/ From a721f357cd5640d930f467010b4b2f6b55dea0e8 Mon Sep 17 00:00:00 2001 From: Slin Lee Date: Mon, 30 Aug 2021 22:32:45 -0700 Subject: [PATCH 11/12] Hack for getting ray to shutdown before each test --- tests/test_app.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_app.py b/tests/test_app.py index ce50ab6..d0a0457 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -9,6 +9,10 @@ client = TestClient(app) +def setup_function(): + ray.shutdown() + + def test_health_check(): response = client.get("/") assert response.status_code == 200 From 7d1ffe216968e4dbad84f94f37a632e285dd51fa Mon Sep 17 00:00:00 2001 From: Slin Lee Date: Tue, 31 Aug 2021 00:26:26 -0700 Subject: [PATCH 12/12] Load mouse and cheese schema to get tests to pass This isn't a real fix. It just happens to work since test_app.py is the only one that actually tests against the schema. pytest seems to load config.py once, which sets the global params like schema with whatever files are there first. --- .github/workflows/python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c8207ec..f42d56c 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -42,5 +42,5 @@ jobs: # config.py won't import if schema.yaml doesn't exist yet run: | cp examples/lpoc/schema.yaml . - python generate.py copy_server_files "examples/lpoc" + python generate.py copy_server_files "examples/mouse_and_cheese" pytest tests/