Skip to content

Commit

Permalink
change docs, favicon etc. fixes #87, fixes #74
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 13, 2021
1 parent 1bfafd5 commit 469e515
Showing 1 changed file with 60 additions and 22 deletions.
82 changes: 60 additions & 22 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from security import get_api_key

from fastapi.responses import FileResponse
from fastapi import Depends, FastAPI, HTTPException, status, Security
from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey
from fastapi import Depends
from fastapi.security.api_key import APIKey
from fastapi.openapi.utils import get_openapi

from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
Expand All @@ -27,7 +27,34 @@
batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder
writer = JsonWriter(config.EXPERIENCE_LOCATION)

app = FastAPI()
from fastapi import FastAPI
from docs import get_swagger_ui_html, get_redoc_html


app = FastAPI(docs_url=None, redoc_url=None)


@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
openapi_url="/openapi.json",
title="Pathmind Policy Server",
swagger_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com",
project_id=config.project_id,
model_id=config.model_id,
)


@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
return get_redoc_html(
openapi_url="/openapi.json",
title="Pathmind Policy Server",
redoc_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com",
project_id=config.project_id,
model_id=config.model_id,
)


tags_metadata = [
{
Expand Down Expand Up @@ -67,6 +94,7 @@ def custom_openapi():


if config.USE_RAY:

@app.on_event("startup")
async def startup_event():

Expand All @@ -76,6 +104,7 @@ async def startup_event():
backend_config = serve.BackendConfig(num_replicas=4)

from api import PathmindPolicy

client.create_backend("pathmind_policy", PathmindPolicy, config=backend_config)
client.create_endpoint("predict", backend="pathmind_policy")

Expand All @@ -87,25 +116,32 @@ async def startup_event():
# Note: for basic auth, use "logged_in: bool = Depends(verify_credentials)" as parameter
@app.post("/predict/", response_model=Action, tags=["Predictions"])
async def predict(payload: Observation, api_key: APIKey = Depends(get_api_key)):
lists = [[getattr(payload, obs)] if not isinstance(getattr(payload, obs), List) else getattr(payload, obs)
for obs in config.observations.keys()]
lists = [
[getattr(payload, obs)]
if not isinstance(getattr(payload, obs), List)
else getattr(payload, obs)
for obs in config.observations.keys()
]
observations = list(itertools.chain(*lists))
# Note that ray can't pickle the pydantic "Observation" model, so we need to convert it here.
return await SERVE_HANDLE.remote(observations)


@app.post("/predict_deterministic/", response_model=Action, tags=["Predictions"])
async def predict_deterministic(payload: Observation, api_key: APIKey = Depends(get_api_key)):
async def predict_deterministic(
payload: Observation, api_key: APIKey = Depends(get_api_key)
):
return _predict_deterministic(payload)


@app.post("/distribution/", tags=["Predictions"])
async def distribution(payload: Observation, api_key: APIKey = Depends(get_api_key)):
async def distribution(
payload: Observation, api_key: APIKey = Depends(get_api_key)
):
return _distribution(payload)


@app.post("/collect_experience/", response_model=Action, tags=["Predictions"])
async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_api_key)):
async def collect_experience(
payload: Experience, api_key: APIKey = Depends(get_api_key)
):

global cache
global batch_builder
Expand All @@ -114,8 +150,12 @@ async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_
rew = payload.reward
done = payload.done

lists = [[getattr(observation, obs)] if not isinstance(getattr(observation, obs), List) else getattr(observation, obs)
for obs in config.observations.keys()]
lists = [
[getattr(observation, obs)]
if not isinstance(getattr(observation, obs), List)
else getattr(observation, obs)
for obs in config.observations.keys()
]

obs = list(itertools.chain(*lists))
obs = np.reshape(np.asarray(obs), (4,))
Expand All @@ -128,7 +168,7 @@ async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_
if cache.is_empty():
cache.store(t=0, prev_obs=obs, prev_action=action.actions, prev_reward=rew)

act =action.actions[0]
act = action.actions[0]

batch_builder.add_values(
agent_index=0,
Expand All @@ -137,18 +177,16 @@ async def collect_experience(payload: Experience, api_key: APIKey = Depends(get_
action_logp=np.log(action.probability),
t=cache.t,
eps_id=cache.episode,

prev_actions=cache.prev_action,
prev_rewards=cache.prev_reward,
obs=cache.prev_obs, # prep.transform(...)

# sent from environment
new_obs=obs,
dones=done,
infos=None,
rewards=rew,
)
cache.store(t=cache.t+1, prev_obs=obs, prev_action=act, prev_reward=rew)
cache.store(t=cache.t + 1, prev_obs=obs, prev_action=act, prev_reward=rew)

if done:
writer.write(batch_builder.build_and_reset())
Expand All @@ -164,13 +202,13 @@ async def predict_raw(payload: RawObservation, api_key: APIKey = Depends(get_api

@app.get("/clients", tags=["Clients"])
async def clients(api_key: APIKey = Depends(get_api_key)):
shutil.make_archive('clients', 'zip', './clients')
return FileResponse(path='clients.zip', filename='clients.zip')
shutil.make_archive("clients", "zip", "./clients")
return FileResponse(path="clients.zip", filename="clients.zip")


@app.get("/schema", tags=["Clients"])
async def server_schema(api_key: APIKey = Depends(get_api_key)):
with open(config.PATHMIND_SCHEMA, 'r') as schema_file:
with open(config.PATHMIND_SCHEMA, "r") as schema_file:
schema_str = schema_file.read()
schema = yaml.safe_load(schema_str)
return schema
Expand All @@ -184,8 +222,8 @@ async def health_check():
app.openapi = custom_openapi

# Write the swagger file locally
with open(config.LOCAL_SWAGGER, 'w') as f:
with open(config.LOCAL_SWAGGER, "w") as f:
f.write(json.dumps(app.openapi()))

# Generate all clients on startup
CLI.clients()
CLI.clients()

0 comments on commit 469e515

Please sign in to comment.