-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path__init__.py
executable file
·634 lines (548 loc) · 30.5 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
"""
This module contains utilities for the DCP CLI and API bindings. These
utility classes and functions are used under the hood by both the DCP
API bindings library and by its CLI. There is no need to use these
utilities directly unless you are extending DCP client functionality.
``SwaggerClient`` is a base class for a general purpose Swagger API
client connection manager. User classes such as ``hca.dss.DSSClient``
extend it as follows:
class APIClient(SwaggerClient):
def __init__(self, *args, **kwargs):
super(APIClient, self).__init__(*args, **kwargs)
self.commands += [self.special_cli_command]
def special_cli_command(self, required_argument, optional_argument=""):
return {}
Each user class should have a configuration subtree keyed by its name
(such as ``APIClient`` above) under the DCP-wide config manager
(available via ``hca.get_config()``; the static defaults for the
config manager are stored in ``hca/default_config.json``). Within that
subtree, the key ``swagger_url`` should point to the HTTPS URL
containing the Swagger API definition that the client is providing an
interface for. On first use, this API definition will be downloaded
and saved into the user config directory (for example,
/Users/Alice/.config/hca) with a name determined by the base64
encoding of ``swagger_url``. On subsequent uses, this file will be
loaded instead to get the Swagger API definition.
The Swagger API definition is then used to dynamically construct and
attach API client methods as class methods, decorate them with API
metadata such as docstrings and I/O signatures. Method names are
determined using a mapping heuristic:
GET /foo/bar -> APIClient.get_foo_bar()
POST /widgets/{id} -> APIClient.post_widget()
Client methods (provided by _ClientMethodFactory) build the HTTP
request payload by matching their keyword argument inputs to the
Swagger API definition, and use the ``requests`` library to call the
API. JSON body input and output is assumed by default:
json_results = APIClient().post_widget(id="foo", qs_param="x", body_param=123)
The ``stream()`` method can be used instead to stream the raw body as
bytes, or to otherwise provide access to the ``requests.Response``
object:
with APIClient().get_foo_bar.stream() as response:
while True:
chunk = response.raw.read(1024)
...
if not chunk:
break
Results from API routes that support GitHub/RFC 5988 style pagination
can be paged like this:
for result in APIClient().get_foo_bar.iterate():
...
Routes that require authentication trigger the use of the auth
middleware provided by requests_oauthlib (work in progress).
CLI parsers for argparse can be generated and injected into a parent
``argparse.ArgumentParser`` object by passing a subparsers object to
SwaggerClient.build_argparse_subparsers. The resulting CLI entry point
can be called like this:
$ dss post_widget --id ID
In addition to bindings to API methods in the Swagger definition,
SwaggerClient designates certain methods as *commands*, which means
they are part of the public bindings API and are also provided as CLI
subcommands. The SwaggerClient class provides two such commands, login
and logout, which manage the cached authentication credentials for the
client. Subclasses can add more commands by adding them to the
``SwaggerClient.commands`` array, as shown with
``special_cli_command`` in the example above.
"""
import os
import multiprocessing
import types
import collections
import typing
import json
import errno
import base64
import argparse
import time
import jwt
import requests
from inspect import signature, Parameter
from requests.adapters import HTTPAdapter, DEFAULT_POOLSIZE
from requests_oauthlib import OAuth2Session
from urllib3.util import retry, timeout
from urllib.parse import urljoin
from jsonpointer import resolve_pointer
from threading import Lock
from argparse import RawTextHelpFormatter
from dcplib.networking import Session
from .. import get_config, logger
from .exceptions import SwaggerAPIException, SwaggerClientInternalError
from ._docs import _pagination_docstring, _streaming_docstring, _md2rst, _parse_docstring
from .fs_helper import FSHelper as fs
"""Based on https://askubuntu.com/questions/668538/cores-vs-threads-how-many-threads-should-i-run-on-this-machine
and https://github.com/bloomreach/s4cmd/blob/master/s4cmd.py#L121."""
DEFAULT_THREAD_COUNT = multiprocessing.cpu_count() * 2
class RetryPolicy(retry.Retry):
pass
class _ClientMethodFactory(object):
def __init__(self, client, parameters, path_parameters, http_method, method_name, method_data, body_props):
self.__dict__.update(locals())
self._context_manager_response = None
def _request(self, req_args, url=None, stream=False, headers=None):
supplied_path_params = [p for p in req_args if p in self.path_parameters and req_args[p] is not None]
if url is None:
url = self.client.host + self.client.http_paths[self.method_name][frozenset(supplied_path_params)]
url = url.format(**req_args)
logger.debug("%s %s %s", self.http_method, url, req_args)
query = {k: v for k, v in req_args.items()
if self.parameters.get(k, {}).get("in") == "query" and v is not None}
body = {k: v for k, v in req_args.items() if k in self.body_props and v is not None}
if "security" in self.method_data:
session = self.client.get_authenticated_session()
else:
session = self.client.get_session()
# TODO: (akislyuk) if using service account credentials, use manual refresh here
json_input = body if self.body_props else None
headers = headers or {}
headers.update({k: v for k, v in req_args.items() if self.parameters.get(k, {}).get('in') == 'header'})
res = session.request(self.http_method, url, params=query, json=json_input, stream=stream,
headers=headers, timeout=self.client.timeout_policy)
if res.status_code >= 400:
raise SwaggerAPIException(response=res)
return res
def _consume_response(self, response):
if self.http_method.upper() == "HEAD":
return response
elif response.headers["content-type"].startswith("application/json"):
return response.json()
else:
return response.content
def __call__(self, client, **kwargs):
return self._consume_response(self._request(kwargs))
def _cli_call(self, cli_args):
return self._consume_response(self._request(vars(cli_args)))
def stream(self, **kwargs):
self._context_manager_response = self._request(kwargs, stream=True)
return self
def __enter__(self, **kwargs):
assert self._context_manager_response is not None
return self._context_manager_response
def __exit__(self, exc_type, exc_val, exc_tb):
self._context_manager_response.close()
self._context_manager_response = None
class _PaginatingClientMethodFactory(_ClientMethodFactory):
def _get_raw_pages(self, **kwargs):
page = None
while page is None or page.links.get("next", {}).get("url"):
page = self._request(kwargs, url=page.links["next"]["url"] if page else None)
yield page
def iterate(self, **kwargs):
"""
Yield specific items from each response depending on its contents.
For example, GET /bundles/{id} and GET /collections/{id} yield the
items contained within; POST /search yields search result items.
"""
for page in self._get_raw_pages(**kwargs):
if page.json().get('results'):
for result in page.json()['results']:
yield result
elif page.json().get('bundle'):
for file in page.json()['bundle']['files']:
yield file
else:
for collection in page.json().get('collections', []):
yield collection
def paginate(self, **kwargs):
"""Yield paginated responses one response body at a time."""
for page in self._get_raw_pages(**kwargs):
yield page.json()
class SwaggerClient(object):
scheme = "https"
retry_policy = RetryPolicy(read=10,
status=10,
backoff_factor=0.1,
status_forcelist=frozenset({500, 502, 503, 504}))
token_expiration = 3600
_authenticated_session = None
_session = None
_spec_valid_for_days = 7
_swagger_spec_lock = Lock()
_type_map = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"array": typing.List,
"object": typing.Mapping
}
_audience = "https://dev.data.humancellatlas.org/" # TODO derive from swagger
# The read timeout should be longer than DSS' API Gateway timeout to avoid races with the client and the gateway
# hanging up at the same time. It's better to consistently get a 504 from the server than a read timeout from the
# client or sometimes one and sometimes the other.
#
timeout_policy = timeout.Timeout(connect=20, read=40)
max_redirects = 1024
def __init__(self, config=None, swagger_url=None, **session_kwargs):
self.config = config or get_config()
self.swagger_url = swagger_url or self.config[self.__class__.__name__].swagger_url
self._session_kwargs = session_kwargs
self._swagger_spec = None
self.__class__.__doc__ = _md2rst(self.swagger_spec["info"]["description"])
self.methods = {}
self.commands = [self.login, self.logout]
self.http_paths = collections.defaultdict(dict)
if "openapi" in self.swagger_spec:
server = self.swagger_spec["servers"][0]
variables = {k: v["default"] for k, v in server.get("variables", {}).items()}
self.host = server["url"].format(**variables)
else:
self.host = "{scheme}://{host}{base}".format(scheme=self.scheme,
host=self.swagger_spec["host"],
base=self.swagger_spec["basePath"])
for http_path, path_data in self.swagger_spec["paths"].items():
for http_method, method_data in path_data.items():
self._build_client_method(http_method, http_path, method_data)
@staticmethod
def load_swagger_json(swagger_json, ptr_str="$ref"):
"""
Load the Swagger JSON and resolve {"$ref": "#/..."} internal JSON Pointer references.
"""
refs = []
def store_refs(d):
if len(d) == 1 and ptr_str in d:
refs.append(d)
return d
swagger_content = json.load(swagger_json, object_hook=store_refs)
for ref in refs:
_, target = ref.popitem()
assert target[0] == "#"
ref.update(resolve_pointer(swagger_content, target[1:]))
return swagger_content
@property
def swagger_spec(self):
with self._swagger_spec_lock:
if not self._swagger_spec:
if "swagger_filename" in self.config:
swagger_filename = self.config.swagger_filename
if not swagger_filename.startswith("/"):
swagger_filename = os.path.join(os.path.dirname(__file__), swagger_filename)
else:
swagger_filename = self._get_swagger_filename(self.swagger_url)
if (("swagger_filename" not in self.config) and
((not os.path.exists(swagger_filename)) or
(fs.get_days_since_last_modified(swagger_filename) >= self._spec_valid_for_days))):
try:
os.makedirs(self.config.user_config_dir)
except OSError as e:
if not (e.errno == errno.EEXIST and os.path.isdir(self.config.user_config_dir)):
raise
res = self.get_session().get(self.swagger_url)
res.raise_for_status()
res_json = res.json()
assert "swagger" in res_json or "openapi" in res_json
fs.atomic_write(swagger_filename, res.content)
with open(swagger_filename) as fh:
self._swagger_spec = self.load_swagger_json(fh)
return self._swagger_spec
def _get_swagger_filename(self, swagger_url):
swagger_filename = base64.urlsafe_b64encode(swagger_url.encode()).decode() + ".json"
swagger_filename = os.path.join(self.config.user_config_dir, swagger_filename)
return swagger_filename
def clear_cache(self):
"""
Clear the cached API definitions for a component. This can help resolve errors communicating with the API.
"""
try:
os.remove(self._get_swagger_filename(self.swagger_url))
except EnvironmentError as e:
logger.warn(os.strerror(e.errno))
else:
self.__init__()
@property
def application_secrets(self):
if "application_secrets" not in self.config:
app_secrets_url = "https://{}/internal/application_secrets".format(self._swagger_spec["host"])
self.config.application_secrets = requests.get(app_secrets_url).json()
return self.config.application_secrets
def get_session(self):
if self._session is None:
self._session = Session(**self._session_kwargs)
self._session.max_redirects = self.max_redirects
self._session.headers.update({"User-Agent": self.__class__.__name__})
self._set_retry_policy(self._session)
return self._session
def logout(self):
"""
Clear {prog} authentication credentials previously configured with ``{prog} login``.
"""
for keys in ["application_secrets", "oauth2_token"]:
try:
del self.config[keys]
except KeyError:
pass
def login(self, access_token="", remote=False):
"""
Configure and save {prog} authentication credentials.
This command may open a browser window to ask for your
consent to use web service authentication credentials.
Use --remote if using the CLI in a remote environment
"""
if access_token:
credentials = argparse.Namespace(token=access_token, refresh_token=None, id_token=None)
else:
scopes = ["openid", "email", "offline_access"]
if remote:
import google_auth_oauthlib.flow
application_secrets = self.application_secrets
redirect_uri = urljoin(application_secrets['installed']['auth_uri'], "/echo")
flow = google_auth_oauthlib.flow.Flow.from_client_config(self.application_secrets, scopes=scopes,
redirect_uri=redirect_uri)
authorization_url, _ = flow.authorization_url()
print("please authenticate at the url: {}".format(authorization_url))
code = input("pass 'code' value from within query_params: ")
flow.fetch_token(code=code)
credentials = flow.credentials
else:
from google_auth_oauthlib.flow import InstalledAppFlow
flow = InstalledAppFlow.from_client_config(self.application_secrets, scopes=scopes)
msg = "Authentication successful. Please close this tab and run HCA CLI commands in the terminal."
credentials = flow.run_local_server(success_message=msg, audience=self._audience)
# TODO: (akislyuk) test token autorefresh on expiration
self.config.oauth2_token = dict(access_token=credentials.token,
refresh_token=credentials.refresh_token,
id_token=credentials.id_token,
expires_at="-1",
token_type="Bearer")
print("Storing access credentials")
def _get_oauth_token_from_service_account_credentials(self):
scopes = ["https://www.googleapis.com/auth/userinfo.email"]
assert 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ
from google.auth.transport.requests import Request as GoogleAuthRequest
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
logger.info("Found GOOGLE_APPLICATION_CREDENTIALS environment variable. "
"Using service account credentials for authentication.")
service_account_credentials_filename = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
if not os.path.isfile(service_account_credentials_filename):
msg = 'File "{}" referenced by the GOOGLE_APPLICATION_CREDENTIALS environment variable does not exist'
raise Exception(msg.format(service_account_credentials_filename))
credentials = ServiceAccountCredentials.from_service_account_file(
service_account_credentials_filename,
scopes=scopes
)
r = GoogleAuthRequest()
credentials.refresh(r)
r.session.close()
return credentials.token, credentials.expiry
def _get_jwt_from_service_account_credentials(self):
assert 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ
service_account_credentials_filename = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
if not os.path.isfile(service_account_credentials_filename):
msg = 'File "{}" referenced by the GOOGLE_APPLICATION_CREDENTIALS environment variable does not exist'
raise Exception(msg.format(service_account_credentials_filename))
with open(service_account_credentials_filename) as fh:
service_credentials = json.load(fh)
iat = time.time()
exp = iat + self.token_expiration
payload = {'iss': service_credentials["client_email"],
'sub': service_credentials["client_email"],
'aud': self._audience,
'iat': iat,
'exp': exp,
'email': service_credentials["client_email"],
'scope': ['email', 'openid', 'offline_access'],
'https://auth.data.humancellatlas.org/group': 'hca'
}
additional_headers = {'kid': service_credentials["private_key_id"]}
signed_jwt = jwt.encode(payload, service_credentials["private_key"], headers=additional_headers,
algorithm='RS256').decode()
return signed_jwt, exp
def get_authenticated_session(self):
if self._authenticated_session is None:
oauth2_client_data = self.application_secrets["installed"]
if 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
token, expires_at = self._get_jwt_from_service_account_credentials()
# TODO: (akislyuk) figure out the right strategy for persisting the service account oauth2 token
self._authenticated_session = OAuth2Session(client_id=oauth2_client_data["client_id"],
token=dict(access_token=token),
**self._session_kwargs)
else:
if "oauth2_token" not in self.config:
msg = ('Please configure {prog} authentication credentials using "{prog} login" '
'or set the GOOGLE_APPLICATION_CREDENTIALS environment variable')
raise Exception(msg.format(prog=self.__module__.replace(".", " ")))
self._authenticated_session = OAuth2Session(
client_id=oauth2_client_data["client_id"],
token=self.config.oauth2_token,
auto_refresh_url=oauth2_client_data["token_uri"],
auto_refresh_kwargs=dict(client_id=oauth2_client_data["client_id"],
client_secret=oauth2_client_data["client_secret"]),
token_updater=self._save_auth_token_refresh_result,
**self._session_kwargs
)
self._authenticated_session.headers.update({"User-Agent": self.__class__.__name__})
self._set_retry_policy(self._authenticated_session)
return self._authenticated_session
def _set_retry_policy(self, session):
adapter = HTTPAdapter(max_retries=self.retry_policy, pool_maxsize=max(DEFAULT_THREAD_COUNT, DEFAULT_POOLSIZE))
session.mount('http://', adapter)
session.mount('https://', adapter)
def _save_auth_token_refresh_result(self, result):
self.config.oauth2_token = result
def _process_method_args(self, parameters, body_json_schema):
body_props = {}
method_args = collections.OrderedDict()
def _parse_properties(properties, schema):
for prop_name, prop_data in properties.items():
enum_values = prop_data.get("enum")
type_ = prop_data.get("type") if enum_values is None else 'string'
anno = self._type_map[type_]
if prop_name not in body_json_schema.get("required", []):
anno = typing.Optional[anno]
param = Parameter(prop_name, Parameter.POSITIONAL_OR_KEYWORD, default=prop_data.get("default"),
annotation=anno)
method_args.setdefault(prop_name, {}).update(dict(param=param, doc=prop_data.get("description"),
choices=enum_values,
required=prop_name in body_json_schema.get("required", [])))
body_props[prop_name] = _merge_dict(schema, body_props.get('prop_name', {}))
if body_json_schema.get('properties', {}):
_parse_properties(body_json_schema["properties"], body_json_schema)
for schema in body_json_schema.get('allOf', []):
_parse_properties(schema.get('properties', {}), schema)
for parameter in parameters.values():
annotation = str if parameter.get("required") else typing.Optional[str]
param = Parameter(parameter["name"], Parameter.POSITIONAL_OR_KEYWORD, default=parameter.get("default"),
annotation=annotation)
method_args[parameter["name"]] = dict(param=param, doc=parameter.get("description"),
choices=parameter.get("enum"), required=parameter.get("required"))
return body_props, method_args
@staticmethod
def _build_method_name(http_method, http_path):
method_name = http_path.replace('/.well-known', '').replace('-', '_')
method_name_parts = [http_method] + [p for p in method_name.split("/")[1:] if not p.startswith("{")]
method_name = "_".join(method_name_parts)
if method_name.endswith("s") and (http_method.upper() in {"POST", "PUT"} or http_path.endswith("/{uuid}")):
method_name = method_name[:-1]
return method_name
def _build_client_method(self, http_method, http_path, method_data):
method_name = self._build_method_name(http_method, http_path)
parameters = {p["name"]: p for p in method_data.get("parameters", [])}
body_json_schema = {"properties": {}}
if "requestBody" in method_data and "application/json" in method_data["requestBody"]["content"]:
body_json_schema = method_data["requestBody"]["content"]["application/json"]["schema"]
else:
for p in parameters:
if parameters[p]["in"] == "body":
body_json_schema = parameters.pop(p)["schema"]
break
path_parameters = [p_name for p_name, p_data in parameters.items() if p_data["in"] == "path"]
self.http_paths[method_name][frozenset(path_parameters)] = http_path
body_props, method_args = self._process_method_args(parameters=parameters, body_json_schema=body_json_schema)
method_supports_pagination = True if str(requests.codes.partial) in method_data["responses"] else False
highlight_streaming_support = True if str(requests.codes.found) in method_data["responses"] else False
factory = _PaginatingClientMethodFactory if method_supports_pagination else _ClientMethodFactory
client_method = factory(self, parameters, path_parameters, http_method, method_name, method_data, body_props)
client_method.__name__ = method_name
client_method.__qualname__ = self.__class__.__name__ + "." + method_name
params = [Parameter("factory", Parameter.POSITIONAL_OR_KEYWORD),
Parameter("client", Parameter.POSITIONAL_OR_KEYWORD)]
params += [v["param"] for k, v in method_args.items() if not k.startswith("_")]
client_method.__signature__ = signature(client_method).replace(parameters=params)
docstring = method_data["summary"] + "\n\n"
if method_supports_pagination:
docstring += _pagination_docstring.format(client_name=self.__class__.__name__, method_name=method_name)
if highlight_streaming_support:
docstring += _streaming_docstring.format(client_name=self.__class__.__name__, method_name=method_name)
for param in method_args:
if not param.startswith("_"):
param_doc = _md2rst(method_args[param]["doc"] or "")
docstring += ":param {}: {}\n".format(param, param_doc.replace("\n", " "))
docstring += ":type {}: {}\n".format(param, method_args[param]["param"].annotation)
docstring += "\n\n" + _md2rst(method_data["description"])
client_method.__doc__ = docstring
setattr(self.__class__, method_name, types.MethodType(client_method, SwaggerClient))
self.methods[method_name] = dict(method_data, entry_point=getattr(self, method_name)._cli_call,
signature=client_method.__signature__, args=method_args)
def _command_arg_forwarder_factory(self, command, command_sig):
def arg_forwarder(parsed_args):
command_args = {k: v for k, v in vars(parsed_args).items() if k in command_sig.parameters}
return command(**command_args)
return arg_forwarder
def _get_command_arg_settings(self, param_data):
if param_data.default is Parameter.empty:
return dict(required=True)
elif param_data.default is True:
return dict(action='store_false', default=True)
elif param_data.default is False:
return dict(action='store_true', default=False)
elif isinstance(param_data.default, (list, tuple)):
return dict(nargs="+", default=param_data.default)
else:
return dict(type=type(param_data.default), default=param_data.default)
def _get_param_argparse_type(self, anno):
if anno in {typing.List, typing.Mapping, typing.Union[typing.Mapping, None]}:
return json.loads
elif isinstance(getattr(anno, "__args__", None), tuple) and anno == typing.Optional[anno.__args__[0]]:
return anno.__args__[0]
return anno
def build_argparse_subparsers(self, subparsers, help_menu=False):
for method_name, method_data in self.methods.items():
subcommand_name = method_name.replace("_", "-")
subparser = subparsers.add_parser(subcommand_name,
help=method_data.get("summary"),
description=method_data.get("description"),
formatter_class=RawTextHelpFormatter)
if help_menu:
required_group_parser = subparser.add_argument_group('Required Arguments')
for param_name, param in method_data["signature"].parameters.items():
if param_name in {"client", "factory"}:
continue
logger.debug("Registering %s %s %s", method_name, param_name, param.annotation)
nargs = "+" if param.annotation == typing.List else None
if help_menu:
subparser = required_group_parser if method_data["args"][param_name]["required"] else subparser
subparser.add_argument("--" + param_name.replace("_", "-").replace("/", "-"),
dest=param_name,
type=self._get_param_argparse_type(param.annotation),
nargs=nargs,
help=method_data["args"][param_name]["doc"],
choices=method_data["args"][param_name]["choices"],
required=method_data["args"][param_name]["required"])
subparser.set_defaults(entry_point=method_data["entry_point"])
for command in self.commands:
sig = signature(command)
if not getattr(command, "__doc__", None):
raise SwaggerClientInternalError("Command {} has no docstring".format(command))
docstring = command.__doc__.format(prog=subparsers._prog_prefix)
method_args = _parse_docstring(docstring)
command_subparser = subparsers.add_parser(command.__name__.replace("_", "-"),
help=method_args['summary'],
description=method_args['description'],
formatter_class=RawTextHelpFormatter)
if help_menu:
required_group_parser = command_subparser.add_argument_group('Required Arguments')
for param_name, param_data in sig.parameters.items():
params = self._get_command_arg_settings(param_data)
if help_menu:
command_subparser = required_group_parser if params.get('required', False) else command_subparser
command_subparser.add_argument("--" + param_name.replace("_", "-"),
help=method_args['params'].get(param_name, None),
**params)
command_subparser.set_defaults(entry_point=self._command_arg_forwarder_factory(command, sig))
def _merge_dict(source, destination):
"""Recursive dict merge"""
for key, value in source.items():
if isinstance(value, dict):
node = destination.setdefault(key, {})
_merge_dict(value, node)
else:
destination[key] = value
return destination