From 9833774a7d395cf47c9170e2837b09aa34538970 Mon Sep 17 00:00:00 2001 From: lxning <23464292+lxning@users.noreply.github.com> Date: Sat, 24 Jun 2023 01:42:17 +0000 Subject: [PATCH] skip send response if grpc channel is closed by client (#2420) * skip send response if channel closed * update log * update log * update requestid --------- Co-authored-by: Ankith Gunapal --- .../java/org/pytorch/serve/job/GRPCJob.java | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java index fe4192103a..9c4b0d9e56 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java @@ -4,6 +4,7 @@ import com.google.protobuf.ByteString; import io.grpc.Status; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.util.ArrayList; import java.util.Arrays; @@ -67,10 +68,16 @@ public void response( int statusCode, String statusPhrase, Map responseHeaders) { - ByteString output = ByteString.copyFrom(body); if (this.getCmd() == WorkerCommands.PREDICT || this.getCmd() == WorkerCommands.STREAMPREDICT) { + if (((ServerCallStreamObserver) predictionResponseObserver) + .isCancelled()) { + logger.warn( + "grpc client call already cancelled, not able to send this response for requestId: {}", + getPayload().getRequestId()); + return; + } PredictionResponse reply = PredictionResponse.newBuilder().setPrediction(output).build(); predictionResponseObserver.onNext(reply); @@ -118,6 +125,14 @@ public void sendError(int status, String error) { Status responseStatus = GRPCUtils.getGRPCStatusCode(status); if (this.getCmd() == WorkerCommands.PREDICT || this.getCmd() == WorkerCommands.STREAMPREDICT) { + if (((ServerCallStreamObserver) predictionResponseObserver) + .isCancelled()) { + logger.warn( + "grpc client call already cancelled, not able to send error: {}, for requestId: {}", + error, + getPayload().getRequestId()); + return; + } predictionResponseObserver.onError( responseStatus .withDescription(error)