Skip to content

Commit

Permalink
Enable token authorization and model control for gRPC (#3238)
Browse files Browse the repository at this point in the history
* Refactor token auth implementation to extend to gRPC

* Implement token authorization for gRPC

* Add test for gRPC token authorization

* Enable model control for gRPC

* Refactor test for gRPC token authorization

* Add tets for gRPC model control

* Reorder relative path check in download URL

* add test for gRPC allowed urls

* Refactor bearer token parsing from header

* Fix linter errors

* Clean up config file in gRPC allowd urls test

* Refactor gRPC tests

* Fix typo in gRPC inference test

* Update documentation for gRPC token auth and model control

* Run formatJava target
  • Loading branch information
namannandan authored Jul 16, 2024
1 parent a0c3f9e commit 5dbf18e
Show file tree
Hide file tree
Showing 21 changed files with 646 additions and 267 deletions.
29 changes: 28 additions & 1 deletion docs/grpc_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ TorchServe provides following gRPCs apis
- **Predictions** : Gets predictions from the served model
- **StreamPredictions** : Gets server side streaming predictions from the saved model

For all Inference API requests, TorchServe requires the correct Inference token to be included or token authorization must be disable. For more details see [token authorization documentation](./token_authorization_api.md)

* [Management API](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/management.proto)
- **RegisterModel** : Serve a model/model-version on TorchServe
- **UnregisterModel** : Free up system resources by unregistering specific version of a model from TorchServe
Expand All @@ -19,6 +21,8 @@ TorchServe provides following gRPCs apis
- **DescribeModel** : Get detail runtime status of default version of a model
- **SetDefault** : Set any registered version of a model as default version

For all Management API requests, TorchServe requires the correct Management token to be included or token authorization must be disabled. For more details see [token authorization documentation](./token_authorization_api.md)

By default, TorchServe listens on port 7070 for the gRPC Inference API and 7071 for the gRPC Management API on localhost.
To configure gRPC APIs on different addresses and ports refer [configuration documentation](configuration.md)

Expand All @@ -45,7 +49,7 @@ pip install -U grpcio protobuf grpcio-tools googleapis-common-protos

```bash
mkdir models
torchserve --start --model-store models/
torchserve --start --disable-token-auth --enable-model-api --model-store models/
```

- Generate python gRPC client stub using the proto files
Expand All @@ -56,21 +60,44 @@ python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/serv

- Register densenet161 model

__Note__: To use this API after TorchServe starts, model API control has to be enabled. Add `--enable-model-api` to command line when starting TorchServe to enable the use of this API. For more details see [model API control](./model_api_control.md)

If token authorization is disabled, use:
```bash
python ts_scripts/torchserve_grpc_client.py register densenet161
```

If token authorization is enabled, use:
```bash
python ts_scripts/torchserve_grpc_client.py register densenet161 --auth-token <management-token>
```

- Run inference using

If token authorization is disabled, use:
```bash
python ts_scripts/torchserve_grpc_client.py infer densenet161 examples/image_classifier/kitten.jpg
```

If token authorization is enabled, use:
```bash
python ts_scripts/torchserve_grpc_client.py infer densenet161 examples/image_classifier/kitten.jpg --auth-token <inference-token>
```

- Unregister densenet161 model

__Note__: To use this API after TorchServe starts, model API control has to be enabled. Add `--enable-model-api` to command line when starting TorchServe to enable the use of this API. For more details see [model API control](./model_api_control.md)

If token authorization is disabled, use:
```bash
python ts_scripts/torchserve_grpc_client.py unregister densenet161
```

If token authorization is enabled, use:
```bash
python ts_scripts/torchserve_grpc_client.py unregister densenet161 --auth-token <management-token>
```

## GRPC Server Side Streaming
TorchServe GRPC APIs adds a server side streaming of the inference API "StreamPredictions" to allow a sequence of inference responses to be sent over the same GRPC stream. This new API is only recommended for use case when the inference latency of the full response is high and the inference intermediate results are sent to client. An example could be LLMs for generative applications, where generating "n" number of tokens can have high latency, in this case user can receive each generated token once ready until the full response completes. This new API automatically forces the batchSize to be one.

Expand Down
2 changes: 1 addition & 1 deletion examples/stateful/sequence_batching/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/serv
* Start TorchServe

```bash
torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
torchserve --ncs --start --disable-token-auth --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
```

* Run sequence inference via GRPC client
Expand Down
2 changes: 1 addition & 1 deletion examples/stateful/sequence_continuous_batching/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/serv
* Start TorchServe

```bash
torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
torchserve --ncs --start --disable-token-auth --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
```

* Run sequence inference via GRPC client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ public static ModelArchive downloadModel(
throw new ModelNotFoundException("empty url");
}

String marFileName = ArchiveUtils.getFilenameFromUrl(url);
File modelLocation = new File(modelStore, marFileName);

if (url.contains("..")) {
throw new ModelNotFoundException("Relative path is not allowed in url: " + url);
}

String marFileName = ArchiveUtils.getFilenameFromUrl(url);
File modelLocation = new File(modelStore, marFileName);

try {
ArchiveUtils.downloadArchive(
allowedUrls, modelLocation, marFileName, url, s3SseKmsEnabled);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.grpcimpl.GRPCInterceptor;
import org.pytorch.serve.grpcimpl.GRPCServiceFactory;
import org.pytorch.serve.http.TokenAuthorizationHandler;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.metrics.MetricCache;
import org.pytorch.serve.metrics.MetricManager;
Expand All @@ -54,6 +53,7 @@
import org.pytorch.serve.util.Connector;
import org.pytorch.serve.util.ConnectorType;
import org.pytorch.serve.util.ServerGroups;
import org.pytorch.serve.util.TokenAuthorization;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkLoadManager;
Expand Down Expand Up @@ -87,7 +87,7 @@ public static void main(String[] args) {
ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
ConfigManager.init(arguments);
ConfigManager configManager = ConfigManager.getInstance();
TokenAuthorizationHandler.setupToken();
TokenAuthorization.init();
PluginsManager.getInstance().initialize();
MetricCache.init();
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
Expand Down Expand Up @@ -465,7 +465,7 @@ private Server startGRPCServer(ConnectorType connectorType) throws IOException {
.addService(
ServerInterceptors.intercept(
GRPCServiceFactory.getgRPCService(connectorType),
new GRPCInterceptor()));
new GRPCInterceptor(connectorType)));

if (connectorType == ConnectorType.INFERENCE_CONNECTOR
&& ConfigManager.getInstance().isOpenInferenceProtocol()) {
Expand All @@ -474,7 +474,7 @@ private Server startGRPCServer(ConnectorType connectorType) throws IOException {
ServerInterceptors.intercept(
GRPCServiceFactory.getgRPCService(
ConnectorType.OPEN_INFERENCE_CONNECTOR),
new GRPCInterceptor()));
new GRPCInterceptor(connectorType)));
}

