Verifying logging with Log4j2, JUnit 5, and Mockito 3+

By devin, 15 August, 2024

My team at work used to use PowerMock to verify logging. Generally you don't need to verify logging, but occasionally you have an outage and realize it's important to verify in a unit test that logging is working correctly. Hence we created this utility class that avoids the need for static mocking.

Why is logging so hard? It's because you have a class like so:

public class ExampleClass {
  private final ServiceClient serviceClient;
  private static final Logger LOGGER = LogManager.getLogger(ExampleClass.class);
  
  public Optional<String> doSomeCalculation() {
    try {
      return Optional.of(serviceClient.call());
    } catch (SomeServiceException e) {
      LOGGER.warn("The call failed {}", e.getDebugInfo());
      return Optional.empty();
    }
  }
}

It's quite difficult to stub the behaviour of the LOGGER constant because it's initialized statically. I fought with the internals of Mockito, JUnit, and Log4j for a while before coming up with the following class:

public class LogReplayHijacker implements BeforeEachCallback, AfterEachCallback {
    private final Logger logger;
    private final ReplayAppender replayAppender;

    private static final Filter FILTER = LevelRangeFilter.createFilter(Level.OFF, Level.ALL, Filter.Result.ACCEPT, Filter.Result.DENY);

    public LogReplayHijacker(org.apache.logging.log4j.Logger logger) {
        this.logger = (org.apache.logging.log4j.core.Logger) logger;
        this.replayAppender = new ReplayAppender(String.format("%s-%s", logger.getName(), getClass().getSimpleName()));
    }

    /**
     * Get the log events for all log levels
     */
    public List<LogEvent> getLogEvents() {
        return replayAppender.getLogEvents();
    }

    /**
     * Get the log events filtered by a specific Level
     */
    public List<LogEvent> getLogEvents(Level level) {
        return replayAppender.getLogEvents(level);
    }

    /**
     * Get the log messages for all log levels
     */
    public List<String> getLogMessages() {
        return getLogEvents().stream()
            .map(LogEvent::getMessage)
            .map(Message::getFormattedMessage)
            .collect(toImmutableList());
    }

    /**
     * Get the log messages filtered by a specific Level
     */
    public List<String> getLogMessages(Level level) {
        return getLogEvents(level).stream()
            .map(LogEvent::getMessage)
            .map(Message::getFormattedMessage)
            .collect(toImmutableList());
    }

    @Override
    public void beforeEach(ExtensionContext extensionContext) {
        replayAppender.start();
        logger.addAppender(replayAppender);
        logger.getContext().addFilter(FILTER);
    }

    @Override
    public void afterEach(ExtensionContext extensionContext) {
        logger.getContext().removeFilter(FILTER);
        logger.removeAppender(replayAppender);
        replayAppender.stop();
    }

    private static class ReplayAppender extends AbstractAppender  {
        @Getter
        private final List<LogEvent> logEvents = new ArrayList<>();

        public List<LogEvent> getLogEvents(Level level) {
            return logEvents.stream()
                .filter(event -> event.getLevel().equals(level))
                .collect(toImmutableList());
        }

        ReplayAppender(String name) {
            super(name, null, null, true, Property.EMPTY_ARRAY);
        }

        @Override
        public void append(final LogEvent event) {
            logEvents.add(event);
        }
    }
}

Basically this is just a helper class that abstracts all the effort I went into. Now in your code you don't need to do any static mocking, you just use @RegisterExtension to make this a JUnit5 callback, and then read the log messages that were logged in your test. Something like this, using our example:

import static org.assertj.core.api.Assertions.assertThat;
//other imports

class ExampleClassTest {
  @RegisterExtension
  final LogReplayHijacker  logReplay = new LogReplayHijacker(LogManager.getLogger(ExampleClass.class));
  
  void test_happyCase() {
    // TODO
  }
  
  void test_errorCase() {
    ServiceException mockException = mock(ServiceException.class);
    when(mockException.getDebugInfo()).thenReturn("Oh no");
    
    ServiceClient mockClient = mock(ServiceClient.class);
    when(mockClient.call()).thenThrow(mockException);
    
    Optional<String> output = new ExampleClass(mockClient).doSomeCalculation();
    assertThat(output).isEmpty();
    
    assertThat(logReplay.getLogEvents(Level.ERROR)).isEmpty();
    assertThat(logReplay.getLogEvents(Level.WARN)).hasSize(1);
    assertThat(logReplay.getLogMessages(Level.WARN)).contains("The call failed Oh no");
  }
}

Plain text

  • No HTML tags allowed.
  • Web page addresses and email addresses turn into links automatically.
  • Lines and paragraphs break automatically.