diff --git a/client/pdo/client/commands/service_groups.py b/client/pdo/client/commands/service_groups.py index 078abbc7..5f793b1f 100644 --- a/client/pdo/client/commands/service_groups.py +++ b/client/pdo/client/commands/service_groups.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import json -import mergedeep import toml import pdo.client.builder.shell as pshell import pdo.client.builder.script as pscript -import pdo.common.utility as putils import pdo.common.config as pconfig -import pdo.client.commands.service_db as pservice from pdo.service_client.service_data.service_groups import GroupsDatabaseManager as group_data @@ -32,7 +28,6 @@ 'get_group_info', 'add_group', 'remove_group', - 'clear_service_data', 'script_command_clear', 'script_command_export', 'script_command_import', @@ -60,10 +55,6 @@ def add_group(service_type : str, group_name : str, service_urls, **kwargs) : if service_type not in group_data.service_types : raise RuntimeError("unknown service type; {}".format(service_type)) - # make sure that all of the URLs are registered in the service_db - for u in service_urls : - _ = pservice.get_service_info(service_type, service_url=u) - info = group_data.service_group_map[service_type](group_name, service_urls, **kwargs) group_data.local_groups_manager.update(info) diff --git a/python/pdo/service_client/service_data/service_groups.py b/python/pdo/service_client/service_data/service_groups.py index 1e195e0c..a0d6b85e 100644 --- a/python/pdo/service_client/service_data/service_groups.py +++ b/python/pdo/service_client/service_data/service_groups.py @@ -15,20 +15,16 @@ # limitations under the License. import atexit -from functools import lru_cache import json import lmdb -import os import pdo.common.config as pconfig -import pdo.common.logger as plogger -from pdo.service_client.enclave import EnclaveServiceClient -from pdo.service_client.provisioning import ProvisioningServiceClient -from pdo.service_client.storage import StorageServiceClient +from pdo.service_client.service_data.service_data import ServiceDatabaseManager as service_data from pdo.common.utility import classproperty -from urllib.parse import urlparse +import logging +logger = logging.getLogger(__name__) ## ================================================================= ## SERVICE GROUP CLASSES @@ -64,6 +60,16 @@ def serialize(self) : service_info['urls'] = self.service_urls return service_info + def verify(self) : + """Verify that the URLs in the group are all part of the service db + + Raises an exception if verification fails. Note that this only checks + if the URLs currently exist in the database. There is no enforcement + for future changes. + """ + for u in self.service_urls : + _ = service_data.local_service_manager.get_by_url(u, self.service_type) + def clone(self) : return type(self).unpack(self.serialize()) @@ -132,7 +138,7 @@ class GroupsDatabaseManager(object) : @classproperty def local_groups_manager(cls) : if cls.__local_groups_manager__ is None : - groups_db_file = pconfig.shared_configuration(['Service','GroupDatabaseFile'], "./groups_db.mdb") + groups_db_file = pconfig.shared_configuration(['Service', 'GroupDatabaseFile'], "./groups_db.mdb") cls.__local_groups_manager__ = cls(groups_db_file, True) atexit.register(cls.__local_groups_manager__.close) @@ -183,6 +189,10 @@ def update(self, group_info : BaseGroup) : The update operation is effectively the same as the store operation except that update overwrites an existing entry, while store fails. """ + + # make sure that all of the URLs are registered in the service_db + group_info.verify() + groups_db = self.groups_db(group_info.service_type) with self.groups_db_env.begin(write=True) as txn : group_name = BaseGroup.force_to_bytes(group_info.group_name) @@ -192,6 +202,9 @@ def update(self, group_info : BaseGroup) : # ----------------------------------------------------------------- def store(self, group_info) : + # make sure that all of the URLs are registered in the service_db + group_info.verify() + groups_db = self.groups_db(group_info.service_type) with self.groups_db_env.begin(write=True) as txn : group_name = BaseGroup.force_to_bytes(group_info.group_name) @@ -264,8 +277,11 @@ def import_group_information(self, groups) : except : pass - group_info = self.service_group_map[service_type](group_name, **group_info) - self.store(group_info) + try : + group_info = self.service_group_map[service_type](group_name, **group_info) + self.store(group_info) + except : + logger.warning('failed to import {} group {}'.format(service_type, group_name)) # ----------------------------------------------------------------- def export_group_information(self) :