diff --git a/lib/logstash/inputs/beats.rb b/lib/logstash/inputs/beats.rb index 36817798..665f9d09 100644 --- a/lib/logstash/inputs/beats.rb +++ b/lib/logstash/inputs/beats.rb @@ -118,9 +118,6 @@ class LogStash::Inputs::Beats < LogStash::Inputs::Base # Close Idle clients after X seconds of inactivity. config :client_inactivity_timeout, :validate => :number, :default => 60 - # Beats handler executor thread - config :executor_threads, :validate => :number, :default => LogStash::Config::CpuCoreStrategy.maximum - def register # For Logstash 2.4 we need to make sure that the logger is correctly set for the # java classes before actually loading them. @@ -162,7 +159,7 @@ def register end # def register def create_server - server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout, @executor_threads) + server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout) if @ssl ssl_context_builder = new_ssl_context_builder if client_authentification? diff --git a/spec/inputs/beats_spec.rb b/spec/inputs/beats_spec.rb index 51aec7ea..5692aa0b 100644 --- a/spec/inputs/beats_spec.rb +++ b/spec/inputs/beats_spec.rb @@ -25,16 +25,15 @@ context "#register" do context "host related configuration" do - let(:config) { super.merge("host" => host, "port" => port, "client_inactivity_timeout" => client_inactivity_timeout, "executor_threads" => threads) } + let(:config) { super.merge("host" => host, "port" => port, "client_inactivity_timeout" => client_inactivity_timeout) } let(:host) { "192.168.1.20" } let(:port) { 9000 } let(:client_inactivity_timeout) { 400 } - let(:threads) { 10 } subject(:plugin) { LogStash::Inputs::Beats.new(config) } it "sends the required options to the server" do - expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout, threads) + expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout) subject.register end end diff --git a/src/main/java/org/logstash/beats/BeatsHandler.java b/src/main/java/org/logstash/beats/BeatsHandler.java index 16564222..123895d6 100644 --- a/src/main/java/org/logstash/beats/BeatsHandler.java +++ b/src/main/java/org/logstash/beats/BeatsHandler.java @@ -2,7 +2,6 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.util.AttributeKey; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; diff --git a/src/main/java/org/logstash/beats/BeatsParser.java b/src/main/java/org/logstash/beats/BeatsParser.java index 9e852e49..4c42ecaa 100644 --- a/src/main/java/org/logstash/beats/BeatsParser.java +++ b/src/main/java/org/logstash/beats/BeatsParser.java @@ -3,8 +3,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -14,12 +16,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; - public class BeatsParser extends ByteToMessageDecoder { private final static Logger logger = LogManager.getLogger(BeatsParser.class); + private final static long maxDirectMemory = io.netty.util.internal.PlatformDependent.maxDirectMemory(); private Batch batch; @@ -45,15 +48,19 @@ private enum States { private int requiredBytes = 0; private int sequence = 0; private boolean decodingCompressedBuffer = false; + private long usedDirectMemory; + private boolean closeCalled = false; @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if(!hasEnoughBytes(in)) { if (decodingCompressedBuffer){ throw new InvalidFrameProtocolException("Insufficient bytes in compressed content to decode: " + currentState); } return; } + usedDirectMemory = ((PooledByteBufAllocator) ctx.alloc()).metric().usedDirectMemory(); switch (currentState) { case READ_HEADER: { @@ -178,6 +185,14 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t case READ_COMPRESSED_FRAME: { logger.trace("Running: READ_COMPRESSED_FRAME"); + + if (usedDirectMemory + requiredBytes > maxDirectMemory * 0.90) { + ctx.channel().config().setAutoRead(false); + ctx.close(); + closeCalled = true; + throw new IOException("not enough memory to decompress this from " + ctx.channel().id()); + } + // Use the compressed size as the safe start for the buffer. ByteBuf buffer = inflateCompressedFrame(ctx, in); transition(States.READ_HEADER); @@ -190,14 +205,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } finally { decodingCompressedBuffer = false; buffer.release(); + ctx.channel().config().setAutoRead(false); + ctx.channel().eventLoop().schedule(() -> { + ctx.channel().config().setAutoRead(true); + }, 5, TimeUnit.MILLISECONDS); transition(States.READ_HEADER); } break; } case READ_JSON: { logger.trace("Running: READ_JSON"); - ((V2Batch)batch).addMessage(sequence, in, requiredBytes); - if(batch.isComplete()) { + ((V2Batch) batch).addMessage(sequence, in, requiredBytes); + if (batch.isComplete()) { if(logger.isTraceEnabled()) { logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence); } @@ -212,6 +231,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } private ByteBuf inflateCompressedFrame(final ChannelHandlerContext ctx, final ByteBuf in) throws IOException { + ByteBuf buffer = ctx.alloc().buffer(requiredBytes); Inflater inflater = new Inflater(); try ( @@ -219,7 +239,7 @@ private ByteBuf inflateCompressedFrame(final ChannelHandlerContext ctx, final By InflaterOutputStream inflaterStream = new InflaterOutputStream(buffOutput, inflater) ) { in.readBytes(inflaterStream, requiredBytes); - }finally{ + }finally { inflater.end(); } return buffer; @@ -247,4 +267,59 @@ private void batchComplete() { batch = null; } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + //System.out.println("channelRead(" + ctx.channel().isActive() + ": " + ctx.channel().id() + ":" + currentState + ":" + decodingCompressedBuffer); + if (closeCalled) { + ((ByteBuf) msg).release(); + //if(batch != null) batch.release(); + return; + } + usedDirectMemory = ((PooledByteBufAllocator) ctx.alloc()).metric().usedDirectMemory(); + + // If we're just beginning a new frame on this channel, + // don't accumulate more data for 25 ms if usage of direct memory is above 20% + // + // The goal here is to avoid thundering herd: many beats connecting and sending data + // at the same time. As some channels progress to other states they'll use more memory + // but also give it back once a full batch is read. + if ((!decodingCompressedBuffer) && (this.currentState != States.READ_COMPRESSED_FRAME)) { + if (usedDirectMemory > (maxDirectMemory * 0.40)) { + ctx.channel().config().setAutoRead(false); + //System.out.println("pausing reads on " + ctx.channel().id()); + ctx.channel().eventLoop().schedule(() -> { + //System.out.println("resuming reads on " + ctx.channel().id()); + ctx.channel().config().setAutoRead(true); + }, 200, TimeUnit.MILLISECONDS); + } else { + //System.out.println("no need to pause reads on " + ctx.channel().id()); + } + } else if (usedDirectMemory > maxDirectMemory * 0.90) { + ctx.channel().config().setAutoRead(false); + ctx.close(); + closeCalled = true; + ((ByteBuf) msg).release(); + if (batch != null) batch.release(); + throw new IOException("about to explode, cut them all down " + ctx.channel().id()); + } + super.channelRead(ctx, msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + System.out.println(cause.getClass().toString() + ":" + ctx.channel().id().toString() + ":" + this.currentState + "|" + cause.getMessage()); + if (cause instanceof DecoderException) { + ctx.channel().config().setAutoRead(false); + if (!closeCalled) ctx.close(); + } else if (cause instanceof OutOfMemoryError) { + cause.printStackTrace(); + ctx.channel().config().setAutoRead(false); + if (!closeCalled) ctx.close(); + } else if (cause instanceof IOException) { + ctx.channel().config().setAutoRead(false); + if (!closeCalled) ctx.close(); + } else { + super.exceptionCaught(ctx, cause); + } + } } diff --git a/src/main/java/org/logstash/beats/Runner.java b/src/main/java/org/logstash/beats/Runner.java index 9a8813a8..19889efa 100644 --- a/src/main/java/org/logstash/beats/Runner.java +++ b/src/main/java/org/logstash/beats/Runner.java @@ -20,7 +20,7 @@ static public void main(String[] args) throws Exception { // Check for leaks. // ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); - Server server = new Server("0.0.0.0", DEFAULT_PORT, 15, Runtime.getRuntime().availableProcessors()); + Server server = new Server("0.0.0.0", DEFAULT_PORT, 15); if(args.length > 0 && args[0].equals("ssl")) { logger.debug("Using SSL"); diff --git a/src/main/java/org/logstash/beats/Server.java b/src/main/java/org/logstash/beats/Server.java index 5da86c09..4a1e60bd 100644 --- a/src/main/java/org/logstash/beats/Server.java +++ b/src/main/java/org/logstash/beats/Server.java @@ -17,7 +17,6 @@ public class Server { private final int port; private final String host; - private final int beatsHeandlerThreadCount; private NioEventLoopGroup workGroup; private IMessageListener messageListener = new MessageListener(); private SslHandlerProvider sslHandlerProvider; @@ -25,11 +24,10 @@ public class Server { private final int clientInactivityTimeoutSeconds; - public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount) { + public Server(String host, int p, int clientInactivityTimeoutSeconds) { this.host = host; port = p; this.clientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds; - beatsHeandlerThreadCount = threadCount; } public void setSslHandlerProvider(SslHandlerProvider sslHandlerProvider){ @@ -49,7 +47,7 @@ public Server listen() throws InterruptedException { try { logger.info("Starting server on port: {}", this.port); - beatsInitializer = new BeatsInitializer(messageListener, clientInactivityTimeoutSeconds, beatsHeandlerThreadCount); + beatsInitializer = new BeatsInitializer(messageListener, clientInactivityTimeoutSeconds); ServerBootstrap server = new ServerBootstrap(); server.group(workGroup) @@ -99,21 +97,18 @@ private class BeatsInitializer extends ChannelInitializer { private final String CONNECTION_HANDLER = "connection-handler"; private final String BEATS_ACKER = "beats-acker"; - private final int DEFAULT_IDLESTATEHANDLER_THREAD = 4; private final int IDLESTATE_WRITER_IDLE_TIME_SECONDS = 5; private final EventExecutorGroup idleExecutorGroup; - private final EventExecutorGroup beatsHandlerExecutorGroup; private final IMessageListener localMessageListener; private final int localClientInactivityTimeoutSeconds; - BeatsInitializer(IMessageListener messageListener, int clientInactivityTimeoutSeconds, int beatsHandlerThread) { + BeatsInitializer(IMessageListener messageListener, int clientInactivityTimeoutSeconds) { // Keeps a local copy of Server settings, so they can't be modified once it starts listening this.localMessageListener = messageListener; this.localClientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds; idleExecutorGroup = new DefaultEventExecutorGroup(DEFAULT_IDLESTATEHANDLER_THREAD); - beatsHandlerExecutorGroup = new DefaultEventExecutorGroup(beatsHandlerThread); } public void initChannel(SocketChannel socket){ @@ -126,11 +121,10 @@ public void initChannel(SocketChannel socket){ new IdleStateHandler(localClientInactivityTimeoutSeconds, IDLESTATE_WRITER_IDLE_TIME_SECONDS, localClientInactivityTimeoutSeconds)); pipeline.addLast(BEATS_ACKER, new AckEncoder()); pipeline.addLast(CONNECTION_HANDLER, new ConnectionHandler()); - pipeline.addLast(beatsHandlerExecutorGroup, new BeatsParser(), new BeatsHandler(localMessageListener)); + pipeline.addLast(new BeatsParser()); + pipeline.addLast(new BeatsHandler(localMessageListener)); } - - @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { logger.warn("Exception caught in channel initializer", cause); @@ -144,7 +138,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E public void shutdownEventExecutor() { try { idleExecutorGroup.shutdownGracefully().sync(); - beatsHandlerExecutorGroup.shutdownGracefully().sync(); } catch (InterruptedException e) { throw new IllegalStateException(e); } diff --git a/src/test/java/org/logstash/beats/ServerTest.java b/src/test/java/org/logstash/beats/ServerTest.java index 37512cdc..a3f88258 100644 --- a/src/test/java/org/logstash/beats/ServerTest.java +++ b/src/test/java/org/logstash/beats/ServerTest.java @@ -33,7 +33,6 @@ public class ServerTest { private int randomPort; private EventLoopGroup group; private final String host = "0.0.0.0"; - private final int threadCount = 10; @Before public void setUp() { @@ -50,7 +49,7 @@ public void testServerShouldTerminateConnectionWhenExceptionHappen() throws Inte final CountDownLatch latch = new CountDownLatch(concurrentConnections); - final Server server = new Server(host, randomPort, inactivityTime, threadCount); + final Server server = new Server(host, randomPort, inactivityTime); final AtomicBoolean otherCause = new AtomicBoolean(false); server.setMessageListener(new MessageListener() { public void onNewConnection(ChannelHandlerContext ctx) { @@ -114,7 +113,7 @@ public void testServerShouldTerminateConnectionIdleForTooLong() throws Interrupt final CountDownLatch latch = new CountDownLatch(concurrentConnections); final AtomicBoolean exceptionClose = new AtomicBoolean(false); - final Server server = new Server(host, randomPort, inactivityTime, threadCount); + final Server server = new Server(host, randomPort, inactivityTime); server.setMessageListener(new MessageListener() { @Override public void onNewConnection(ChannelHandlerContext ctx) { @@ -170,7 +169,7 @@ public void run() { @Test public void testServerShouldAcceptConcurrentConnection() throws InterruptedException { - final Server server = new Server(host, randomPort, 30, threadCount); + final Server server = new Server(host, randomPort, 30); SpyListener listener = new SpyListener(); server.setMessageListener(listener); Runnable serverTask = new Runnable() {