Skip to content

Commit

Permalink
Merge pull request #1 from lewan42/add_data_getter
Browse files Browse the repository at this point in the history
add data getter
  • Loading branch information
lewan42 authored Oct 24, 2024
2 parents 2017dfb + 36dc9f6 commit 27d3ea2
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 75 deletions.
237 changes: 167 additions & 70 deletions centrifuge/src/main/java/io/github/centrifugal/centrifuge/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -395,34 +395,67 @@ private void handleConnectionOpen() {
if (this.getState() != ClientState.CONNECTING) {
return;
}
if (this.refreshRequired || (this.token.equals("") && this.opts.getTokenGetter() != null)) {
ConnectionTokenEvent connectionTokenEvent = new ConnectionTokenEvent();
if (this.opts.getTokenGetter() == null) {
this.listener.onError(Client.this, new ErrorEvent(new ConfigurationError(new Exception("tokenGetter function should be provided in Client options to handle token refresh, see Options.setTokenGetter"))));
this.processDisconnect(DISCONNECTED_UNAUTHORIZED, "unauthorized", false);
return;
}
this.opts.getTokenGetter().getConnectionToken(connectionTokenEvent, (err, token) -> this.executor.submit(() -> {
if (Client.this.getState() != ClientState.CONNECTING) {
if (this.refreshRequired) {
if (this.data == null && this.opts.getDataGetter() != null) {
ConnectionDataEvent connectionDataEvent = new ConnectionDataEvent();
if (this.opts.getDataGetter() == null) {
this.listener.onError(Client.this, new ErrorEvent(new ConfigurationError(new Exception("dataGetter function should be provided in Client options to handle token refresh, see Options.setTokenGetter"))));
this.processDisconnect(DISCONNECTED_UNAUTHORIZED, "unauthorized", false);
return;
}
if (err != null) {
if (err instanceof UnauthorizedException) {
Client.this.failUnauthorized();
this.opts.getDataGetter().getConnectionData(connectionDataEvent, (err, data) -> this.executor.submit(() -> {
if (Client.this.getState() != ClientState.CONNECTING) {
return;
}
Client.this.listener.onError(Client.this, new ErrorEvent(new TokenError(err)));
this.ws.close(NORMAL_CLOSURE_STATUS, "");
return;
}
if (token == null) {
Client.this.processDisconnect(DISCONNECTED_BAD_PROTOCOL, "bad protocol (token)", false);
if (err != null) {
if (err instanceof UnauthorizedException) {
Client.this.failUnauthorized();
return;
}
Client.this.listener.onError(Client.this, new ErrorEvent(new TokenError(err)));
this.ws.close(NORMAL_CLOSURE_STATUS, "");
return;
}
if (data == null) {
Client.this.processDisconnect(DISCONNECTED_BAD_PROTOCOL, "bad protocol (data)", false);
return;
}
Client.this.data = com.google.protobuf.ByteString.copyFrom(data);
Client.this.refreshRequired = false;
this.sendConnect();
}));

} else if (this.token.equals("") && this.opts.getTokenGetter() != null) {
ConnectionTokenEvent connectionTokenEvent = new ConnectionTokenEvent();
if (this.opts.getTokenGetter() == null) {
this.listener.onError(Client.this, new ErrorEvent(new ConfigurationError(new Exception("dataGetter function should be provided in Client options to handle token refresh, see Options.setTokenGetter"))));
this.processDisconnect(DISCONNECTED_UNAUTHORIZED, "unauthorized", false);
return;
}
Client.this.token = token;
Client.this.refreshRequired = false;
this.opts.getTokenGetter().getConnectionToken(connectionTokenEvent, (err, token) -> this.executor.submit(() -> {
if (Client.this.getState() != ClientState.CONNECTING) {
return;
}
if (err != null) {
if (err instanceof UnauthorizedException) {
Client.this.failUnauthorized();
return;
}
Client.this.listener.onError(Client.this, new ErrorEvent(new TokenError(err)));
this.ws.close(NORMAL_CLOSURE_STATUS, "");
return;
}
if (token == null) {
Client.this.processDisconnect(DISCONNECTED_BAD_PROTOCOL, "bad protocol (data)", false);
return;
}
Client.this.token = token;
Client.this.refreshRequired = false;
this.sendConnect();
}));
} else {
this.sendConnect();
}));
}
} else {
this.sendConnect();
}
Expand Down Expand Up @@ -547,7 +580,7 @@ private ServerSubscription getServerSub(String channel) {
* Create new subscription to channel with SubscriptionOptions and SubscriptionEventListener
*
* @param channel: to create Subscription for.
* @param options: to pass SubscriptionOptions, e.g. token.
* @param options: to pass SubscriptionOptions, e.g. token.
* @param listener: to pass event handler.
* @return Subscription.
* @throws DuplicateSubscriptionException if Subscription already exists in internal registry.
Expand Down Expand Up @@ -640,6 +673,7 @@ private void handleConnectReply(Protocol.Reply reply) {
this.handleConnectionError(new ReplyError(reply.getError().getCode(), reply.getError().getMessage(), reply.getError().getTemporary()));
if (reply.getError().getCode() == 109) { // Token expired.
this.refreshRequired = true;
this.data = null;
this.ws.close(NORMAL_CLOSURE_STATUS, "");
} else if (reply.getError().getTemporary()) {
this.ws.close(NORMAL_CLOSURE_STATUS, "");
Expand Down Expand Up @@ -751,65 +785,122 @@ private void handleConnectReply(Protocol.Reply reply) {
}

private void sendRefresh() {
if (this.opts.getTokenGetter() == null) {
return;
}
this.executor.submit(() -> Client.this.opts.getTokenGetter().getConnectionToken(new ConnectionTokenEvent(), (err, token) -> Client.this.executor.submit(() -> {
if (Client.this.getState() != ClientState.CONNECTED) {
return;
}
if (err != null) {
if (err instanceof UnauthorizedException) {
Client.this.failUnauthorized();

if (this.opts.getDataGetter() != null) {
this.executor.submit(() -> Client.this.opts.getDataGetter().getConnectionData(new ConnectionDataEvent(), (err, data) -> Client.this.executor.submit(() -> {
if (Client.this.getState() != ClientState.CONNECTED) {
return;
}
Client.this.listener.onError(Client.this, new ErrorEvent(new TokenError(err)));
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
return;
}
if (token == null || token.equals("")) {
this.failUnauthorized();
return;
}
Client.this.token = token;
refreshSynchronized(token, (error, result) -> {
if (Client.this.getState() != ClientState.CONNECTED) {
if (err != null) {
if (err instanceof UnauthorizedException) {
Client.this.failUnauthorized();
return;
}
Client.this.listener.onError(Client.this, new ErrorEvent(new TokenError(err)));
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
return;
}
if (data == null || data.length == 0) {
this.failUnauthorized();
return;
}
if (error != null) {
Client.this.listener.onError(Client.this, new ErrorEvent(new RefreshError(error)));
if (error instanceof ReplyError) {
ReplyError e;
e = (ReplyError) error;
if (e.isTemporary()) {
Client.this.data = com.google.protobuf.ByteString.copyFrom(data);
refreshSynchronized(data, null, (error, result) -> {
if (Client.this.getState() != ClientState.CONNECTED) {
return;
}
if (error != null) {
Client.this.listener.onError(Client.this, new ErrorEvent(new RefreshError(error)));
if (error instanceof ReplyError) {
ReplyError e;
e = (ReplyError) error;
if (e.isTemporary()) {
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
} else {
Client.this.processDisconnect(e.getCode(), e.getMessage(), false);
}
return;
} else {
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
} else {
Client.this.processDisconnect(e.getCode(), e.getMessage(), false);
}
return;
} else {
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
}
if (result.getExpires()) {
int ttl = result.getTtl();
Client.this.refreshTask = Client.this.scheduler.schedule(Client.this::sendRefresh, ttl, TimeUnit.SECONDS);
}
});
})));
} else if (this.opts.getTokenGetter() != null) {
this.executor.submit(() -> Client.this.opts.getTokenGetter().getConnectionToken(new ConnectionTokenEvent(), (err, token) -> Client.this.executor.submit(() -> {
if (Client.this.getState() != ClientState.CONNECTED) {
return;
}
if (result.getExpires()) {
int ttl = result.getTtl();
Client.this.refreshTask = Client.this.scheduler.schedule(Client.this::sendRefresh, ttl, TimeUnit.SECONDS);
if (err != null) {
if (err instanceof UnauthorizedException) {
Client.this.failUnauthorized();
return;
}
Client.this.listener.onError(Client.this, new ErrorEvent(new TokenError(err)));
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
return;
}
});
})));
if (token == null || token.equals("")) {
this.failUnauthorized();
return;
}
Client.this.token = token;
refreshSynchronized(null, token, (error, result) -> {
if (Client.this.getState() != ClientState.CONNECTED) {
return;
}
if (error != null) {
Client.this.listener.onError(Client.this, new ErrorEvent(new RefreshError(error)));
if (error instanceof ReplyError) {
ReplyError e;
e = (ReplyError) error;
if (e.isTemporary()) {
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
} else {
Client.this.processDisconnect(e.getCode(), e.getMessage(), false);
}
return;
} else {
Client.this.refreshTask = Client.this.scheduler.schedule(
Client.this::sendRefresh,
Client.this.backoff.duration(0, 10000, 20000),
TimeUnit.MILLISECONDS
);
}
return;
}
if (result.getExpires()) {
int ttl = result.getTtl();
Client.this.refreshTask = Client.this.scheduler.schedule(Client.this::sendRefresh, ttl, TimeUnit.SECONDS);
}
});
})));
}
}

private void sendConnect() {
Expand Down Expand Up @@ -1299,10 +1390,16 @@ private void presenceStatsSynchronized(String channel, ResultCallback<PresenceSt
this.enqueueCommandFuture(cmd, f);
}

private void refreshSynchronized(String token, ResultCallback<Protocol.RefreshResult> cb) {
Protocol.RefreshRequest req = Protocol.RefreshRequest.newBuilder()
.setToken(token)
.build();
private void refreshSynchronized(byte[] data, String token, ResultCallback<Protocol.RefreshResult> cb) {
Protocol.RefreshRequest.Builder req = Protocol.RefreshRequest.newBuilder();

if (data != null) {
req.setTokenBytes(com.google.protobuf.ByteString.copyFrom(data));
}

if (token != null) {
req.setToken(token);
}

Protocol.Command cmd = Protocol.Command.newBuilder()
.setId(this.getNextId())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package io.github.centrifugal.centrifuge;

public class ConnectionDataEvent { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.github.centrifugal.centrifuge;

public abstract class ConnectionDataGetter {
public void getConnectionData(ConnectionDataEvent event, DataCallback cb) {
cb.Done(null, null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.github.centrifugal.centrifuge;

public interface DataCallback {
void Done(Throwable e, byte[] data);
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ public void setData(byte[] data) {

private byte[] data;

public ConnectionDataGetter getDataGetter() {
return dataGetter;
}

/**
* Set a method to extract new connection data.
*/
public void setDataGetter(ConnectionDataGetter dataGetter) {
this.dataGetter = dataGetter;
}

private ConnectionDataGetter dataGetter;

/**
* Set custom headers for WebSocket Upgrade request.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ void subscribeError(ReplyError err) {
this.listener.onError(this, new SubscriptionErrorEvent(new SubscriptionSubscribeError(err)));
if (err.getCode() == 109) { // Token expired.
this.token = "";
this.data = null;
this.scheduleResubscribe();
} if (err.isTemporary()) {
this.scheduleResubscribe();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.centrifugal.centrifuge.demo;

import static java.nio.charset.StandardCharsets.UTF_8;

import android.annotation.SuppressLint;
import android.os.Bundle;
import android.support.v7.app.AppCompatActivity;
Expand All @@ -9,23 +11,24 @@
import io.github.centrifugal.centrifuge.Client;
import io.github.centrifugal.centrifuge.ConnectedEvent;
import io.github.centrifugal.centrifuge.ConnectingEvent;
import io.github.centrifugal.centrifuge.ConnectionDataEvent;
import io.github.centrifugal.centrifuge.ConnectionDataGetter;
import io.github.centrifugal.centrifuge.DataCallback;
import io.github.centrifugal.centrifuge.DisconnectedEvent;
import io.github.centrifugal.centrifuge.ErrorEvent;
import io.github.centrifugal.centrifuge.EventListener;
import io.github.centrifugal.centrifuge.Options;
import io.github.centrifugal.centrifuge.PublicationEvent;
import io.github.centrifugal.centrifuge.ServerPublicationEvent;
import io.github.centrifugal.centrifuge.ServerSubscribedEvent;
import io.github.centrifugal.centrifuge.SubscribingEvent;
import io.github.centrifugal.centrifuge.SubscriptionErrorEvent;
import io.github.centrifugal.centrifuge.SubscribedEvent;
import io.github.centrifugal.centrifuge.SubscribingEvent;
import io.github.centrifugal.centrifuge.Subscription;
import io.github.centrifugal.centrifuge.SubscriptionErrorEvent;
import io.github.centrifugal.centrifuge.SubscriptionEventListener;
import io.github.centrifugal.centrifuge.SubscriptionOptions;
import io.github.centrifugal.centrifuge.UnsubscribedEvent;

import static java.nio.charset.StandardCharsets.UTF_8;

public class MainActivity extends AppCompatActivity {

private Client client;
Expand Down Expand Up @@ -72,7 +75,18 @@ public void onPublication(Client client, ServerPublicationEvent event) {
};

Options opts = new Options();
// opts.setToken("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0c3VpdGVfand0In0.hPmHsVqvtY88PvK4EmJlcdwNuKFuy3BGaF7dMaKdPlw");
String dataStr = "{ \"data\": { \"token\": \"jwt_token\" } }";
byte[] data = dataStr.getBytes();

opts.setData(data);

//Example with data
opts.setDataGetter(new ConnectionDataGetter() {
@Override
public void getConnectionData(ConnectionDataEvent event, DataCallback cb) {
cb.Done(null, data);
}
});

// Change the endpoint if needed.
String endpoint = "ws://10.0.2.2:8000/connection/websocket?cf_protocol_version=v2";
Expand Down

0 comments on commit 27d3ea2

Please sign in to comment.