From 5803aadd3f209eba1ffbb2cf7bf16778019dbee1 Mon Sep 17 00:00:00 2001
From: fredoboulo <fredoboulo@users.noreply.github.com>
Date: Fri, 23 Feb 2018 23:55:57 +0100
Subject: [PATCH] Fix #524 : V1 and V2 protocol downgrades handle received data
 in handshake buffer

This patch is upstream pull request, see:
https://gihub.com/zeromq/jeromq/pull/527.

It is merged on commit c2afa9c, and we can drop it on the
0.4.4 release.

---
 src/main/java/zmq/io/StreamEngine.java            | 21 ++++++++++--
 src/test/java/zmq/io/AbstractProtocolVersion.java | 41 +++++++++++++----------
 src/test/java/zmq/io/V0ProtocolTest.java          | 12 +++++++
 src/test/java/zmq/io/V1ProtocolTest.java          | 16 +++++++--
 src/test/java/zmq/io/V2ProtocolTest.java          | 16 +++++++--
 5 files changed, 81 insertions(+), 25 deletions(-)

diff --git a/src/main/java/zmq/io/StreamEngine.java b/src/main/java/zmq/io/StreamEngine.java
index b8933c92..fe2f2d8d 100644
--- a/src/main/java/zmq/io/StreamEngine.java
+++ b/src/main/java/zmq/io/StreamEngine.java
@@ -816,9 +816,7 @@ private boolean handshake()
             assert (bufferSize == headerSize);
 
             //  Make sure the decoder sees the data we have already received.
-            greetingRecv.flip();
-            inpos = greetingRecv;
-            insize = greetingRecv.limit();
+            decodeDataAfterHandshake(0);
 
             //  To allow for interoperability with peers that do not forward
             //  their subscriptions, we inject a phantom subscription message
@@ -846,6 +844,8 @@ else if (greetingRecv.get(revisionPos) == Protocol.V1.revision) {
             }
             encoder = new V1Encoder(errno, Config.OUT_BATCH_SIZE.getValue());
             decoder = new V1Decoder(errno, Config.IN_BATCH_SIZE.getValue(), options.maxMsgSize, options.allocator);
