From 5dbf18ef0b0a44d08897178ca92c1be0772d83cd Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 15 Jul 2024 17:44:02 -0700 Subject: [PATCH] Enable token authorization and model control for gRPC (#3238) * Refactor token auth implementation to extend to gRPC * Implement token authorization for gRPC * Add test for gRPC token authorization * Enable model control for gRPC * Refactor test for gRPC token authorization * Add tets for gRPC model control * Reorder relative path check in download URL * add test for gRPC allowed urls * Refactor bearer token parsing from header * Fix linter errors * Clean up config file in gRPC allowd urls test * Refactor gRPC tests * Fix typo in gRPC inference test * Update documentation for gRPC token auth and model control * Run formatJava target --- docs/grpc_api.md | 29 ++- examples/stateful/sequence_batching/Readme.md | 2 +- .../sequence_continuous_batching/Readme.md | 2 +- .../serve/archive/model/ModelArchive.java | 6 +- .../java/org/pytorch/serve/ModelServer.java | 8 +- .../org/pytorch/serve/ServerInitializer.java | 2 +- .../serve/grpcimpl/GRPCInterceptor.java | 39 ++++ .../serve/grpcimpl/ManagementImpl.java | 20 ++ .../serve/http/TokenAuthorizationHandler.java | 220 ++---------------- .../api/rest/ManagementRequestHandler.java | 9 +- .../org/pytorch/serve/util/ConfigManager.java | 4 +- .../serve/util/TokenAuthorization.java | 209 +++++++++++++++++ .../org/pytorch/serve/util/TokenType.java | 7 - .../org/pytorch/serve/ModelServerTest.java | 3 + test/pytest/test_gRPC_allowed_urls.py | 73 ++++++ test/pytest/test_gRPC_inference_api.py | 6 +- test/pytest/test_gRPC_management_apis.py | 42 ++-- test/pytest/test_gRPC_model_control.py | 58 +++++ test/pytest/test_gRPC_token_authorization.py | 112 +++++++++ test/pytest/test_gRPC_utils.py | 9 +- ts_scripts/torchserve_grpc_client.py | 53 +++-- 21 files changed, 646 insertions(+), 267 deletions(-) create mode 100644 frontend/server/src/main/java/org/pytorch/serve/util/TokenAuthorization.java delete mode 100644 frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java create mode 100644 test/pytest/test_gRPC_allowed_urls.py create mode 100644 test/pytest/test_gRPC_model_control.py create mode 100644 test/pytest/test_gRPC_token_authorization.py diff --git a/docs/grpc_api.md b/docs/grpc_api.md index 5cd54bed74..a41387ebb9 100644 --- a/docs/grpc_api.md +++ b/docs/grpc_api.md @@ -11,6 +11,8 @@ TorchServe provides following gRPCs apis - **Predictions** : Gets predictions from the served model - **StreamPredictions** : Gets server side streaming predictions from the saved model +For all Inference API requests, TorchServe requires the correct Inference token to be included or token authorization must be disable. For more details see [token authorization documentation](./token_authorization_api.md) + * [Management API](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/management.proto) - **RegisterModel** : Serve a model/model-version on TorchServe - **UnregisterModel** : Free up system resources by unregistering specific version of a model from TorchServe @@ -19,6 +21,8 @@ TorchServe provides following gRPCs apis - **DescribeModel** : Get detail runtime status of default version of a model - **SetDefault** : Set any registered version of a model as default version +For all Management API requests, TorchServe requires the correct Management token to be included or token authorization must be disabled. For more details see [token authorization documentation](./token_authorization_api.md) + By default, TorchServe listens on port 7070 for the gRPC Inference API and 7071 for the gRPC Management API on localhost. To configure gRPC APIs on different addresses and ports refer [configuration documentation](configuration.md) @@ -45,7 +49,7 @@ pip install -U grpcio protobuf grpcio-tools googleapis-common-protos ```bash mkdir models -torchserve --start --model-store models/ +torchserve --start --disable-token-auth --enable-model-api --model-store models/ ``` - Generate python gRPC client stub using the proto files @@ -56,21 +60,44 @@ python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/serv - Register densenet161 model +__Note__: To use this API after TorchServe starts, model API control has to be enabled. Add `--enable-model-api` to command line when starting TorchServe to enable the use of this API. For more details see [model API control](./model_api_control.md) + +If token authorization is disabled, use: ```bash python ts_scripts/torchserve_grpc_client.py register densenet161 +``` + +If token authorization is enabled, use: +```bash +python ts_scripts/torchserve_grpc_client.py register densenet161 --auth-token ``` - Run inference using +If token authorization is disabled, use: ```bash python ts_scripts/torchserve_grpc_client.py infer densenet161 examples/image_classifier/kitten.jpg +``` + +If token authorization is enabled, use: +```bash +python ts_scripts/torchserve_grpc_client.py infer densenet161 examples/image_classifier/kitten.jpg --auth-token ``` - Unregister densenet161 model +__Note__: To use this API after TorchServe starts, model API control has to be enabled. Add `--enable-model-api` to command line when starting TorchServe to enable the use of this API. For more details see [model API control](./model_api_control.md) + +If token authorization is disabled, use: ```bash python ts_scripts/torchserve_grpc_client.py unregister densenet161 ``` + +If token authorization is enabled, use: +```bash +python ts_scripts/torchserve_grpc_client.py unregister densenet161 --auth-token +``` + ## GRPC Server Side Streaming TorchServe GRPC APIs adds a server side streaming of the inference API "StreamPredictions" to allow a sequence of inference responses to be sent over the same GRPC stream. This new API is only recommended for use case when the inference latency of the full response is high and the inference intermediate results are sent to client. An example could be LLMs for generative applications, where generating "n" number of tokens can have high latency, in this case user can receive each generated token once ready until the full response completes. This new API automatically forces the batchSize to be one. diff --git a/examples/stateful/sequence_batching/Readme.md b/examples/stateful/sequence_batching/Readme.md index d1cae6c257..a3311666cd 100644 --- a/examples/stateful/sequence_batching/Readme.md +++ b/examples/stateful/sequence_batching/Readme.md @@ -113,7 +113,7 @@ python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/serv * Start TorchServe ```bash -torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties +torchserve --ncs --start --disable-token-auth --model-store models --model stateful.mar --ts-config examples/stateful/config.properties ``` * Run sequence inference via GRPC client diff --git a/examples/stateful/sequence_continuous_batching/Readme.md b/examples/stateful/sequence_continuous_batching/Readme.md index 7d4e9a9ed9..83d50db9ed 100644 --- a/examples/stateful/sequence_continuous_batching/Readme.md +++ b/examples/stateful/sequence_continuous_batching/Readme.md @@ -184,7 +184,7 @@ python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/serv * Start TorchServe ```bash -torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties +torchserve --ncs --start --disable-token-auth --model-store models --model stateful.mar --ts-config examples/stateful/config.properties ``` * Run sequence inference via GRPC client diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java index 9d26ff765f..4e73e3ad56 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelArchive.java @@ -54,13 +54,13 @@ public static ModelArchive downloadModel( throw new ModelNotFoundException("empty url"); } - String marFileName = ArchiveUtils.getFilenameFromUrl(url); - File modelLocation = new File(modelStore, marFileName); - if (url.contains("..")) { throw new ModelNotFoundException("Relative path is not allowed in url: " + url); } + String marFileName = ArchiveUtils.getFilenameFromUrl(url); + File modelLocation = new File(modelStore, marFileName); + try { ArchiveUtils.downloadArchive( allowedUrls, modelLocation, marFileName, url, s3SseKmsEnabled); diff --git a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java index bcad57cafa..cbc4613dfb 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ModelServer.java @@ -40,7 +40,6 @@ import org.pytorch.serve.archive.model.ModelNotFoundException; import org.pytorch.serve.grpcimpl.GRPCInterceptor; import org.pytorch.serve.grpcimpl.GRPCServiceFactory; -import org.pytorch.serve.http.TokenAuthorizationHandler; import org.pytorch.serve.http.messages.RegisterModelRequest; import org.pytorch.serve.metrics.MetricCache; import org.pytorch.serve.metrics.MetricManager; @@ -54,6 +53,7 @@ import org.pytorch.serve.util.Connector; import org.pytorch.serve.util.ConnectorType; import org.pytorch.serve.util.ServerGroups; +import org.pytorch.serve.util.TokenAuthorization; import org.pytorch.serve.wlm.Model; import org.pytorch.serve.wlm.ModelManager; import org.pytorch.serve.wlm.WorkLoadManager; @@ -87,7 +87,7 @@ public static void main(String[] args) { ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd); ConfigManager.init(arguments); ConfigManager configManager = ConfigManager.getInstance(); - TokenAuthorizationHandler.setupToken(); + TokenAuthorization.init(); PluginsManager.getInstance().initialize(); MetricCache.init(); InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); @@ -465,7 +465,7 @@ private Server startGRPCServer(ConnectorType connectorType) throws IOException { .addService( ServerInterceptors.intercept( GRPCServiceFactory.getgRPCService(connectorType), - new GRPCInterceptor())); + new GRPCInterceptor(connectorType))); if (connectorType == ConnectorType.INFERENCE_CONNECTOR && ConfigManager.getInstance().isOpenInferenceProtocol()) { @@ -474,7 +474,7 @@ private Server startGRPCServer(ConnectorType connectorType) throws IOException { ServerInterceptors.intercept( GRPCServiceFactory.getgRPCService( ConnectorType.OPEN_INFERENCE_CONNECTOR), - new GRPCInterceptor())); + new GRPCInterceptor(connectorType))); } if (configManager.isGRPCSSLEnabled()) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java index b362ceb958..b9cec4e07b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java @@ -19,7 +19,7 @@ import org.pytorch.serve.servingsdk.impl.PluginsManager; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.ConnectorType; -import org.pytorch.serve.util.TokenType; +import org.pytorch.serve.util.TokenAuthorization.TokenType; import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler; import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler; import org.slf4j.Logger; diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java index 252427ca4c..e3c76abf9b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/GRPCInterceptor.java @@ -8,13 +8,32 @@ import io.grpc.ServerInterceptor; import io.grpc.Status; import org.pytorch.serve.http.Session; +import org.pytorch.serve.util.ConnectorType; +import org.pytorch.serve.util.TokenAuthorization; +import org.pytorch.serve.util.TokenAuthorization.TokenType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class GRPCInterceptor implements ServerInterceptor { + private TokenType tokenType; + private static final Metadata.Key tokenAuthHeaderKey = + Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER); private static final Logger logger = LoggerFactory.getLogger("ACCESS_LOG"); + public GRPCInterceptor(ConnectorType connectorType) { + switch (connectorType) { + case MANAGEMENT_CONNECTOR: + tokenType = TokenType.MANAGEMENT; + break; + case INFERENCE_CONNECTOR: + tokenType = TokenType.INFERENCE; + break; + default: + tokenType = TokenType.INFERENCE; + } + } + @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { @@ -23,6 +42,17 @@ public ServerCall.Listener interceptCall( String serviceName = call.getMethodDescriptor().getFullMethodName(); Session session = new Session(inetSocketString, serviceName); + if (TokenAuthorization.isEnabled()) { + if (!headers.containsKey(tokenAuthHeaderKey) + || !checkTokenAuthorization(headers.get(tokenAuthHeaderKey))) { + call.close( + Status.PERMISSION_DENIED.withDescription( + "Token Authorization failed. Token either incorrect, expired, or not provided correctly"), + new Metadata()); + return new ServerCall.Listener() {}; + } + } + return next.startCall( new ForwardingServerCall.SimpleForwardingServerCall(call) { @Override @@ -34,4 +64,13 @@ public void close(final Status status, final Metadata trailers) { }, headers); } + + private Boolean checkTokenAuthorization(String tokenAuthHeaderValue) { + if (tokenAuthHeaderValue == null) { + return false; + } + String token = TokenAuthorization.parseTokenFromBearerTokenHeader(tokenAuthHeaderValue); + + return TokenAuthorization.checkTokenAuthorization(token, tokenType); + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java index 33b13997a3..cc5e9e3c2f 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java +++ b/frontend/server/src/main/java/org/pytorch/serve/grpcimpl/ManagementImpl.java @@ -23,6 +23,7 @@ import org.pytorch.serve.job.GRPCJob; import org.pytorch.serve.job.Job; import org.pytorch.serve.util.ApiUtils; +import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.GRPCUtils; import org.pytorch.serve.util.JsonUtils; import org.pytorch.serve.util.messages.RequestInput; @@ -32,8 +33,13 @@ import org.slf4j.LoggerFactory; public class ManagementImpl extends ManagementAPIsServiceImplBase { + private ConfigManager configManager; private static final Logger logger = LoggerFactory.getLogger(ManagementImpl.class); + public ManagementImpl() { + configManager = ConfigManager.getInstance(); + } + @Override public void describeModel( DescribeModelRequest request, StreamObserver responseObserver) { @@ -117,6 +123,13 @@ public void registerModel( StatusResponse statusResponse; try { + if (!configManager.isModelApiEnabled()) { + sendErrorResponse( + responseObserver, + Status.PERMISSION_DENIED, + new ModelException("Model API disabled")); + return; + } statusResponse = ApiUtils.registerModel(registerModelRequest); sendStatusResponse(responseObserver, statusResponse); } catch (InternalServerException e) { @@ -205,6 +218,13 @@ public void unregisterModel( .asRuntimeException()); }); try { + if (!configManager.isModelApiEnabled()) { + sendErrorResponse( + responseObserver, + Status.PERMISSION_DENIED, + new ModelException("Model API disabled")); + return; + } String modelName = request.getModelName(); if (modelName == null || ("").equals(modelName)) { sendErrorResponse( diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java index 065ba80762..65a709d681 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java @@ -1,31 +1,17 @@ package org.pytorch.serve.http; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonObject; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.QueryStringDecoder; -import java.io.File; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.attribute.PosixFilePermission; -import java.nio.file.attribute.PosixFilePermissions; -import java.security.SecureRandom; -import java.time.Instant; -import java.util.Base64; import java.util.List; import java.util.Map; -import java.util.Set; import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.InvalidKeyException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.workflow.WorkflowException; -import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; -import org.pytorch.serve.util.TokenType; +import org.pytorch.serve.util.TokenAuthorization; +import org.pytorch.serve.util.TokenAuthorization.TokenType; import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,12 +22,8 @@ *

This class // */ public class TokenAuthorizationHandler extends HttpRequestHandlerChain { - + private TokenType tokenType; private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class); - private static TokenType tokenType; - private static Boolean tokenEnabled = false; - private static Token token; - private static Object tokenObject; /** Creates a new {@code InferenceRequestHandler} instance. */ public TokenAuthorizationHandler(TokenType type) { @@ -56,72 +38,43 @@ public void handleRequest( String[] segments) throws ModelException, DownloadArchiveException, WorkflowException, WorkerInitializationException { - if (tokenEnabled) { + if (TokenAuthorization.isEnabled()) { if (tokenType == TokenType.MANAGEMENT) { if (req.toString().contains("/token")) { try { - checkTokenAuthorization(req, "token"); + checkTokenAuthorization(req, TokenType.TOKEN_API); String queryResponse = parseQuery(req); - String resp = token.updateKeyFile(queryResponse); + String resp = + TokenAuthorization.updateKeyFile( + TokenType.valueOf(queryResponse.toUpperCase())); NettyUtils.sendJsonResponse(ctx, resp); return; } catch (Exception e) { - logger.error("Key file updated unsuccessfully"); + logger.error("Failed to update key file"); throw new InvalidKeyException( "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); } } else { - checkTokenAuthorization(req, "management"); + checkTokenAuthorization(req, TokenType.MANAGEMENT); } } else if (tokenType == TokenType.INFERENCE) { - checkTokenAuthorization(req, "inference"); + checkTokenAuthorization(req, TokenType.INFERENCE); } } chain.handleRequest(ctx, req, decoder, segments); } - public static void setupToken() { - if (!ConfigManager.getInstance().getDisableTokenAuthorization()) { - try { - token = new Token(); - if (token.generateKeyFile("token")) { - String loggingMessage = - "\n######\n" - + "TorchServe now enforces token authorization by default.\n" - + "This requires the correct token to be provided when calling an API.\n" - + "Key file located at " - + ConfigManager.getInstance().getModelServerHome() - + "/key_file.json\n" - + "Check token authorization documenation for information: https://github.com/pytorch/serve/blob/master/docs/token_authorization_api.md \n" - + "######\n"; - logger.info(loggingMessage); - } - } catch (IOException e) { - e.printStackTrace(); - logger.error("Token Authorization setup unsuccessfully"); - throw new IllegalStateException("Token Authorization setup unsuccessfully", e); - } - tokenEnabled = true; - } - } - - private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException { + private void checkTokenAuthorization(FullHttpRequest req, TokenType tokenType) + throws ModelException { String tokenBearer = req.headers().get("Authorization"); if (tokenBearer == null) { throw new InvalidKeyException( - "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); + "Token Authorization failed. Token either incorrect, expired, or not provided correctly"); } - String[] arrOfStr = tokenBearer.split(" ", 2); - if (arrOfStr.length == 1) { + String token = TokenAuthorization.parseTokenFromBearerTokenHeader(tokenBearer); + if (!TokenAuthorization.checkTokenAuthorization(token, tokenType)) { throw new InvalidKeyException( - "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); - } - String currToken = arrOfStr[1]; - - boolean result = token.checkTokenAuthorization(currToken, type); - if (!result) { - throw new InvalidKeyException( - "Token Authentication failed. Token either incorrect, expired, or not provided correctly"); + "Token Authorization failed. Token either incorrect, expired, or not provided correctly"); } } @@ -140,142 +93,3 @@ private String parseQuery(FullHttpRequest req) { return "NO TYPE PROVIDED"; } } - -class Token { - private static String apiKey; - private static String managementKey; - private static String inferenceKey; - private static Instant managementExpirationTimeMinutes; - private static Instant inferenceExpirationTimeMinutes; - private SecureRandom secureRandom = new SecureRandom(); - private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); - private String fileName = "key_file.json"; - private String filePath = ConfigManager.getInstance().getModelServerHome(); - - public String updateKeyFile(String queryResponse) throws IOException { - String test = ""; - if ("management".equals(queryResponse)) { - generateKeyFile("management"); - } else if ("inference".equals(queryResponse)) { - generateKeyFile("inference"); - } else { - test = "{\n\t\"Error\": " + queryResponse + "\n}\n"; - } - return test; - } - - public String generateKey() { - byte[] randomBytes = new byte[6]; - secureRandom.nextBytes(randomBytes); - return baseEncoder.encodeToString(randomBytes); - } - - public Instant generateTokenExpiration() { - long secondsToAdd = (long) (ConfigManager.getInstance().getTimeToExpiration() * 60); - return Instant.now().plusSeconds(secondsToAdd); - } - - // generates a key file with new keys depending on the parameter provided - public boolean generateKeyFile(String type) throws IOException { - String userDirectory = filePath + "/" + fileName; - File file = new File(userDirectory); - if (!file.createNewFile() && !file.exists()) { - return false; - } - if (apiKey == null) { - apiKey = generateKey(); - } - switch (type) { - case "management": - managementKey = generateKey(); - managementExpirationTimeMinutes = generateTokenExpiration(); - break; - case "inference": - inferenceKey = generateKey(); - inferenceExpirationTimeMinutes = generateTokenExpiration(); - break; - default: - managementKey = generateKey(); - inferenceKey = generateKey(); - inferenceExpirationTimeMinutes = generateTokenExpiration(); - managementExpirationTimeMinutes = generateTokenExpiration(); - } - - JsonObject parentObject = new JsonObject(); - - JsonObject managementObject = new JsonObject(); - managementObject.addProperty("key", managementKey); - managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString()); - parentObject.add("management", managementObject); - - JsonObject inferenceObject = new JsonObject(); - inferenceObject.addProperty("key", inferenceKey); - inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString()); - parentObject.add("inference", inferenceObject); - - JsonObject apiObject = new JsonObject(); - apiObject.addProperty("key", apiKey); - parentObject.add("API", apiObject); - - Files.write( - Paths.get(fileName), - new GsonBuilder() - .setPrettyPrinting() - .create() - .toJson(parentObject) - .getBytes(StandardCharsets.UTF_8)); - - if (!setFilePermissions()) { - try { - Files.delete(Paths.get(fileName)); - } catch (IOException e) { - return false; - } - return false; - } - return true; - } - - public boolean setFilePermissions() { - Path path = Paths.get(fileName); - try { - Set permissions = PosixFilePermissions.fromString("rw-------"); - Files.setPosixFilePermissions(path, permissions); - } catch (Exception e) { - return false; - } - return true; - } - - // checks the token provided in the http with the saved keys depening on parameters - public boolean checkTokenAuthorization(String token, String type) { - String key; - Instant expiration; - switch (type) { - case "token": - key = apiKey; - expiration = null; - break; - case "management": - key = managementKey; - expiration = managementExpirationTimeMinutes; - break; - default: - key = inferenceKey; - expiration = inferenceExpirationTimeMinutes; - } - - if (token.equals(key)) { - if (expiration != null && isTokenExpired(expiration)) { - return false; - } - } else { - return false; - } - return true; - } - - public boolean isTokenExpired(Instant expirationTime) { - return !(Instant.now().isBefore(expirationTime)); - } -} diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java index eebb3aeb82..d5f041a106 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java @@ -78,7 +78,8 @@ public void handleRequest( if (HttpMethod.GET.equals(method)) { handleListModels(ctx, decoder); return; - } else if (HttpMethod.POST.equals(method) && isModeEnabled()) { + } else if (HttpMethod.POST.equals(method) + && configManager.isModelApiEnabled()) { handleRegisterModel(ctx, decoder, req); return; } @@ -97,7 +98,7 @@ public void handleRequest( } else { handleScaleModel(ctx, decoder, segments[2], modelVersion); } - } else if (HttpMethod.DELETE.equals(method) && isModeEnabled()) { + } else if (HttpMethod.DELETE.equals(method) && configManager.isModelApiEnabled()) { handleUnregisterModel(ctx, segments[2], modelVersion); } else if (HttpMethod.OPTIONS.equals(method)) { ModelManager modelManager = ModelManager.getInstance(); @@ -133,10 +134,6 @@ private boolean isManagementReq(String[] segments) { || endpointMap.containsKey(segments[1]); } - private boolean isModeEnabled() { - return configManager.getModelControlMode(); - } - private boolean isKFV1ManagementReq(String[] segments) { return segments.length == 4 && "v1".equals(segments[1]) && "models".equals(segments[2]); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index 2780f88548..12ece0b54b 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -498,7 +498,7 @@ public int getNumberOfGpu() { return getIntProperty(TS_NUMBER_OF_GPU, 0); } - public boolean getModelControlMode() { + public boolean isModelApiEnabled() { return Boolean.parseBoolean(getProperty(TS_ENABLE_MODEL_API, "false")); } @@ -832,7 +832,7 @@ public String dumpConfigurations() { + "\nSystem metrics command: " + (getSystemMetricsCmd().isEmpty() ? "default" : getSystemMetricsCmd()) + "\nModel API enabled: " - + (getModelControlMode() ? "true" : "false"); + + (isModelApiEnabled() ? "true" : "false"); } public boolean useNativeIo() { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/TokenAuthorization.java b/frontend/server/src/main/java/org/pytorch/serve/util/TokenAuthorization.java new file mode 100644 index 0000000000..885540e089 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/util/TokenAuthorization.java @@ -0,0 +1,209 @@ +package org.pytorch.serve.util; + +import com.google.gson.GsonBuilder; +import com.google.gson.JsonObject; +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.attribute.PosixFilePermission; +import java.nio.file.attribute.PosixFilePermissions; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.Base64; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TokenAuthorization { + private static String apiKey; + private static String managementKey; + private static String inferenceKey; + private static Instant managementExpirationTimeMinutes; + private static Instant inferenceExpirationTimeMinutes; + private static Boolean tokenAuthEnabled; + private static String keyFilePath; + private static final SecureRandom secureRandom = new SecureRandom(); + private static final Base64.Encoder baseEncoder = Base64.getUrlEncoder(); + private static final Pattern bearerTokenHeaderPattern = Pattern.compile("^Bearer\\s+(\\S+)$"); + private static final Logger logger = LoggerFactory.getLogger(TokenAuthorization.class); + + public enum TokenType { + INFERENCE, + MANAGEMENT, + TOKEN_API + } + + public static void init() { + if (ConfigManager.getInstance().getDisableTokenAuthorization()) { + tokenAuthEnabled = false; + return; + } + + tokenAuthEnabled = true; + apiKey = generateKey(); + keyFilePath = Paths.get(System.getProperty("user.dir"), "key_file.json").toString(); + + try { + if (generateKeyFile(TokenType.TOKEN_API)) { + String loggingMessage = + "\n######\n" + + "TorchServe now enforces token authorization by default.\n" + + "This requires the correct token to be provided when calling an API.\n" + + "Key file located at " + + keyFilePath + + "\nCheck token authorization documenation for information: https://github.com/pytorch/serve/blob/master/docs/token_authorization_api.md \n" + + "######\n"; + logger.info(loggingMessage); + } + } catch (IOException e) { + e.printStackTrace(); + logger.error("Token Authorization setup unsuccessful"); + throw new IllegalStateException("Token Authorization setup unsuccessful", e); + } + } + + public static Boolean isEnabled() { + return tokenAuthEnabled; + } + + public static String updateKeyFile(TokenType tokenType) throws IOException { + String status = ""; + + switch (tokenType) { + case MANAGEMENT: + generateKeyFile(TokenType.MANAGEMENT); + break; + case INFERENCE: + generateKeyFile(TokenType.INFERENCE); + break; + default: + status = "{\n\t\"Error\": " + tokenType + "\n}\n"; + } + + return status; + } + + public static boolean checkTokenAuthorization(String token, TokenType tokenType) { + String key; + Instant expiration; + switch (tokenType) { + case TOKEN_API: + key = apiKey; + expiration = null; + break; + case MANAGEMENT: + key = managementKey; + expiration = managementExpirationTimeMinutes; + break; + default: + key = inferenceKey; + expiration = inferenceExpirationTimeMinutes; + } + + if (token.equals(key)) { + if (expiration != null && isTokenExpired(expiration)) { + return false; + } + } else { + return false; + } + return true; + } + + public static String parseTokenFromBearerTokenHeader(String bearerTokenHeader) { + String token = ""; + Matcher matcher = bearerTokenHeaderPattern.matcher(bearerTokenHeader); + if (matcher.matches()) { + token = matcher.group(1); + } + + return token; + } + + private static String generateKey() { + byte[] randomBytes = new byte[6]; + secureRandom.nextBytes(randomBytes); + return baseEncoder.encodeToString(randomBytes); + } + + private static Instant generateTokenExpiration() { + long secondsToAdd = (long) (ConfigManager.getInstance().getTimeToExpiration() * 60); + return Instant.now().plusSeconds(secondsToAdd); + } + + private static boolean generateKeyFile(TokenType tokenType) throws IOException { + File file = new File(keyFilePath); + if (!file.createNewFile() && !file.exists()) { + return false; + } + switch (tokenType) { + case MANAGEMENT: + managementKey = generateKey(); + managementExpirationTimeMinutes = generateTokenExpiration(); + break; + case INFERENCE: + inferenceKey = generateKey(); + inferenceExpirationTimeMinutes = generateTokenExpiration(); + break; + default: + managementKey = generateKey(); + inferenceKey = generateKey(); + inferenceExpirationTimeMinutes = generateTokenExpiration(); + managementExpirationTimeMinutes = generateTokenExpiration(); + } + + JsonObject parentObject = new JsonObject(); + + JsonObject managementObject = new JsonObject(); + managementObject.addProperty("key", managementKey); + managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString()); + parentObject.add("management", managementObject); + + JsonObject inferenceObject = new JsonObject(); + inferenceObject.addProperty("key", inferenceKey); + inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString()); + parentObject.add("inference", inferenceObject); + + JsonObject apiObject = new JsonObject(); + apiObject.addProperty("key", apiKey); + parentObject.add("API", apiObject); + + Files.write( + Paths.get(keyFilePath), + new GsonBuilder() + .setPrettyPrinting() + .create() + .toJson(parentObject) + .getBytes(StandardCharsets.UTF_8)); + + if (!setFilePermissions()) { + try { + Files.delete(Paths.get(keyFilePath)); + } catch (IOException e) { + return false; + } + return false; + } + return true; + } + + private static boolean setFilePermissions() { + Path path = Paths.get(keyFilePath); + try { + Set permissions = PosixFilePermissions.fromString("rw-------"); + Files.setPosixFilePermissions(path, permissions); + } catch (Exception e) { + return false; + } + return true; + } + + private static boolean isTokenExpired(Instant expirationTime) { + return !(Instant.now().isBefore(expirationTime)); + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java b/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java deleted file mode 100644 index 9bf6318d2e..0000000000 --- a/frontend/server/src/main/java/org/pytorch/serve/util/TokenType.java +++ /dev/null @@ -1,7 +0,0 @@ -package org.pytorch.serve.util; - -public enum TokenType { - INFERENCE, - MANAGEMENT, - TOKEN_API -} diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index b3929fad98..518143ec04 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -46,6 +46,7 @@ import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.ConnectorType; import org.pytorch.serve.util.JsonUtils; +import org.pytorch.serve.util.TokenAuthorization; import org.pytorch.serve.wlm.Model; import org.testng.Assert; import org.testng.SkipException; @@ -77,7 +78,9 @@ public void beforeSuite() InvalidSnapshotException { ConfigManager.init(new ConfigManager.Arguments()); configManager = ConfigManager.getInstance(); + configManager.setProperty("disable_token_authorization", "true"); configManager.setProperty("metrics_mode", "prometheus"); + TokenAuthorization.init(); PluginsManager.getInstance().initialize(); MetricCache.init(); diff --git a/test/pytest/test_gRPC_allowed_urls.py b/test/pytest/test_gRPC_allowed_urls.py new file mode 100644 index 0000000000..77ad3f1982 --- /dev/null +++ b/test/pytest/test_gRPC_allowed_urls.py @@ -0,0 +1,73 @@ +import os + +import management_pb2 +import pytest +import test_gRPC_utils +import test_utils + +CONFIG_FILE = test_utils.ROOT_DIR + "/config.properties" + + +def setup_module(module): + test_utils.torchserve_cleanup() + with open(CONFIG_FILE, "w") as f: + f.write( + "allowed_urls=https://torchserve.s3.amazonaws.com/mar_files/densenet161.mar" + ) + test_utils.start_torchserve(snapshot_file=CONFIG_FILE, gen_mar=False) + + +def teardown_module(module): + test_utils.torchserve_cleanup() + if os.path.exists(CONFIG_FILE): + os.remove(CONFIG_FILE) + + +def register(stub, model_url, model_name, metadata): + params = { + "url": model_url, + "initial_workers": 1, + "synchronous": True, + "model_name": model_name, + } + stub.RegisterModel( + management_pb2.RegisterModelRequest(**params), metadata=metadata, timeout=120 + ) + + +def test_gRPC_allowed_urls(): + management_stub = test_gRPC_utils.get_management_stub() + + # register model with permitted url + metadata = (("protocol", "gRPC"),) + register( + management_stub, + "https://torchserve.s3.amazonaws.com/mar_files/densenet161.mar", + "densenet161", + metadata, + ) + + # register model with unpermitted url + with pytest.raises( + Exception, match=r".*Given URL.*does not match any allowed URL.*" + ): + register( + management_stub, + "https://torchserve.s3.amazonaws.com/mar_files/resnet-18.mar", + "resnet-18", + metadata, + ) + + +def test_gRPC_allowed_urls_relative_path(): + management_stub = test_gRPC_utils.get_management_stub() + + # register model with relative path in model url + metadata = (("protocol", "gRPC"),) + with pytest.raises(Exception, match=r".*Relative path is not allowed in url.*"): + register( + management_stub, + "https://torchserve.s3.amazonaws.com/mar_files/../mar_files/densenet161.mar", + "densenet161-relative-path", + metadata, + ) diff --git a/test/pytest/test_gRPC_inference_api.py b/test/pytest/test_gRPC_inference_api.py index 888e6e756c..65a6f717b4 100644 --- a/test/pytest/test_gRPC_inference_api.py +++ b/test/pytest/test_gRPC_inference_api.py @@ -12,17 +12,19 @@ inference_stream_data_json = "../postman/inference_stream_data.json" inference_stream2_data_json = "../postman/inference_stream2_data.json" config_file = test_utils.ROOT_DIR + "/config.properties" -with open(config_file, "w") as f: - f.write("install_py_dep_per_model=true") def setup_module(module): test_utils.torchserve_cleanup() + with open(config_file, "w") as f: + f.write("install_py_dep_per_model=true") test_utils.start_torchserve(snapshot_file=config_file) def teardown_module(module): test_utils.torchserve_cleanup() + if os.path.exists(config_file): + os.remove(config_file) def __get_change(current, previous): diff --git a/test/pytest/test_gRPC_management_apis.py b/test/pytest/test_gRPC_management_apis.py index d805d813c0..b1dbe7f268 100644 --- a/test/pytest/test_gRPC_management_apis.py +++ b/test/pytest/test_gRPC_management_apis.py @@ -1,10 +1,10 @@ -import grpc import json import os -import test_gRPC_utils -import test_utils from urllib import parse +import grpc +import test_gRPC_utils +import test_utils management_data_json = "../postman/management_data.json" @@ -26,12 +26,22 @@ def __get_query_params(parsed_url): query_params = dict(parse.parse_qsl(parsed_url.query)) for key, value in query_params.items(): - if key in ['min_worker', 'max_worker', 'initial_workers', 'timeout', 'number_gpu', 'batch_size', - 'max_batch_delay', 'response_timeout', "limit", "next_page_token"]: + if key in [ + "min_worker", + "max_worker", + "initial_workers", + "timeout", + "number_gpu", + "batch_size", + "max_batch_delay", + "response_timeout", + "limit", + "next_page_token", + ]: query_params[key] = int(query_params[key]) - if key in ['synchronous']: + if key in ["synchronous"]: query_params[key] = bool(query_params[key]) - if key in ['url'] and query_params[key].startswith('{{mar_path_'): + if key in ["url"] and query_params[key].startswith("{{mar_path_"): query_params[key] = test_utils.mar_file_table[query_params[key][2:-2]] return query_params @@ -43,9 +53,7 @@ def __get_path_params(parsed_url): if len(path) == 1: return {} - path_params = { - "model_name": path[1] - } + path_params = {"model_name": path[1]} if len(path) == 3: path_params.update({"model_version": path[2]}) @@ -89,21 +97,21 @@ def test_management_apis(): "scale": "ScaleWorker", "set_default": "SetDefault", "list": "ListModels", - "describe": "DescribeModel" + "describe": "DescribeModel", } - with open(os.path.join(os.path.dirname(__file__), management_data_json), 'rb') as f: + with open(os.path.join(os.path.dirname(__file__), management_data_json), "rb") as f: test_data = json.loads(f.read()) for item in test_data: try: - api_name = item['type'] + api_name = item["type"] api = globals()["__get_" + api_name + "_params"] - params = api(parse.urlsplit(item['path'])) + params = api(parse.urlsplit(item["path"])) test_gRPC_utils.run_management_api(api_mapping[api_name], **params) except grpc.RpcError as e: - if 'grpc_status_code' in item: - assert e.code().value[0] == item['grpc_status_code'] + if "grpc_status_code" in item: + assert e.code().value[0] == item["grpc_status_code"] except ValueError as e: # gRPC has more stricter check on the input types hence ignoring the test case from data file - continue \ No newline at end of file + continue diff --git a/test/pytest/test_gRPC_model_control.py b/test/pytest/test_gRPC_model_control.py new file mode 100644 index 0000000000..56039ac349 --- /dev/null +++ b/test/pytest/test_gRPC_model_control.py @@ -0,0 +1,58 @@ +import management_pb2 +import pytest +import test_gRPC_utils +import test_utils + + +def teardown_module(module): + test_utils.torchserve_cleanup() + + +def register(stub, model_name, metadata): + marfile = f"https://torchserve.s3.amazonaws.com/mar_files/{model_name}.mar" + params = { + "url": marfile, + "initial_workers": 1, + "synchronous": True, + "model_name": model_name, + } + stub.RegisterModel( + management_pb2.RegisterModelRequest(**params), metadata=metadata, timeout=120 + ) + + +def unregister(stub, model_name, metadata): + params = { + "model_name": model_name, + } + stub.UnregisterModel( + management_pb2.UnregisterModelRequest(**params), metadata=metadata, timeout=60 + ) + + +def test_grpc_register_model_with_model_control(): + test_utils.torchserve_cleanup() + test_utils.start_torchserve(enable_model_api=False, gen_mar=False) + management_stub = test_gRPC_utils.get_management_stub() + + metadata = (("protocol", "gRPC"),) + with pytest.raises(Exception, match=r".*Model API disabled.*"): + register(management_stub, "densenet161", metadata) + + test_utils.torchserve_cleanup() + + +def test_grpc_unregister_model_with_model_control(): + test_utils.torchserve_cleanup() + test_utils.start_torchserve( + enable_model_api=False, + gen_mar=False, + models="densenet161=https://torchserve.s3.amazonaws.com/mar_files/densenet161.mar", + ) + management_stub = test_gRPC_utils.get_management_stub() + + metadata = (("protocol", "gRPC"),) + with pytest.raises(Exception, match=r".*Model API disabled.*"): + unregister(management_stub, "densenet161", metadata) + + test_utils.torchserve_cleanup() diff --git a/test/pytest/test_gRPC_token_authorization.py b/test/pytest/test_gRPC_token_authorization.py new file mode 100644 index 0000000000..e532765d97 --- /dev/null +++ b/test/pytest/test_gRPC_token_authorization.py @@ -0,0 +1,112 @@ +import json +import os + +import inference_pb2 +import management_pb2 +import pytest +import test_gRPC_utils +import test_utils + +CURR_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def setup_module(module): + test_utils.torchserve_cleanup() + test_utils.start_torchserve(disable_token=False, gen_mar=False) + + +def teardown_module(module): + test_utils.torchserve_cleanup() + + +def register(stub, model_name, metadata): + marfile = f"https://torchserve.s3.amazonaws.com/mar_files/{model_name}.mar" + params = { + "url": marfile, + "initial_workers": 1, + "synchronous": True, + "model_name": model_name, + } + stub.RegisterModel( + management_pb2.RegisterModelRequest(**params), metadata=metadata, timeout=120 + ) + + +def unregister(stub, model_name, metadata): + params = { + "model_name": model_name, + } + stub.UnregisterModel( + management_pb2.UnregisterModelRequest(**params), metadata=metadata, timeout=60 + ) + + +def infer(stub, model_name, model_input, metadata): + with open(model_input, "rb") as f: + data = f.read() + + input_data = {"data": data} + params = {"model_name": model_name, "input": input_data} + response = stub.Predictions( + inference_pb2.PredictionsRequest(**params), metadata=metadata, timeout=60 + ) + + return response.prediction.decode("utf-8") + + +def read_key_file(type): + json_file_path = os.path.join(CURR_DIR, "key_file.json") + with open(json_file_path) as json_file: + json_data = json.load(json_file) + + options = { + "management": json_data.get("management", {}).get("key", "NOT_PRESENT"), + "inference": json_data.get("inference", {}).get("key", "NOT_PRESENT"), + "token": json_data.get("API", {}).get("key", "NOT_PRESENT"), + } + key = options.get(type, "Invalid data type") + return key + + +def test_grpc_api_with_token_auth(): + management_stub = test_gRPC_utils.get_management_stub() + inference_stub = test_gRPC_utils.get_inference_stub() + management_key = read_key_file("management") + inference_key = read_key_file("inference") + + # register model with incorrect authorization token + metadata = (("protocol", "gRPC"), ("authorization", f"Bearer incorrect-token")) + with pytest.raises(Exception, match=r".*Token Authorization failed.*"): + register(management_stub, "densenet161", metadata) + + # register model with correct authorization token + metadata = (("protocol", "gRPC"), ("authorization", f"Bearer {management_key}")) + register(management_stub, "densenet161", metadata) + + # make inference request with incorrect auth token + metadata = (("protocol", "gRPC"), ("authorization", f"Bearer incorrect-token")) + with pytest.raises(Exception, match=r".*Token Authorization failed.*"): + infer( + inference_stub, + "densenet161", + os.path.join(test_utils.REPO_ROOT, "examples/image_classifier/kitten.jpg"), + metadata, + ) + + # make inference request with correct auth token + metadata = (("protocol", "gRPC"), ("authorization", f"Bearer {inference_key}")) + infer( + inference_stub, + "densenet161", + os.path.join(test_utils.REPO_ROOT, "examples/image_classifier/kitten.jpg"), + metadata, + ) + + # unregister model with incorrect authorization token + metadata = (("protocol", "gRPC"), ("authorization", f"Bearer incorrect-token")) + with pytest.raises(Exception, match=r".*Token Authorization failed.*"): + unregister(management_stub, "densenet161", metadata) + + # unregister model with correct authorization token + metadata = (("protocol", "gRPC"), ("authorization", f"Bearer {management_key}")) + unregister(management_stub, "densenet161", metadata) diff --git a/test/pytest/test_gRPC_utils.py b/test/pytest/test_gRPC_utils.py index 3c786ed0a4..450710a27c 100644 --- a/test/pytest/test_gRPC_utils.py +++ b/test/pytest/test_gRPC_utils.py @@ -5,18 +5,19 @@ def get_inference_stub(): - channel = grpc.insecure_channel('localhost:7070') + channel = grpc.insecure_channel("localhost:7070") stub = inference_pb2_grpc.InferenceAPIsServiceStub(channel) return stub def get_management_stub(): - channel = grpc.insecure_channel('localhost:7071') + channel = grpc.insecure_channel("localhost:7071") stub = management_pb2_grpc.ManagementAPIsServiceStub(channel) return stub def run_management_api(api_name, **kwargs): management_stub = get_management_stub() - return getattr(management_stub, api_name)(getattr(management_pb2, f"{api_name}Request")(**kwargs)) - + return getattr(management_stub, api_name)( + getattr(management_pb2, f"{api_name}Request")(**kwargs) + ) diff --git a/ts_scripts/torchserve_grpc_client.py b/ts_scripts/torchserve_grpc_client.py index 396f78d8e5..1afbc12b1c 100644 --- a/ts_scripts/torchserve_grpc_client.py +++ b/ts_scripts/torchserve_grpc_client.py @@ -57,7 +57,7 @@ def infer_stream(stub, model_name, model_input, metadata): exit(1) -def infer_stream2(model_name, sequence_id, input_files): +def infer_stream2(model_name, sequence_id, input_files, metadata): response_queue = queue.Queue() process_response_func = partial( InferStream2.default_process_response, response_queue @@ -69,6 +69,7 @@ def infer_stream2(model_name, sequence_id, input_files): model_name=model_name, sequence_id=sequence_id, process_response=process_response_func, + metadata=metadata, ) sequence = input_files.split(",") @@ -89,7 +90,7 @@ def infer_stream2(model_name, sequence_id, input_files): client.stop() -def register(stub, model_name, mar_set_str): +def register(stub, model_name, mar_set_str, metadata): mar_set = set() if mar_set_str: mar_set = set(mar_set_str.split(",")) @@ -108,7 +109,9 @@ def register(stub, model_name, mar_set_str): "model_name": model_name, } try: - response = stub.RegisterModel(management_pb2.RegisterModelRequest(**params)) + response = stub.RegisterModel( + management_pb2.RegisterModelRequest(**params), metadata=metadata + ) print(f"Model {model_name} registered successfully") except grpc.RpcError as e: print(f"Failed to register model {model_name}.") @@ -116,10 +119,11 @@ def register(stub, model_name, mar_set_str): exit(1) -def unregister(stub, model_name): +def unregister(stub, model_name, metadata): try: response = stub.UnregisterModel( - management_pb2.UnregisterModelRequest(model_name=model_name) + management_pb2.UnregisterModelRequest(model_name=model_name), + metadata=metadata, ) print(f"Model {model_name} unregistered successfully") except grpc.RpcError as e: @@ -170,13 +174,16 @@ def init_handler(self, response_iterator): self._handler.start() print("InferStream2 started") - def enqueue_request(self, model_input): + def enqueue_request(self, model_input, metadata): with open(model_input, "rb") as f: data = f.read() input_data = {"data": data} request = inference_pb2.PredictionsRequest( - model_name=self._model_name, sequence_id=self._sequence_id, input=input_data + model_name=self._model_name, + sequence_id=self._sequence_id, + input=input_data, + metadata=metadata, ) if self._alive: self._request_queue.put(request) @@ -233,7 +240,9 @@ def __init__(self): self._channel = grpc.insecure_channel("localhost:7070") self._stub = inference_pb2_grpc.InferenceAPIsServiceStub(self._channel) - def start_stream(self, model_name: str, sequence_id: str, process_response): + def start_stream( + self, model_name: str, sequence_id: str, process_response, metadata + ): if self._stream is not None: raise RuntimeError( "Cannot start InferStream2SimpleClient since " @@ -247,7 +256,7 @@ def start_stream(self, model_name: str, sequence_id: str, process_response): ) try: response_iterator = self._stub.StreamPredictions2( - RequestIterator(self._stream) + RequestIterator(self._stream), metadata=metadata ) self._stream.init_handler(response_iterator) except grpc.RpcError as e: @@ -276,6 +285,14 @@ def stop(self): default=None, help="Name of the model used.", ) + parent_parser.add_argument( + "--auth-token", + dest="auth_token", + type=str, + default=None, + required=False, + help="Authorization token", + ) parser = argparse.ArgumentParser( description="TorchServe gRPC client", @@ -333,16 +350,22 @@ def stop(self): ) args = parser.parse_args() - - metadata = (("protocol", "gRPC"), ("session_id", "12345")) + if args.auth_token: + metadata = ( + ("protocol", "gRPC"), + ("session_id", "12345"), + ("authorization", f"Bearer {args.auth_token}"), + ) + else: + metadata = (("protocol", "gRPC"), ("session_id", "12345")) if args.action == "infer": infer(get_inference_stub(), args.model_name, args.model_input, metadata) elif args.action == "infer_stream": - infer_stream(get_inference_stub(), args.model_name, args.model_input) + infer_stream(get_inference_stub(), args.model_name, args.model_input, metadata) elif args.action == "infer_stream2": - infer_stream2(args.model_name, args.sequence_id, args.input_files) + infer_stream2(args.model_name, args.sequence_id, args.input_files, metadata) elif args.action == "register": - register(get_management_stub(), args.model_name, args.mar_set) + register(get_management_stub(), args.model_name, args.mar_set, metadata) elif args.action == "unregister": - unregister(get_management_stub(), args.model_name) + unregister(get_management_stub(), args.model_name, metadata)