Skip to content

Commit

Permalink
Merge pull request #7 from AMLResearchProject/patch-4-fix-formatting
Browse files Browse the repository at this point in the history
Patch 4 fix formatting
  • Loading branch information
AdamMiltonBarker authored Aug 22, 2021
2 parents 31d165d + 1575d04 commit 672f18d
Show file tree
Hide file tree
Showing 11 changed files with 1,154 additions and 1,167 deletions.
146 changes: 73 additions & 73 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,116 +44,116 @@


class classifier(AbstractClassifier):
""" ALL Jetson Nano Classifier
""" ALL Jetson Nano Classifier
Represents a AI classifier that processes data
using the ALL Jetson Nano Classifier model.
"""
Represents a AI classifier that processes data
using the ALL Jetson Nano Classifier model.
"""

def train(self):
""" Creates & trains the model. """
def train(self):
""" Creates & trains the model. """

self.model.prepare_data()
self.model.prepare_network()
self.model.train()
self.model.evaluate()
self.model.prepare_data()
self.model.prepare_network()
self.model.train()
self.model.evaluate()

def init_model(self):
""" Initializes the model class """
def init_model(self):
""" Initializes the model class """

self.model = model(self.helpers)
self.model = model(self.helpers)

def load_model(self):
""" Loads the trained model """
def load_model(self):
""" Loads the trained model """

self.model.load()
self.model.load()

def load_model_tfrt(self):
""" Loads the trained TFRT model """
def load_model_tfrt(self):
""" Loads the trained TFRT model """

self.model.load_tfrt()
self.model.load_tfrt()

def inference(self):
""" Classifies test data locally """
def inference(self):
""" Classifies test data locally """

self.load_model()
self.model.test()
self.load_model()
self.model.test()

def server(self):
""" Starts the API server """
def server(self):
""" Starts the API server """

self.load_model()
self.server = server(self.helpers, self.model,
self.model_type)
self.server.start()
self.load_model()
self.server = server(self.helpers, self.model,
self.model_type)
self.server.start()

def inference_http(self):
""" Classifies test data via HTTP requests """
def inference_http(self):
""" Classifies test data via HTTP requests """

self.model.test_http()
self.model.test_http()

def inference_tfrt(self):
""" Classifies test data via HTTP requests """
def inference_tfrt(self):
""" Classifies test data via HTTP requests """

self.load_model_tfrt()
self.model.test_tfrt()
self.load_model_tfrt()
self.model.test_tfrt()

def init_engine(self):
""" Initizializes the engine class """
def init_engine(self):
""" Initizializes the engine class """

from modules.engine import engine
from modules.engine import engine

self.engine = engine(self.helpers)
self.engine = engine(self.helpers)

def inference_tensorrt(self):
""" Classifies test data via HTTP requests """
def inference_tensorrt(self):
""" Classifies test data via HTTP requests """

self.engine.load_engine()
self.engine.test()
self.engine.load_engine()
self.engine.test()

def signal_handler(self, signal, frame):
self.helpers.logger.info("Disconnecting")
sys.exit(1)
def signal_handler(self, signal, frame):
self.helpers.logger.info("Disconnecting")
sys.exit(1)


classifier = classifier()


def main():

if len(sys.argv) < 2:
print("You must provide an argument")
exit()
elif sys.argv[1] not in classifier.helpers.confs["agent"]["params"]:
print("Mode not supported! server, train or inference")
exit()
if len(sys.argv) < 2:
print("You must provide an argument")
exit()
elif sys.argv[1] not in classifier.helpers.confs["agent"]["params"]:
print("Mode not supported! server, train or inference")
exit()

mode = sys.argv[1]
mode = sys.argv[1]

if mode == "train":
classifier.init_model()
classifier.train()
if mode == "train":
classifier.init_model()
classifier.train()

elif mode == "classify":
classifier.init_model()
classifier.inference()
elif mode == "classify":
classifier.init_model()
classifier.inference()

elif mode == "server":
classifier.init_model()
classifier.server()
elif mode == "server":
classifier.init_model()
classifier.server()

elif mode == "classify_http":
classifier.init_model()
classifier.inference_http()
elif mode == "classify_http":
classifier.init_model()
classifier.inference_http()

elif mode == "classify_tfrt":
classifier.init_model()
classifier.inference_tfrt()
elif mode == "classify_tfrt":
classifier.init_model()
classifier.inference_tfrt()

elif mode == "classify_tensorrt":
classifier.init_engine()
classifier.inference_tensorrt()
elif mode == "classify_tensorrt":
classifier.init_engine()
classifier.inference_tensorrt()


if __name__ == "__main__":
main()
main()
92 changes: 46 additions & 46 deletions modules/AbstractClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,49 +40,49 @@


class AbstractClassifier(ABC):
""" Abstract class representing an AI Classifier.
Represents an AI Classifier. AI Classifiers process data using AI
models. Based on HIAS AI Agents for future compatibility with
the HIAS Network.
"""

def __init__(self):
""" Initializes the AbstractClassifier object. """
super().__init__()

self.helpers = helpers("Classifier")
self.confs = self.helpers.confs
self.model_type = None

self.helpers.logger.info("Classifier initialization complete.")

@abstractmethod
def init_model(self):
""" Loads the model class """
pass

@abstractmethod
def train(self):
""" Creates & trains the model. """
pass

@abstractmethod
def load_model(self):
""" Loads the AI model """
pass

@abstractmethod
def inference(self):
""" Loads model and classifies test data """
pass

@abstractmethod
def server(self):
""" Loads the API server """
pass

@abstractmethod
def inference_http(self):
""" Classifies test data via HTTP requests """
pass
""" Abstract class representing an AI Classifier.
Represents an AI Classifier. AI Classifiers process data using AI
models. Based on HIAS AI Agents for future compatibility with
the HIAS Network.
"""

def __init__(self):
""" Initializes the AbstractClassifier object. """
super().__init__()

self.helpers = helpers("Classifier")
self.confs = self.helpers.confs
self.model_type = None

self.helpers.logger.info("Classifier initialization complete.")

@abstractmethod
def init_model(self):
""" Loads the model class """
pass

@abstractmethod
def train(self):
""" Creates & trains the model. """
pass

@abstractmethod
def load_model(self):
""" Loads the AI model """
pass

@abstractmethod
def inference(self):
""" Loads model and classifies test data """
pass

@abstractmethod
def server(self):
""" Loads the API server """
pass

@abstractmethod
def inference_http(self):
""" Classifies test data via HTTP requests """
pass
Loading

0 comments on commit 672f18d

Please sign in to comment.