Skip to content

Commit

Permalink
Send query heartbeat while spooled data is downloaded
Browse files Browse the repository at this point in the history
This prevents a situation when long list of segments is returned and client
slowly progresses over it which could take longer than a query abandoned timeout.

This change allows to send a heartbeat while client is progressing over data.
  • Loading branch information
wendigo committed Mar 10, 2025
1 parent 58ecb8f commit 7b599b7
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand All @@ -23,11 +24,13 @@
import io.airlift.units.Duration;
import jakarta.annotation.Nullable;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Headers;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;

import java.io.IOException;
import java.io.InterruptedIOException;
Expand All @@ -39,13 +42,15 @@
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.ZoneId;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Stream;
Expand All @@ -60,16 +65,19 @@
import static io.trino.client.ProtocolHeaders.TRINO_HEADERS;
import static io.trino.client.TrinoJsonCodec.jsonCodec;
import static java.lang.String.format;
import static java.net.HttpURLConnection.HTTP_NOT_FOUND;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED;
import static java.util.Arrays.stream;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;

@ThreadSafe
class StatementClientV1
implements StatementClient
{
private static final long HEARTBEAT_INTERVAL = new Duration(30, SECONDS).toMillis() * 1_000_000;
private static final MediaType MEDIA_TYPE_TEXT = MediaType.parse("text/plain; charset=utf-8");
private static final TrinoJsonCodec<QueryResults> QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class);

Expand All @@ -94,6 +102,7 @@ class StatementClientV1
private final Set<String> deallocatedPreparedStatements = Sets.newConcurrentHashSet();
private final AtomicReference<String> startedTransactionId = new AtomicReference<>();
private final AtomicBoolean clearTransactionId = new AtomicBoolean();
private final AtomicLong lastHeartbeat = new AtomicLong();
private final ZoneId timeZone;
private final Duration requestTimeoutNanos;
private final Optional<String> user;
Expand Down Expand Up @@ -388,6 +397,50 @@ public boolean advance()
return executeRequest(request, "fetching next", (e) -> true);
}

public void heartbeat()
{
if (System.nanoTime() - lastHeartbeat.get() < HEARTBEAT_INTERVAL) {
return;
}

if (!isRunning()) {
return;
}

URI nextUri = currentStatusInfo().getNextUri();
if (nextUri == null) {
return;
}

Request request = prepareRequest(HttpUrl.get(nextUri)
.newBuilder()
.addPathSegment("heartbeat").build())
.build();

lastHeartbeat.set(System.nanoTime());
httpCallFactory.newCall(request).enqueue(new Callback() {
@Override
public void onFailure(Call call, IOException e)
{
if (isTransient(e)) {
lastHeartbeat.set(0); // retry sending heartbeat
}
}

@Override
public void onResponse(Call call, Response response)
{
if (response.code() == HTTP_OK) {
// Heartbeat acknowledged, move even further
lastHeartbeat.set(System.nanoTime());
}
if (response.code() == HTTP_NOT_FOUND) {
lastHeartbeat.set(Long.MAX_VALUE); // No server-side support for heartbeats
}
}
});
}

private boolean executeRequest(Request request, String taskName, Function<Exception, Boolean> isRetryable)
{
Exception cause = null;
Expand Down Expand Up @@ -426,6 +479,7 @@ private boolean executeRequest(Request request, String taskName, Function<Except
JsonResponse<QueryResults> response;
try {
response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpCallFactory, request);
lastHeartbeat.set(System.nanoTime());
}
catch (RuntimeException e) {
if (!isRetryable.apply(e)) {
Expand Down Expand Up @@ -511,7 +565,7 @@ private void processResponse(Headers headers, QueryResults results)
}

currentResults.set(results);
currentRows.set(resultRowsDecoder.toRows(results));
currentRows.set(new HeartbeatingResultRows(resultRowsDecoder.toRows(results), this::heartbeat));
}

private List<String> safeSplitToList(String value)
Expand Down Expand Up @@ -608,4 +662,50 @@ private enum State
*/
FINISHED,
}

private static class HeartbeatingResultRows
implements ResultRows
{
private final Iterator<List<Object>> iterator;
private final boolean isNull;
private final Runnable heartbeat;

public HeartbeatingResultRows(ResultRows delegate, Runnable heartbeat)
{
this.iterator = delegate.iterator();
this.isNull = delegate.isNull();
this.heartbeat = heartbeat;
}

@Override
public void close()
throws IOException
{
if (iterator instanceof CloseableIterator) {
((CloseableIterator<?>) iterator).close();
}
}

@Override
public boolean isNull()
{
return isNull;
}

@Override
public Iterator<List<Object>> iterator()
{
return new AbstractIterator<>() {
@Override
protected List<Object> computeNext()
{
heartbeat.run();
if (iterator.hasNext()) {
return iterator.next();
}
return endOfData();
}
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ public void getQueryResults(
asyncQueryResults(query, token, externalUriInfo, asyncResponse);
}

@GET
@Path("{queryId}/{slug}/{token}/heartbeat")
@Produces(MediaType.APPLICATION_JSON)
public Response heartbeat(@PathParam("queryId") QueryId queryId, @PathParam("slug") String slug, @PathParam("token") long token)
{
Query query = queries.get(queryId);
if (query != null && query.isSlugValid(slug, token)) {
queryManager.recordHeartbeat(queryId);
return Response.ok().build();
}
throw new NotFoundException("Query not found");
}

protected Query getQuery(QueryId queryId, String slug, long token)
{
Query query = queries.get(queryId);
Expand Down

0 comments on commit 7b599b7

Please sign in to comment.