+
+            decodeDataAfterHandshake(V2_GREETING_SIZE);
         }
         else if (greetingRecv.get(revisionPos) == Protocol.V2.revision) {
             //  ZMTP/2.0 framing.
@@ -859,6 +859,8 @@ else if (greetingRecv.get(revisionPos) == Protocol.V2.revision) {
             }
             encoder = new V2Encoder(errno, Config.OUT_BATCH_SIZE.getValue());
             decoder = new V2Decoder(errno, Config.IN_BATCH_SIZE.getValue(), options.maxMsgSize, options.allocator);
+
+            decodeDataAfterHandshake(V2_GREETING_SIZE);
         }
         else {
             zmtpVersion = Protocol.V3;
@@ -904,6 +906,19 @@ else if (greetingRecv.get(revisionPos) == Protocol.V2.revision) {
         return true;
     }
 
+    private void decodeDataAfterHandshake(int greetingSize)
+    {
+        final int pos = greetingRecv.position();
+        if (pos > greetingSize) {
+            // data is present after handshake
+            greetingRecv.position(greetingSize).limit(pos);
+
+            //  Make sure the decoder sees this extra data.
+            inpos = greetingRecv;
+            insize = greetingRecv.remaining();
+        }
+    }
+
     private Msg identityMsg()
     {
         Msg msg = new Msg(options.identitySize);
diff --git a/src/test/java/zmq/io/AbstractProtocolVersion.java b/src/test/java/zmq/io/AbstractProtocolVersion.java
index e60db403..aa06b4a7 100644
--- a/src/test/java/zmq/io/AbstractProtocolVersion.java
+++ b/src/test/java/zmq/io/AbstractProtocolVersion.java
@@ -18,15 +18,18 @@
 import zmq.SocketBase;
 import zmq.ZError;
 import zmq.ZMQ;
+import zmq.ZMQ.Event;
 import zmq.util.Utils;
 
 public abstract class AbstractProtocolVersion
 {
+    protected static final int REPETITIONS = 1000;
+
     static class SocketMonitor extends Thread
     {
-        private final Ctx             ctx;
-        private final String          monitorAddr;
-        private final List<ZMQ.Event> events = new ArrayList<>();
+        private final Ctx         ctx;
+        private final String      monitorAddr;
+        private final ZMQ.Event[] events = new ZMQ.Event[1];
 
         public SocketMonitor(Ctx ctx, String monitorAddr)
         {
@@ -41,15 +44,15 @@ public void run()
             boolean rc = s.connect(monitorAddr);
             assertThat(rc, is(true));
             // Only some of the exceptional events could fire
-            while (true) {
-                ZMQ.Event event = ZMQ.Event.read(s);
-                if (event == null && s.errno() == ZError.ETERM) {
-                    break;
-                }
-                assertThat(event, notNullValue());
-
-                events.add(event);
+
+            ZMQ.Event event = ZMQ.Event.read(s);
+            if (event == null && s.errno() == ZError.ETERM) {
+                s.close();
+                return;
             }
+            assertThat(event, notNullValue());
+
+            events[0] = event;
             s.close();
         }
     }
@@ -69,11 +72,12 @@ public void run()
         boolean rc = ZMQ.setSocketOption(receiver, ZMQ.ZMQ_LINGER, 0);
         assertThat(rc, is(true));
 
-        SocketMonitor monitor = new SocketMonitor(ctx, "inproc://monitor");
-        monitor.start();
         rc = ZMQ.monitorSocket(receiver, "inproc://monitor", ZMQ.ZMQ_EVENT_HANDSHAKE_PROTOCOL);
         assertThat(rc, is(true));
 
+        SocketMonitor monitor = new SocketMonitor(ctx, "inproc://monitor");
+        monitor.start();
+
         rc = ZMQ.bind(receiver, host);
         assertThat(rc, is(true));
 
@@ -81,17 +85,18 @@ public void run()
         OutputStream out = sender.getOutputStream();
         for (ByteBuffer raw : raws) {
             out.write(raw.array());
-            ZMQ.msleep(100);
         }
 
         Msg msg = ZMQ.recv(receiver, 0);
         assertThat(msg, notNullValue());
         assertThat(new String(msg.data(), ZMQ.CHARSET), is(payload));
 
-        ZMQ.msleep(500);
-        assertThat(monitor.events.size(), is(1));
-        assertThat(monitor.events.get(0).event, is(ZMQ.ZMQ_EVENT_HANDSHAKE_PROTOCOL));
-        assertThat((Integer) monitor.events.get(0).arg, is(version));
+        monitor.join();
+
+        final Event event = monitor.events[0];
+        assertThat(event, notNullValue());
+        assertThat(event.event, is(ZMQ.ZMQ_EVENT_HANDSHAKE_PROTOCOL));
+        assertThat((Integer) event.arg, is(version));
 
         InputStream in = sender.getInputStream();
         byte[] data = new byte[255];
diff --git a/src/test/java/zmq/io/V0ProtocolTest.java b/src/test/java/zmq/io/V0ProtocolTest.java
index bd547d23..1a5b7aef 100644
--- a/src/test/java/zmq/io/V0ProtocolTest.java
+++ b/src/test/java/zmq/io/V0ProtocolTest.java
@@ -10,6 +10,18 @@
 
 public class V0ProtocolTest extends AbstractProtocolVersion
 {
+    @Test
+    public void testFixIssue524() throws IOException, InterruptedException
+    {
+        for (int idx = 0; idx < REPETITIONS; ++idx) {
+            if (idx % 100 == 0) {
+                System.out.print(idx + " ");
+            }
+            testProtocolVersion0short();
+        }
+        System.out.println();
+    }
+
     @Test(timeout = 2000)
     public void testProtocolVersion0short() throws IOException, InterruptedException
     {
diff --git a/src/test/java/zmq/io/V1ProtocolTest.java b/src/test/java/zmq/io/V1ProtocolTest.java
index e1045f34..764159d0 100644
--- a/src/test/java/zmq/io/V1ProtocolTest.java
+++ b/src/test/java/zmq/io/V1ProtocolTest.java
@@ -10,7 +10,19 @@
 
 public class V1ProtocolTest extends AbstractProtocolVersion
 {
-    @Test(timeout = 2000)
+    @Test
+    public void testFixIssue524() throws IOException, InterruptedException
+    {
+        for (int idx = 0; idx < REPETITIONS; ++idx) {
+            if (idx % 100 == 0) {
+                System.out.print(idx + " ");
+            }
+            testProtocolVersion1short();
+        }
+        System.out.println();
+    }
+
+    @Test
     public void testProtocolVersion1short() throws IOException, InterruptedException
     {
         List<ByteBuffer> raws = raws(0);
@@ -25,7 +37,7 @@ public void testProtocolVersion1short() throws IOException, InterruptedException
         assertProtocolVersion(1, raws, "abcdefg");
     }
 
-    @Test(timeout = 2000)
+    @Test
     public void testProtocolVersion1long() throws IOException, InterruptedException
     {
         List<ByteBuffer> raws = raws(0);
diff --git a/src/test/java/zmq/io/V2ProtocolTest.java b/src/test/java/zmq/io/V2ProtocolTest.java
index d5e64bce..7fda31bc 100644
--- a/src/test/java/zmq/io/V2ProtocolTest.java
+++ b/src/test/java/zmq/io/V2ProtocolTest.java
@@ -21,7 +21,19 @@ protected ByteBuffer identity()
                 .put((byte) 0);
     }
 
-    @Test(timeout = 2000)
+    @Test
+    public void testFixIssue524() throws IOException, InterruptedException
+    {
+        for (int idx = 0; idx < REPETITIONS; ++idx) {
+            if (idx % 100 == 0) {
+                System.out.print(idx + " ");
+            }
+            testProtocolVersion2short();
+        }
+        System.out.println();
+    }
+
+    @Test
     public void testProtocolVersion2short() throws IOException, InterruptedException
     {
         List<ByteBuffer> raws = raws(1);
@@ -38,7 +50,7 @@ public void testProtocolVersion2short() throws IOException, InterruptedException
         assertProtocolVersion(2, raws, "abcdefg");
     }
 
-    @Test(timeout = 2000)
+    @Test
     public void testProtocolVersion2long() throws IOException, InterruptedException
     {
         List<ByteBuffer> raws = raws(1);