if (configManager.isGRPCSSLEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.pytorch.serve.servingsdk.impl.PluginsManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.ConnectorType;
import org.pytorch.serve.util.TokenType;
import org.pytorch.serve.util.TokenAuthorization.TokenType;
import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler;
import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler;
import org.slf4j.Logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,32 @@
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import org.pytorch.serve.http.Session;
import org.pytorch.serve.util.ConnectorType;
import org.pytorch.serve.util.TokenAuthorization;
import org.pytorch.serve.util.TokenAuthorization.TokenType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GRPCInterceptor implements ServerInterceptor {

private TokenType tokenType;
private static final Metadata.Key<String> tokenAuthHeaderKey =
Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER);
private static final Logger logger = LoggerFactory.getLogger("ACCESS_LOG");

public GRPCInterceptor(ConnectorType connectorType) {
switch (connectorType) {
case MANAGEMENT_CONNECTOR:
tokenType = TokenType.MANAGEMENT;
break;
case INFERENCE_CONNECTOR:
tokenType = TokenType.INFERENCE;
break;
default:
tokenType = TokenType.INFERENCE;
}
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
Expand All @@ -23,6 +42,17 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
String serviceName = call.getMethodDescriptor().getFullMethodName();
Session session = new Session(inetSocketString, serviceName);

if (TokenAuthorization.isEnabled()) {
if (!headers.containsKey(tokenAuthHeaderKey)
|| !checkTokenAuthorization(headers.get(tokenAuthHeaderKey))) {
call.close(
Status.PERMISSION_DENIED.withDescription(
"Token Authorization failed. Token either incorrect, expired, or not provided correctly"),
new Metadata());
return new ServerCall.Listener<ReqT>() {};
}
}

return next.startCall(
new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
@Override
Expand All @@ -34,4 +64,13 @@ public void close(final Status status, final Metadata trailers) {
},
headers);
}

private Boolean checkTokenAuthorization(String tokenAuthHeaderValue) {
if (tokenAuthHeaderValue == null) {
return false;
}
String token = TokenAuthorization.parseTokenFromBearerTokenHeader(tokenAuthHeaderValue);

return TokenAuthorization.checkTokenAuthorization(token, tokenType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.pytorch.serve.job.GRPCJob;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.GRPCUtils;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.messages.RequestInput;
Expand All @@ -32,8 +33,13 @@
import org.slf4j.LoggerFactory;

public class ManagementImpl extends ManagementAPIsServiceImplBase {
private ConfigManager configManager;
private static final Logger logger = LoggerFactory.getLogger(ManagementImpl.class);

public ManagementImpl() {
configManager = ConfigManager.getInstance();
}

@Override
public void describeModel(
DescribeModelRequest request, StreamObserver<ManagementResponse> responseObserver) {
Expand Down Expand Up @@ -117,6 +123,13 @@ public void registerModel(

StatusResponse statusResponse;
try {
if (!configManager.isModelApiEnabled()) {
sendErrorResponse(
responseObserver,
Status.PERMISSION_DENIED,
new ModelException("Model API disabled"));
return;
}
statusResponse = ApiUtils.registerModel(registerModelRequest);
sendStatusResponse(responseObserver, statusResponse);
} catch (InternalServerException e) {
Expand Down Expand Up @@ -205,6 +218,13 @@ public void unregisterModel(
.asRuntimeException());
});
try {
if (!configManager.isModelApiEnabled()) {
sendErrorResponse(
responseObserver,
Status.PERMISSION_DENIED,
new ModelException("Model API disabled"));
return;
}
String modelName = request.getModelName();
if (modelName == null || ("").equals(modelName)) {
sendErrorResponse(
Expand Down
Loading

0 comments on commit 5dbf18e

Please sign in to comment.