diff --git a/src/main/java/com/sproutsocial/nsq/Subscriber.java b/src/main/java/com/sproutsocial/nsq/Subscriber.java index 03d1bae..b52da3d 100644 --- a/src/main/java/com/sproutsocial/nsq/Subscriber.java +++ b/src/main/java/com/sproutsocial/nsq/Subscriber.java @@ -19,6 +19,7 @@ import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicLong; import static com.sproutsocial.nsq.Util.checkArgument; import static com.sproutsocial.nsq.Util.checkNotNull; @@ -28,6 +29,7 @@ public class Subscriber extends BasePubSub { private final List lookups = new ArrayList(); private final List subscriptions = new ArrayList(); + private final AtomicLong subscriptionIdCounter = new AtomicLong(0L); private final int lookupIntervalSecs; private int maxLookupFailuresBeforeError; private int defaultMaxInFlight = 200; @@ -65,12 +67,12 @@ public Subscriber(String... lookupHosts) { lookupHosts); } - public synchronized void subscribe(String topic, String channel, MessageHandler handler) { - subscribe(topic, channel, defaultMaxInFlight, handler); + public synchronized SubscriptionId subscribe(String topic, String channel, MessageHandler handler) { + return subscribe(topic, channel, defaultMaxInFlight, handler); } - public synchronized void subscribe(String topic, String channel, final MessageDataHandler handler) { - subscribe(topic, channel, defaultMaxInFlight, new BackoffHandler(new MessageHandler() { + public synchronized SubscriptionId subscribe(String topic, String channel, final MessageDataHandler handler) { + return subscribe(topic, channel, defaultMaxInFlight, new BackoffHandler(new MessageHandler() { @Override public void accept(Message msg) { handler.accept(msg.getData()); @@ -82,18 +84,22 @@ public void accept(Message msg) { * Subscribe to a topic. * If the configured executor is multi-threaded and maxInFlight > 1 (the defaults) * then the MessageHandler must be thread safe. + * + * @returns a {@link SubscriptionId} that can be passed back to an {@link Subscriber#unsubscribe} call. */ - public synchronized void subscribe(String topic, String channel, int maxInFlight, MessageHandler handler) { + public synchronized SubscriptionId subscribe(String topic, String channel, int maxInFlight, MessageHandler handler) { checkNotNull(topic); checkNotNull(channel); checkNotNull(handler); client.addSubscriber(this); - Subscription sub = new Subscription(client, topic, channel, handler, this, maxInFlight); + final SubscriptionId subscriptionId = SubscriptionId.fromCounter(subscriptionIdCounter); + final Subscription sub = new Subscription(subscriptionId, client, topic, channel, handler, this, maxInFlight); if (handler instanceof BackoffHandler) { ((BackoffHandler)handler).setSubscription(sub); //awkward } subscriptions.add(sub); sub.checkConnections(lookupTopic(topic)); + return subscriptionId; } /** @@ -102,15 +108,18 @@ public synchronized void subscribe(String topic, String channel, int maxInFlight * * NOTE: This will *not* delete the underlying channel that might have been created during the initial subscribe * call. + * + * @param subscriptionId A SubscriptionId returned from a previous {@link Subscriber#subscribe call}. + * @return true if the subscription was successfully stopped and removed. */ - public synchronized boolean unsubscribe(String topic, String channel) { - return unsubscribeSubscription(topic, channel) != null; + public synchronized boolean unsubscribe(final SubscriptionId subscriptionId) { + return unsubscribeSubscription(subscriptionId) != null; } - synchronized Subscription unsubscribeSubscription(String topic, String channel) { + synchronized Subscription unsubscribeSubscription(final SubscriptionId subscriptionId) { for (int i = 0; i < subscriptions.size(); i++) { final Subscription sub = subscriptions.get(i); - if (sub.getTopic().equals(topic) && sub.getChannel().equals(channel)) { + if (sub.getSubscriptionId().equals(subscriptionId)) { sub.stop(); return subscriptions.remove(i); } @@ -118,6 +127,22 @@ synchronized Subscription unsubscribeSubscription(String topic, String channel) return null; } + /** + * Deprecated. This method cannot handle when a single client creates two subscriptions to the same topic with the + * same channel name correctly. + */ + @Deprecated + public synchronized boolean unsubscribe(String topic, String channel) { + for (int i = 0; i < subscriptions.size(); i++) { + final Subscription sub = subscriptions.get(i); + if (sub.getTopic().equals(topic) && sub.getChannel().equals(channel)) { + sub.stop(); + return subscriptions.remove(i) != null; + } + } + return false; + } + public synchronized void setMaxInFlight(String topic, String channel, int maxInFlight) { for (Subscription sub : subscriptions) { if (sub.getTopic().equals(topic) && sub.getChannel().equals(channel)) { diff --git a/src/main/java/com/sproutsocial/nsq/Subscription.java b/src/main/java/com/sproutsocial/nsq/Subscription.java index f83d9c9..05123a0 100644 --- a/src/main/java/com/sproutsocial/nsq/Subscription.java +++ b/src/main/java/com/sproutsocial/nsq/Subscription.java @@ -9,7 +9,7 @@ import static com.sproutsocial.nsq.Util.copy; class Subscription extends BasePubSub { - + private final SubscriptionId subscriptionId; private final String topic; private final String channel; private final MessageHandler handler; @@ -20,8 +20,15 @@ class Subscription extends BasePubSub { private static final Logger logger = LoggerFactory.getLogger(Subscription.class); - public Subscription(Client client, String topic, String channel, MessageHandler handler, Subscriber subscriber, int maxInFlight) { + public Subscription(final SubscriptionId subscriptionId, + final Client client, + final String topic, + final String channel, + final MessageHandler handler, + final Subscriber subscriber, + final int maxInFlight) { super(client); + this.subscriptionId = subscriptionId; this.topic = topic; this.channel = channel; this.handler = handler; @@ -29,6 +36,10 @@ public Subscription(Client client, String topic, String channel, MessageHandler this.maxInFlight = maxInFlight; } + public SubscriptionId getSubscriptionId() { + return subscriptionId; + } + public synchronized int getMaxInFlight() { return maxInFlight; } @@ -187,7 +198,7 @@ public String getChannel() { @Override public String toString() { - return String.format("subscription %s.%s connections:%s", topic, channel, connectionMap.size()); + return String.format("subscription id %s, %s.%s connections:%s", subscriptionId, topic, channel, connectionMap.size()); } public int getConnectionCount() { diff --git a/src/main/java/com/sproutsocial/nsq/SubscriptionId.java b/src/main/java/com/sproutsocial/nsq/SubscriptionId.java new file mode 100644 index 0000000..73e2a6b --- /dev/null +++ b/src/main/java/com/sproutsocial/nsq/SubscriptionId.java @@ -0,0 +1,38 @@ +package com.sproutsocial.nsq; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Represents a unique subscription that's returned from a call to {@link Subscriber#subscribe}. + * Can be passed to methods such as {@link Subscriber#unsubscribe} to remove a subscription. + */ +public class SubscriptionId { + private final long id; + + protected SubscriptionId(final long id) { + this.id = id; + } + + static SubscriptionId fromCounter(final AtomicLong counter) { + return new SubscriptionId(counter.getAndIncrement()); + } + + @Override + public boolean equals(Object other) { + if (other == this) { return true; } + if (!(other instanceof SubscriptionId)) { return false; } + SubscriptionId that = (SubscriptionId)other; + return id == that.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public String toString() { + return "SubscriptionId { " + id + " }"; + } +} diff --git a/src/test/java/com/sproutsocial/nsq/SubscriberFocusedDockerTestIT.java b/src/test/java/com/sproutsocial/nsq/SubscriberFocusedDockerTestIT.java index 8376385..dead57c 100644 --- a/src/test/java/com/sproutsocial/nsq/SubscriberFocusedDockerTestIT.java +++ b/src/test/java/com/sproutsocial/nsq/SubscriberFocusedDockerTestIT.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; public class SubscriberFocusedDockerTestIT extends BaseDockerTestIT { private static Logger logger = LoggerFactory.getLogger(SubscriberFocusedDockerTestIT.class); @@ -24,8 +25,8 @@ public void setup() { public void twoDifferentSubscribersShareMessages() { TestMessageHandler handler1 = new TestMessageHandler(); TestMessageHandler handler2 = new TestMessageHandler(); - final Subscriber subscriber1 = startSubscriber(handler1, "channelA", null); - final Subscriber subscriber2 = startSubscriber(handler2, "channelA", null); + final Subscriber subscriber1 = startSubscriber(handler1, "channelA", null, null); + final Subscriber subscriber2 = startSubscriber(handler2, "channelA", null, null); List messages = messages(20, 40); send(topic, messages, 1, 200, publisher); @@ -47,7 +48,29 @@ public void twoDifferentSubscribersShareMessages() { @Test public void unsubscribingSubscribers() { TestMessageHandler handler = new TestMessageHandler(); - Subscriber subscriber = startSubscriber(handler, "channelA", null); + AtomicReference subscriptionId = new AtomicReference<>(); + Subscriber subscriber = startSubscriber(handler, "channelA", null, subscriptionId); + List batch1 = messages(20, 40); + List batch2 = messages(20, 40); + + send(topic, batch1, 0, 0, publisher); + Util.sleepQuietly(5000); + // Unsubscribe after the first batch. + Assert.assertTrue(subscriber.unsubscribe(subscriptionId.get())); + send(topic, batch2, 0, 0, publisher); + + Util.sleepQuietly(5000); + + // Ensure we only get 20 messages, even though we sent 40. + List consumerMessages = handler.drainMessages(20); + Assert.assertEquals(20, consumerMessages.size()); + } + + @Test + @Deprecated + public void unsubscribingSubscribersByTopicAndChannel() { + TestMessageHandler handler = new TestMessageHandler(); + Subscriber subscriber = startSubscriber(handler, "channelA", null, null); List batch1 = messages(20, 40); List batch2 = messages(20, 40); @@ -83,13 +106,13 @@ public void unsubscribeWithMessagesInFlight() { // Deliberately use a message handler that hangs forever, causing messages // to stay in flight. HangingMessageHandler handler = new HangingMessageHandler(); - Subscriber subscriber = startSubscriber(handler, "channelA", null); + AtomicReference subscriptionId = new AtomicReference<>(); + Subscriber subscriber = startSubscriber(handler, "channelA", null, subscriptionId); List batch1 = messages(20, 40); send(topic, batch1, 0, 0, publisher); Util.sleepQuietly(5000); - final Subscription subscription = subscriber.unsubscribeSubscription(topic, "channelA"); - Assert.assertTrue(subscription != null); + final Subscription subscription = subscriber.unsubscribeSubscription(subscriptionId.get()); // Since messages are in flight, we won't close the subscription immediately Assert.assertEquals(1, subscription.getConnectionCount()); @@ -112,16 +135,17 @@ public void unsubscribeWithMessagesInFlight() { @Test public void unsubscribeBeforeSubscriptionIsEstablished() { TestMessageHandler handler = new TestMessageHandler(); - Subscriber subscriber = startSubscriber(handler, "channelA", null); - Assert.assertTrue(subscriber.unsubscribe(topic, "channelA")); + AtomicReference subscriptionId = new AtomicReference<>(); + Subscriber subscriber = startSubscriber(handler, "channelA", null, subscriptionId); + Assert.assertTrue(subscriber.unsubscribe(subscriptionId.get())); } @Test public void verySlowConsumer_allMessagesReceivedByResponsiveConsumer() { TestMessageHandler handler = new TestMessageHandler(); NoAckReceiver delayHandler = new NoAckReceiver(8000); - final Subscriber subscriber1 = startSubscriber(handler, "channelA", null); - final Subscriber subscriber2 = startSubscriber(delayHandler, "channelA", null); + final Subscriber subscriber1 = startSubscriber(handler, "channelA", null, null); + final Subscriber subscriber2 = startSubscriber(delayHandler, "channelA", null, null); List messages = messages(40, 40); send(topic, messages, 1, 100, publisher); @@ -146,14 +170,17 @@ public void teardown() throws InterruptedException { super.teardown(); } - private Subscriber startSubscriber(MessageHandler handler, String channel, FailedMessageHandler failedMessageHandler) { + private Subscriber startSubscriber(MessageHandler handler, String channel, FailedMessageHandler failedMessageHandler, AtomicReference subscriptionIdRef) { Subscriber subscriber = new Subscriber(client, 1, 10, cluster.getLookupNode().getHttpHostAndPort().toString()); subscriber.setDefaultMaxInFlight(1); subscriber.setMaxAttempts(5); if (failedMessageHandler != null) { subscriber.setFailedMessageHandler(failedMessageHandler); } - subscriber.subscribe(topic, channel, handler); + final SubscriptionId subscriptionId = subscriber.subscribe(topic, channel, handler); + if (subscriptionIdRef != null) { + subscriptionIdRef.set(subscriptionId); + } this.subscribers.add(subscriber); return subscriber; }