diff --git a/.github/workflows/maven.yml b/.github/workflows/maven.yml index 8f1d781..350e46d 100644 --- a/.github/workflows/maven.yml +++ b/.github/workflows/maven.yml @@ -10,7 +10,7 @@ on: branches: [ master ] jobs: - build: + build-test: runs-on: ubuntu-latest @@ -22,18 +22,5 @@ jobs: java-version: '17' distribution: 'temurin' cache: maven - - name: Build with Maven - run: mvn package --file pom.xml - test: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - name: Set up JDK 17 - uses: actions/setup-java@v3 - with: - java-version: '17' - distribution: 'temurin' - cache: maven - - name: Build with Maven - run: mvn test --file pom.xml + - name: Build & Test with Maven + run: mvn --batch-mode --update-snapshots verify diff --git a/.gitignore b/.gitignore index e415a37..06e4112 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,5 @@ /target/ /src/main/resources/loginInfo.yml /.idea/ -/src/test/ +/src/test/java/api /out/ \ No newline at end of file diff --git a/pom.xml b/pom.xml index 1a2017d..22c03b8 100644 --- a/pom.xml +++ b/pom.xml @@ -7,8 +7,20 @@ dev.Jacrispys InsideAgentDev 0.1.9-beta.1 + + ${project.basedir}/src/test/java/ + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + + **/SpManTests.java + + + org.apache.maven.plugins maven-compiler-plugin @@ -67,6 +79,7 @@ UTF-8 Jacrispys dev.jacrispys.JavaBot.JavaBotMain + ${project.basedir}/src/test/java/ @@ -185,9 +198,16 @@ - com.zaxxer - HikariCP - 5.0.1 + org.junit.jupiter + junit-jupiter-engine + 5.10.0 + test + + + + org.mockito + mockito-core + 5.6.0 diff --git a/src/main/java/dev/jacrispys/JavaBot/JavaBotMain.java b/src/main/java/dev/jacrispys/JavaBot/JavaBotMain.java index 1f5784f..148b3cf 100644 --- a/src/main/java/dev/jacrispys/JavaBot/JavaBotMain.java +++ b/src/main/java/dev/jacrispys/JavaBot/JavaBotMain.java @@ -41,8 +41,7 @@ */ public class JavaBotMain { - // TODO: 1/26/2023 Add Documentation to all functions - + // TODO: 10/6/2023 Add unit testing to most classes private static final Logger logger = LoggerFactory.getLogger(JavaBotMain.class); private static final String className = JavaBotMain.class.getSimpleName(); public static AudioPlayerManager audioManager; diff --git a/src/main/java/dev/jacrispys/JavaBot/api/libs/utils/async/AsyncHandlerImpl.java b/src/main/java/dev/jacrispys/JavaBot/api/libs/utils/async/AsyncHandlerImpl.java index 31267d4..cfccc7c 100644 --- a/src/main/java/dev/jacrispys/JavaBot/api/libs/utils/async/AsyncHandlerImpl.java +++ b/src/main/java/dev/jacrispys/JavaBot/api/libs/utils/async/AsyncHandlerImpl.java @@ -1,5 +1,8 @@ package dev.jacrispys.JavaBot.api.libs.utils.async; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; @@ -7,7 +10,9 @@ /** * Async method handler */ -public abstract class AsyncHandlerImpl implements AsyncHandler{ +public abstract class AsyncHandlerImpl implements AsyncHandler { + + private static final Logger logger = LoggerFactory.getLogger(AsyncHandlerImpl.class); public record VoidMethodRunner(Runnable runnable, CompletableFuture cf) {} public record MethodRunner(Runnable runnable, CompletableFuture cf) {} @@ -26,6 +31,7 @@ public void completeVoid() { runner.cf().complete(null); } } catch (InterruptedException e) { + logger.error("Method runner was interrupted! Please report this.", e); Thread.currentThread().interrupt(); } } @@ -44,6 +50,7 @@ public void completeMethod() { } } } catch (InterruptedException e) { + logger.error("Method runner was interrupted! Please report this.", e); Thread.currentThread().interrupt(); } } diff --git a/src/main/java/dev/jacrispys/JavaBot/audio/AudioPlayerButtons.java b/src/main/java/dev/jacrispys/JavaBot/audio/AudioPlayerButtons.java index 13b9599..8717807 100644 --- a/src/main/java/dev/jacrispys/JavaBot/audio/AudioPlayerButtons.java +++ b/src/main/java/dev/jacrispys/JavaBot/audio/AudioPlayerButtons.java @@ -25,6 +25,7 @@ public class AudioPlayerButtons extends ListenerAdapter { private GuildAudioManager audioManager; + private final SpotifyManager man = SpotifyManager.getInstance(); /** * Listen's for a ButtonInteractionEvent and then checks and edit's an embed according to what each button is mapped to. @@ -131,14 +132,14 @@ private EmbedBuilder updateEmbed(MessageEmbed embed, int page) { } else { time = ("[" + DurationFormatUtils.formatDuration(track.getDuration(), "HH:mm:ss") + "]"); } - String artistLink = "https://open.spotify.com/artist/" + SpotifyManager.getArtistId(track.getIdentifier()); + String artistLink = "https://open.spotify.com/artist/" + man.getArtistId(track.getIdentifier()); if (i < 5) { queue.append((page - 1) * 10 + i + 1).append(". [").append(track.getInfo().author).append("](").append(artistLink).append(") - [").append(track.getInfo().title).append("](").append(track.getInfo().uri).append(") ").append(time).append(" \n"); } else { queue2.append((page - 1) * 10 + i + 1).append(". [").append(track.getInfo().author).append("](").append(artistLink).append(") - [").append(track.getInfo().title).append("](").append(track.getInfo().uri).append(") ").append(time).append(" \n"); } - } catch (IndexOutOfBoundsException | IOException ex) { + } catch (IndexOutOfBoundsException ex) { break; } } diff --git a/src/main/java/dev/jacrispys/JavaBot/audio/GenerateGenrePlaylist.java b/src/main/java/dev/jacrispys/JavaBot/audio/GenerateGenrePlaylist.java index ea15057..43f96df 100644 --- a/src/main/java/dev/jacrispys/JavaBot/audio/GenerateGenrePlaylist.java +++ b/src/main/java/dev/jacrispys/JavaBot/audio/GenerateGenrePlaylist.java @@ -2,10 +2,12 @@ import com.neovisionaries.i18n.CountryCode; import dev.jacrispys.JavaBot.audio.objects.Genres; -import dev.jacrispys.JavaBot.utils.mysql.MySQLConnection; import dev.jacrispys.JavaBot.utils.SpotifyManager; +import dev.jacrispys.JavaBot.utils.mysql.MySQLConnection; import net.dv8tion.jda.api.EmbedBuilder; -import net.dv8tion.jda.api.entities.*; +import net.dv8tion.jda.api.entities.Guild; +import net.dv8tion.jda.api.entities.MessageEmbed; +import net.dv8tion.jda.api.entities.User; import net.dv8tion.jda.api.entities.channel.concrete.TextChannel; import net.dv8tion.jda.api.entities.channel.concrete.VoiceChannel; import net.dv8tion.jda.api.entities.emoji.UnicodeEmoji; @@ -15,7 +17,6 @@ import net.dv8tion.jda.api.hooks.ListenerAdapter; import net.dv8tion.jda.api.interactions.components.ActionRow; import net.dv8tion.jda.api.interactions.components.buttons.Button; -import net.dv8tion.jda.api.utils.messages.MessageCreateData; import net.dv8tion.jda.api.utils.messages.MessageEditData; import org.apache.hc.core5.http.ParseException; import org.jetbrains.annotations.NotNull; @@ -26,10 +27,11 @@ import java.io.IOException; import java.sql.SQLException; import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; -import static dev.jacrispys.JavaBot.audio.GuildAudioManager.nowPlayingId; - /** * Util class that uses {@link se.michaelthelin.spotify.SpotifyApi} *
to generate playlists based off of {@link Recommendations} @@ -54,8 +56,8 @@ public GenerateGenrePlaylist() { * @throws ParseException if the spotify request fails * @throws SpotifyWebApiException if the spotify request fails */ - public Recommendations generatePlaylistFromGenre(String genres, int limit) throws IOException, ParseException, SpotifyWebApiException { - final GetRecommendationsRequest request = SpotifyManager.getInstance().getSpotifyApi().getRecommendations().market(CountryCode.US).seed_genres(genres).limit(limit).build(); + public Recommendations generatePlaylistFromGenre(String genres, int limit) throws IOException, ParseException, SpotifyWebApiException, ExecutionException, InterruptedException, TimeoutException { + final GetRecommendationsRequest request = SpotifyManager.getInstance().getSpotifyApi().get(10000, TimeUnit.MILLISECONDS).getRecommendations().market(CountryCode.US).seed_genres(genres).limit(limit).build(); return request.execute(); } @@ -69,9 +71,9 @@ public Recommendations generatePlaylistFromGenre(String genres, int limit) throw * @throws ParseException if the spotify request fails * @throws SpotifyWebApiException if the spotify request fails */ - public Recommendations generatePlaylistFromGenre(String genres, int limit, int popularity) throws IOException, ParseException, SpotifyWebApiException { + public Recommendations generatePlaylistFromGenre(String genres, int limit, int popularity) throws IOException, ParseException, SpotifyWebApiException, ExecutionException, InterruptedException, TimeoutException { if (popularity > 100 || popularity < 0) popularity = 100; - final GetRecommendationsRequest request = SpotifyManager.getInstance().getSpotifyApi().getRecommendations().market(CountryCode.US).seed_genres(genres).limit(limit).target_popularity(popularity).build(); + final GetRecommendationsRequest request = SpotifyManager.getInstance().getSpotifyApi().get(10000, TimeUnit.MILLISECONDS).getRecommendations().market(CountryCode.US).seed_genres(genres).limit(limit).target_popularity(popularity).build(); return request.execute(); } diff --git a/src/main/java/dev/jacrispys/JavaBot/audio/GuildAudioManager.java b/src/main/java/dev/jacrispys/JavaBot/audio/GuildAudioManager.java index 54755a9..4936aa3 100644 --- a/src/main/java/dev/jacrispys/JavaBot/audio/GuildAudioManager.java +++ b/src/main/java/dev/jacrispys/JavaBot/audio/GuildAudioManager.java @@ -3,8 +3,6 @@ import com.sedmelluq.discord.lavaplayer.player.AudioLoadResultHandler; import com.sedmelluq.discord.lavaplayer.player.AudioPlayer; import com.sedmelluq.discord.lavaplayer.player.AudioPlayerManager; -import com.sedmelluq.discord.lavaplayer.source.youtube.YoutubeAudioSourceManager; -import com.sedmelluq.discord.lavaplayer.source.youtube.YoutubeHttpContextFilter; import com.sedmelluq.discord.lavaplayer.tools.FriendlyException; import com.sedmelluq.discord.lavaplayer.track.AudioPlaylist; import com.sedmelluq.discord.lavaplayer.track.AudioTrack; @@ -42,15 +40,12 @@ import se.michaelthelin.spotify.model_objects.specification.TrackSimplified; import java.awt.*; -import java.io.IOException; import java.sql.SQLException; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoField; import java.util.List; import java.util.*; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.*; import java.util.function.Function; import static dev.jacrispys.JavaBot.JavaBotMain.audioManager; @@ -64,6 +59,8 @@ public class GuildAudioManager { private static final String className = GuildAudioManager.class.getSimpleName(); private final LoadAudioHandler audioHandler; + private final SpotifyManager man = SpotifyManager.getInstance(); + private boolean djEnabled = false; private static final Map audioManagers = new HashMap<>(); @@ -408,7 +405,7 @@ public MessageData displayQueue() { for (int i = 1; i <= 10; i++) { try { AudioTrack track = trackList.get(i - 1); - String artistLink = "https://open.spotify.com/artist/" + SpotifyManager.getArtistId(track.getIdentifier()); + String artistLink = "https://open.spotify.com/artist/" + man.getArtistId(track.getIdentifier()).get(1000, TimeUnit.MILLISECONDS); String time; if (track.getDuration() < 3600000) { time = ("[" + DurationFormatUtils.formatDuration(track.getDuration(), "mm:ss") + "]"); @@ -420,7 +417,8 @@ public MessageData displayQueue() { } else { queue2.append(i).append(". [").append(track.getInfo().author).append("](").append(artistLink).append(") - [").append(track.getInfo().title).append("](").append(track.getInfo().uri).append(") ").append(time).append(" \n"); } - } catch (IndexOutOfBoundsException | IOException ex) { + } catch (IndexOutOfBoundsException | InterruptedException | ExecutionException | + TimeoutException ex) { break; } } diff --git a/src/main/java/dev/jacrispys/JavaBot/commands/UnclassifiedSlashCommands.java b/src/main/java/dev/jacrispys/JavaBot/commands/UnclassifiedSlashCommands.java index e61ee90..7039bc5 100644 --- a/src/main/java/dev/jacrispys/JavaBot/commands/UnclassifiedSlashCommands.java +++ b/src/main/java/dev/jacrispys/JavaBot/commands/UnclassifiedSlashCommands.java @@ -28,12 +28,12 @@ */ public class UnclassifiedSlashCommands extends ListenerAdapter { - private MySqlStats sqlStats; + //private final MySqlStats sqlStats; private static JDA jda; public UnclassifiedSlashCommands(JDA jda) { this.jda = jda; - this.sqlStats = MySqlStats.getInstance(); + //this.sqlStats = MySqlStats.getInstance(); } public void initCommands(List guilds) { @@ -41,7 +41,7 @@ public void initCommands(List guilds) { guilds.forEach(this::updateGuildCommands); } - public List updateJdaCommands() { + public static List updateJdaCommands() { List commands = new ArrayList<>(); Collections.addAll(commands, Commands.slash("setnick", "Sets the nickname of this bot, or a user.") @@ -60,7 +60,7 @@ protected void updateGuildCommands(Guild guild) { @SuppressWarnings("all") public void onSlashCommandInteraction(@NotNull SlashCommandInteractionEvent event) { // Increment SQL Stat - sqlStats.incrementGuildStat(event.getGuild().getIdLong(), StatType.COMMAND_COUNTER); + //sqlStats.incrementGuildStat(event.getGuild().getIdLong(), StatType.COMMAND_COUNTER); String commandName = event.getName(); switch (commandName) { case "setnick" -> { diff --git a/src/main/java/dev/jacrispys/JavaBot/commands/audio/SlashMusicCommands.java b/src/main/java/dev/jacrispys/JavaBot/commands/audio/SlashMusicCommands.java index d93c7c3..12c2a4e 100644 --- a/src/main/java/dev/jacrispys/JavaBot/commands/audio/SlashMusicCommands.java +++ b/src/main/java/dev/jacrispys/JavaBot/commands/audio/SlashMusicCommands.java @@ -54,7 +54,7 @@ public void initCommands(List guilds) { * Generates a list of commands to be updated {@link ListenerAdapter#onReady(ReadyEvent)} * @return the list of Commands */ - public List updateJdaCommands() { + public static List updateJdaCommands() { List commands = new ArrayList<>(); Collections.addAll(commands, Commands.slash("play", "Add a link to most streaming platforms, or use its name to search!") diff --git a/src/main/java/dev/jacrispys/JavaBot/events/BotStartup.java b/src/main/java/dev/jacrispys/JavaBot/events/BotStartup.java index 37c1408..7decede 100644 --- a/src/main/java/dev/jacrispys/JavaBot/events/BotStartup.java +++ b/src/main/java/dev/jacrispys/JavaBot/events/BotStartup.java @@ -19,23 +19,26 @@ */ public class BotStartup extends ListenerAdapter { - private final MySQLConnection connection = MySQLConnection.getInstance(); + public MySQLConnection getConnection() { + return MySQLConnection.getInstance(); + } + @Override public void onReady(@NotNull ReadyEvent event) { List commands = new ArrayList<>(); - commands.addAll(new SlashMusicCommands().updateJdaCommands()); - commands.addAll(new UnclassifiedSlashCommands(event.getJDA()).updateJdaCommands()); + commands.addAll(SlashMusicCommands.updateJdaCommands()); + commands.addAll(UnclassifiedSlashCommands.updateJdaCommands()); event.getJDA().addEventListener(new SlashDebugCommands(event.getJDA())); event.getJDA().updateCommands().addCommands(commands).queue(); for (Guild guild : event.getJDA().getGuilds()) { - connection.registerGuild(guild, guild.getTextChannels().get(0)); + getConnection().registerGuild(guild, guild.getTextChannels().get(0)); } } @Override public void onGuildJoin(@NotNull GuildJoinEvent event) { - connection.registerGuild(event.getGuild(), event.getGuild().getTextChannels().get(0)); + getConnection().registerGuild(event.getGuild(), event.getGuild().getTextChannels().get(0)); } } diff --git a/src/main/java/dev/jacrispys/JavaBot/utils/SecretData.java b/src/main/java/dev/jacrispys/JavaBot/utils/SecretData.java index f9bded8..fc291b8 100644 --- a/src/main/java/dev/jacrispys/JavaBot/utils/SecretData.java +++ b/src/main/java/dev/jacrispys/JavaBot/utils/SecretData.java @@ -14,14 +14,21 @@ public class SecretData { private static Yaml yaml = new Yaml(); private static Map loginInfo; - public static void initLoginInfo() throws IOException { + private static final String DEFAULT_PATH_DIR = "src/main/resources/loginInfo.yml"; + + public static void initLoginInfo(String path) throws IOException { yaml = new Yaml(); - loginInfo = yaml.load(generateSecretData()); + loginInfo = yaml.load(generateSecretData(path)); + } + + public static void initLoginInfo() throws IOException { + initLoginInfo(DEFAULT_PATH_DIR); } - protected static InputStream generateSecretData() throws IOException { - if (SecretData.class.getClassLoader().getResourceAsStream("loginInfo.yml") == null) { - File file = new File("src/main/resources/loginInfo.yml"); + private static InputStream generateSecretData(String path) throws IOException { + File fileExists = new File(path); + if (!fileExists.exists()) { + File file = new File(path); if (file.getParentFile() != null) file.getParentFile().mkdirs(); if (file.createNewFile()) { Map fileInfo = getDefaultConfig(); @@ -38,11 +45,12 @@ protected static InputStream generateSecretData() throws IOException { return new FileInputStream(file); } else throw new FileNotFoundException("Could not create required config file!"); - } else return SecretData.class.getClassLoader().getResourceAsStream("loginInfo.yml"); + } else return new FileInputStream(path); } + @SafeVarargs @NotNull - private static Map getDefaultConfig() { + private static Map getDefaultConfig(Map.Entry... entry) { Map fileInfo = new HashMap<>(); fileInfo.put("DATA_BASE_PASS", " "); fileInfo.put("TOKEN", " "); @@ -55,6 +63,9 @@ private static Map getDefaultConfig() { fileInfo.put("DB_HOST", "localhost"); fileInfo.put("BOT_CLIENT_ID", " "); fileInfo.put("BOT_CLIENT_SECRET", " "); + for (Map.Entry entryArgs : entry) { + fileInfo.put(entryArgs.getKey(), entryArgs.getValue()); + } return fileInfo; } diff --git a/src/main/java/dev/jacrispys/JavaBot/utils/SpotifyManager.java b/src/main/java/dev/jacrispys/JavaBot/utils/SpotifyManager.java index b453a50..e0d8964 100644 --- a/src/main/java/dev/jacrispys/JavaBot/utils/SpotifyManager.java +++ b/src/main/java/dev/jacrispys/JavaBot/utils/SpotifyManager.java @@ -3,6 +3,7 @@ import com.sedmelluq.discord.lavaplayer.tools.JsonBrowser; import com.sedmelluq.discord.lavaplayer.tools.io.HttpClientTools; import com.sedmelluq.discord.lavaplayer.tools.io.HttpInterfaceManager; +import dev.jacrispys.JavaBot.api.libs.utils.async.AsyncHandlerImpl; import org.apache.http.ParseException; import org.apache.http.client.methods.HttpGet; import org.slf4j.Logger; @@ -12,48 +13,101 @@ import se.michaelthelin.spotify.requests.authorization.client_credentials.ClientCredentialsRequest; import java.io.IOException; +import java.util.concurrent.CompletableFuture; /** * Manages instances of the {@link SpotifyApi} */ -public class SpotifyManager { +public class SpotifyManager extends AsyncHandlerImpl { - private final Thread thread; + private Thread thread; private static final Logger logger = LoggerFactory.getLogger(SpotifyManager.class); - private final SpotifyApi spotifyApi; + private SpotifyApi spotifyApi; private static SpotifyManager instance = null; private static String accessToken; + private long cooldown; + /** * Uses credentials to obtain connection to spotify api */ private SpotifyManager() { - instance = this; - this.spotifyApi = new SpotifyApi.Builder().setClientId(SecretData.getSpotifyId()).setClientSecret(SecretData.getSpotifySecret()).build(); - ClientCredentialsRequest clientCredentialsRequest = spotifyApi.clientCredentials().build(); + runThread(); + } + + + /** + * Thread update order is as follows... + * Check cooldown & execute spotify token update + * Check queues for artist ID requests + */ + private void runThread() { this.thread = new Thread(() -> { - try { - while (true) { - try { - var clientCredentials = clientCredentialsRequest.execute(); - accessToken = clientCredentials.getAccessToken(); - spotifyApi.setAccessToken(clientCredentials.getAccessToken()); - Thread.sleep((clientCredentials.getExpiresIn() - 10) * 1000L); - } catch (IOException | SpotifyWebApiException | ParseException e) { - logger.error("Failed to update the spotify access token. Retrying in 1 minute ", e); - Thread.sleep(60 * 1000); - } + while (true) { + if (System.currentTimeMillis() >= cooldown && voidMethodQueue.isEmpty()) { + CompletableFuture cf = new CompletableFuture<>(); + this.voidMethodQueue.add(new AsyncHandlerImpl.VoidMethodRunner(this::retrieveToken, cf)); + completeVoid(); + continue; + } + if (spotifyApi != null) { + completeMethod(); } - } catch (Exception e) { - logger.error("Failed to update the spotify access token", e); } }); thread.setDaemon(true); thread.start(); } - public SpotifyApi getSpotifyApi() { + private void retrieveToken() { + instance = this; + this.spotifyApi = new SpotifyApi.Builder().setClientId(SecretData.getSpotifyId()).setClientSecret(SecretData.getSpotifySecret()).build(); + ClientCredentialsRequest clientCredentialsRequest = spotifyApi.clientCredentials().build(); + try { + try { + var clientCredentials = clientCredentialsRequest.execute(); + accessToken = clientCredentials.getAccessToken(); + spotifyApi.setAccessToken(clientCredentials.getAccessToken()); + cooldown = System.currentTimeMillis() + ((clientCredentials.getExpiresIn() - 10) * 1000L); + } catch (IOException | SpotifyWebApiException | ParseException e) { + logger.error("Failed to update the spotify access token. Retrying in 1 minute ", e); + cooldown = System.currentTimeMillis() + (60 * 1000); + } + } catch (Exception e) { + logger.error("Failed to update the spotify access token", e); + } + } + + + public CompletableFuture getArtistId(String id) { + CompletableFuture cf = new CompletableFuture<>(); + this.methodQueue.add(new AsyncHandlerImpl.MethodRunner(() -> { + try { + String s = null; + while (s == null) { + s = getArtistIdAsync(id); + } + cf.complete(s); + } catch (IOException e) { + logger.error("Error occurred when fetching JSON", e); + } + }, cf)); + return cf; + } + + public CompletableFuture getSpotifyApi() { + CompletableFuture cf = new CompletableFuture<>(); + this.methodQueue.add(new AsyncHandlerImpl.MethodRunner(() -> { + do { + } while (spotifyApi == null); + cf.complete(getSpotifyApiAsync()); + }, cf)); + return cf; + + } + + private SpotifyApi getSpotifyApiAsync() { return this.spotifyApi; } @@ -70,36 +124,63 @@ public static String getAccessToken() { /** * Gets json object from a given spotify URI + * * @param uri spotify api endpoint to obtain * @return json data received from the http request * @throws IOException if http request fails */ - private static JsonBrowser getJson(String uri) throws IOException { - try { - var request = new HttpGet(uri); - request.addHeader("Authorization", "Bearer " + getAccessToken()); - return HttpClientTools.fetchResponseAsJson(httpInterfaceManager.getInterface(), request); - } catch (Exception ignored) { - return null; - } + private JsonBrowser getJson(String uri) throws IOException { + var request = new HttpGet(uri); + request.addHeader("Authorization", "Bearer " + getAccessToken()); + return HttpClientTools.fetchResponseAsJson(httpInterfaceManager.getInterface(), request); } /** * Obtains the UUID of an artist from a given song */ - public static String getArtistId(String id) throws IOException { - try { - var json = getJson(API_BASE + "tracks/" + id); - if (json == null || json.get("artists").values().isEmpty()) { - return null; - } - return json.get("artists").index(0).get("id").text(); - } catch (Exception ignored) { + private String getArtistIdAsync(String id) throws IOException { + do { + } while (accessToken == null); + var json = getJson(API_BASE + "tracks/" + id); + if (json == null || json.get("artists").values().isEmpty()) { return null; } + return json.get("artists").index(0).get("id").text(); } public static SpotifyManager getInstance() { return instance != null ? instance : new SpotifyManager(); } + + + @Override + public void completeVoid() { + try { + do { + VoidMethodRunner runner = voidMethodQueue.take(); + runner.runnable().run(); + runner.cf().complete(null); + } while (!voidMethodQueue.isEmpty()); + } catch (InterruptedException e) { + logger.error("Method runner was interrupted! Please report this.", e); + Thread.currentThread().interrupt(); + } + } + + @Override + public void completeMethod() { + try { + do { + MethodRunner runner = methodQueue.take(); + runner.runnable().run(); + while (true) { + if (!runner.cf().isDone() && !runner.cf().isCancelled()) continue; + break; + } + } while (!methodQueue.isEmpty()); + } catch (InterruptedException e) { + logger.error("Method runner was interrupted! Please report this.", e); + Thread.currentThread().interrupt(); + } + } } diff --git a/src/test/java/unit/SecretDataTests.java b/src/test/java/unit/SecretDataTests.java new file mode 100644 index 0000000..06fd1f6 --- /dev/null +++ b/src/test/java/unit/SecretDataTests.java @@ -0,0 +1,33 @@ +package unit; + +import dev.jacrispys.JavaBot.utils.SecretData; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.util.UUID; + +public class SecretDataTests { + + private static final String PATH_DIR = "src/test/java/unit/resources/"; + private static final UUID uuid = UUID.randomUUID(); + + @Test + void generateNewSecretFile() throws IOException { + String path = PATH_DIR + "_" + uuid + ".yml"; + SecretData.initLoginInfo(path); + File file = new File(path); + Assertions.assertTrue(file.exists()); + } + + @AfterAll + public static void clean() { + File dir = new File(PATH_DIR); + for (File file : dir.listFiles()) { + file.delete(); + } + } + +} diff --git a/src/test/java/unit/SpManTests.java b/src/test/java/unit/SpManTests.java new file mode 100644 index 0000000..6eebea4 --- /dev/null +++ b/src/test/java/unit/SpManTests.java @@ -0,0 +1,38 @@ +package unit; + +import dev.jacrispys.JavaBot.utils.SecretData; +import dev.jacrispys.JavaBot.utils.SpotifyManager; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class SpManTests { + + private final String TEST_TRACK = "6shRGWCtBUOPFLFTTqXZIC?si=f88f88fde04e49c4"; + @Test + void spotifyManagerInstance() { + SpotifyManager man = SpotifyManager.getInstance(); + Assertions.assertNotNull(man); + } + + @Test + void spotifyApi() throws IOException, ExecutionException, InterruptedException, TimeoutException { + SecretData.initLoginInfo(); + SpotifyManager man = SpotifyManager.getInstance(); + String host = man.getSpotifyApi().get().getHost(); + Assertions.assertNotNull(host); + } + + @Test + void spotifyArtistId() throws IOException, ExecutionException, InterruptedException, TimeoutException { + String expected = "3GBPw9NK25X1Wt2OUvOwY3"; + SecretData.initLoginInfo(); + SpotifyManager man = SpotifyManager.getInstance(); + String actual = man.getArtistId(TEST_TRACK).get(10000, TimeUnit.MILLISECONDS); + Assertions.assertEquals(expected, actual); + } +} diff --git a/src/test/java/unit/StartupTests.java b/src/test/java/unit/StartupTests.java new file mode 100644 index 0000000..47b0290 --- /dev/null +++ b/src/test/java/unit/StartupTests.java @@ -0,0 +1,55 @@ +package unit; + +import org.junit.jupiter.api.Test; +import unit.mocks.ReadyEventMock; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class StartupTests { + + @Test + void commandRegisters() { + List expected = new ArrayList<>(); + expected.add("setnick"); + expected.add("embedbuilder"); + expected.add("auth-token"); + + expected.add("play"); + expected.add("skip"); + expected.add("volume"); + expected.add("clear"); + expected.add("stop"); + expected.add("pause"); + expected.add("resume"); + expected.add("dc"); + expected.add("leave"); + expected.add("disconnect"); + expected.add("follow"); + expected.add("queue"); + expected.add("shuffle"); + expected.add("song"); + expected.add("song-info"); + expected.add("info"); + expected.add("remove"); + expected.add("seek"); + expected.add("fix"); + expected.add("loop"); + expected.add("move"); + expected.add("hijack"); + expected.add("playtop"); + expected.add("skipto"); + expected.add("fileplay"); + expected.add("radio"); + + Collections.sort(expected); + + ReadyEventMock.assertUpdateCommands(ReadyEventMock.mockBotStartup(), expected); + } + + @Test + void readyStatus() { + ReadyEventMock.assertReadyStatus(ReadyEventMock.mockBotStartup()); + } +} diff --git a/src/test/java/unit/UnclassifiedCommandsTests.java b/src/test/java/unit/UnclassifiedCommandsTests.java new file mode 100644 index 0000000..9e2436f --- /dev/null +++ b/src/test/java/unit/UnclassifiedCommandsTests.java @@ -0,0 +1,27 @@ +package unit; + +import dev.jacrispys.JavaBot.commands.UnclassifiedSlashCommands; +import net.dv8tion.jda.api.Permission; +import net.dv8tion.jda.api.entities.Member; +import net.dv8tion.jda.api.entities.Message; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import unit.mocks.JDAMock; +import unit.mocks.SlashCommandMock; + +import java.util.List; + +public class UnclassifiedCommandsTests { + + @Test + void checkPermSetNick() throws InterruptedException { + Member member = JDAMock.getMember("Jacrispy", List.of(Permission.ADMINISTRATOR)); + String name = "setnick"; + Message m = SlashCommandMock.testSlashCommandReply(new UnclassifiedSlashCommands(JDAMock.getJDA()), name, member); + String actual = m.getContentRaw(); + String expected = "You do not have permission to use this command!"; + + Assertions.assertEquals(expected, actual); + } + +} diff --git a/src/test/java/unit/mocks/BypassDb.java b/src/test/java/unit/mocks/BypassDb.java new file mode 100644 index 0000000..42df343 --- /dev/null +++ b/src/test/java/unit/mocks/BypassDb.java @@ -0,0 +1,29 @@ +package unit.mocks; + +import dev.jacrispys.JavaBot.utils.mysql.MySQLConnection; +import org.mockito.Mockito; + +import java.sql.ResultSet; + +public class BypassDb { + + public static MySQLConnection mockSqlConnection() { + MySQLConnection connection = Mockito.mock(MySQLConnection.class); + try { + Mockito.when(connection.registerGuild(Mockito.any(), Mockito.any())).thenAnswer(invocationOnMock -> true); + Mockito.when(connection.getMusicChannel(Mockito.any())).thenAnswer(invocationOnMock -> 0L); + Mockito.when(connection.queryCommand(Mockito.anyString())).thenAnswer(invocationOnMock -> mockResultSet()); + } catch (Exception e) { + throw new RuntimeException(e); + } + return connection; + } + + public static ResultSet mockResultSet() { + ResultSet set = Mockito.mock(ResultSet.class); + + Mockito.doAnswer(invocationOnMock -> null).when(set); + + return set; + } +} diff --git a/src/test/java/unit/mocks/JDAMock.java b/src/test/java/unit/mocks/JDAMock.java new file mode 100644 index 0000000..2a42e73 --- /dev/null +++ b/src/test/java/unit/mocks/JDAMock.java @@ -0,0 +1,99 @@ +package unit.mocks; + +import net.dv8tion.jda.api.JDA; +import net.dv8tion.jda.api.Permission; +import net.dv8tion.jda.api.entities.Member; +import net.dv8tion.jda.api.entities.SelfUser; +import net.dv8tion.jda.api.entities.User; +import net.dv8tion.jda.api.interactions.commands.Command; +import net.dv8tion.jda.api.interactions.commands.build.CommandData; +import net.dv8tion.jda.api.requests.restaction.CommandListUpdateAction; +import net.dv8tion.jda.internal.JDAImpl; +import net.dv8tion.jda.internal.utils.PermissionUtil; +import org.jetbrains.annotations.NotNull; +import org.mockito.MockingDetails; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.List; + +public class JDAMock { + private static final List commandList = new ArrayList<>(); + + + public static JDA getJDA() { + return getJDA("UnitTesting", "1111"); + } + + @NotNull + public static JDA getJDA(String name, String discriminator) { + try { + JDA jda = Mockito.mock(JDAImpl.class); + Mockito.when(jda.getStatus()).thenAnswer(invocation -> JDA.Status.CONNECTED); + Mockito.when(jda.unloadUser(Mockito.anyLong())).thenAnswer(invocation -> true); + Mockito.when(jda.awaitReady()).thenAnswer(invocationOnMock -> jda); + Mockito.when(jda.awaitStatus(Mockito.any(JDA.Status.class))).thenAnswer(invocationOnMock -> jda); + Mockito.when(jda.awaitStatus(Mockito.any(JDA.Status.class), Mockito.any(JDA.Status[].class))).thenAnswer(invocationOnMock -> jda); + Mockito.when(jda.getSelfUser()).thenAnswer(invocationOnMock -> getSelfUser(name, discriminator)); + Mockito.when(jda.updateCommands()).thenAnswer(invocationOnMock -> getCommandUpdate()); + Mockito.doNothing().when(jda).addEventListener(Mockito.any()); + + return jda; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private static CommandListUpdateAction getCommandUpdate() { + CommandListUpdateAction update = Mockito.mock(CommandListUpdateAction.class); + + Mockito.when(update.addCommands(Mockito.anyList())).thenAnswer(invocationOnMock -> { + commandList.addAll(invocationOnMock.getArgument(0)); + return update; + }); + + + Mockito.doAnswer(invocationOnMock -> update).when(update).queue(); + + return update; + } + + + public static SelfUser getSelfUser(String name, String discriminator) { + SelfUser selfUser = Mockito.mock(SelfUser.class); + + Mockito.when(selfUser.getName()).thenAnswer(invocationOnMock -> name); + Mockito.when(selfUser.getId()).thenAnswer(invocationOnMock -> "0"); + Mockito.when(selfUser.getIdLong()).thenAnswer(invocationOnMock -> 0L); + + return selfUser; + } + + + public static Member getMember(String name, List perms) { + Member member = Mockito.mock(Member.class); + + Mockito.when(member.getEffectiveName()).thenAnswer(invocationOnMock -> name); + Mockito.when(member.getNickname()).thenAnswer(invocationOnMock -> name); + Mockito.when(member.hasPermission(Mockito.anyList())).thenAnswer(invocationOnMock -> checkPerms(perms, invocationOnMock.getArgument(0))); + + User user = Mockito.mock(User.class); + Mockito.when(user.getName()).thenAnswer(invocationOnMock -> name); + + Mockito.when(member.getUser()).thenAnswer(invocationOnMock -> user); + + return member; + + } + + private static boolean checkPerms(List current, Permission... toCheck) { + long curr = Permission.getRaw(current); + long check = Permission.getRaw(toCheck); + System.out.println("Check Perms"); + return (curr & check) == check; + } + + public static List getCommandList() { + return commandList; + } +} diff --git a/src/test/java/unit/mocks/MessageMock.java b/src/test/java/unit/mocks/MessageMock.java new file mode 100644 index 0000000..91d277f --- /dev/null +++ b/src/test/java/unit/mocks/MessageMock.java @@ -0,0 +1,109 @@ +package unit.mocks; + +import net.dv8tion.jda.api.entities.Message; +import net.dv8tion.jda.api.entities.MessageEmbed; +import net.dv8tion.jda.api.entities.channel.concrete.NewsChannel; +import net.dv8tion.jda.api.entities.channel.concrete.TextChannel; +import net.dv8tion.jda.api.entities.channel.middleman.MessageChannel; +import net.dv8tion.jda.api.entities.channel.unions.MessageChannelUnion; +import net.dv8tion.jda.api.requests.restaction.MessageCreateAction; +import net.dv8tion.jda.api.utils.messages.MessageCreateData; +import org.mockito.Mockito; +import unit.mocks.util.Callback; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.mockito.Mockito.*; + +/** + * Mock Message impls from: here. + */ +public class MessageMock { + + /** + * Get a mocked {@link MessageChannel}. + * + * @param name the name of the channel. + * @param id the id of the channel. + * @param messageCallback the callback used for returning the {@link Message} when for example + * {@link MessageChannel#sendMessage(CharSequence)} or other methods are called. + * @return a mocked {@link MessageChannel}. + */ + public static MessageChannel getMessageChannel(String name, long id, Callback messageCallback) { + MessageChannel channel = mock(MessageChannel.class, withSettings() + .extraInterfaces(TextChannel.class, NewsChannel.class, MessageChannelUnion.class)); + when(channel.getName()).thenAnswer(invocation -> name); + when(channel.getIdLong()).thenAnswer(invocation -> id); + when(channel.getId()).thenAnswer(invocation -> String.valueOf(id)); + + when(channel.sendMessage(any(CharSequence.class))) + .thenAnswer(invocation -> getMessageCreateAction(messageCallback, + getMessage(invocation.getArgument(0), channel))); + + when(channel.sendMessage(any(MessageCreateData.class))) + .thenAnswer(invocation -> getMessageCreateAction(messageCallback, + invocation.getArgument(0))); + + when(channel.sendMessageEmbeds(any(MessageEmbed.class), any(MessageEmbed[].class))) + .thenAnswer(invocation -> { + List embeds = invocation.getArguments().length == 1 ? new ArrayList<>() : + Arrays.asList(invocation.getArgument(1)); + embeds.add(invocation.getArgument(0)); + return getMessageCreateAction(messageCallback, getMessage(null, embeds, channel)); + }); + + when(channel.sendMessageEmbeds(anyList())) + .thenAnswer(invocation -> getMessageCreateAction(messageCallback, + getMessage(null, invocation.getArgument(0), channel))); + + return channel; + } + + /** + * Get a mocked {@link MessageCreateAction}. + * + * @param messageCallback the message callback that well return the {@link Message} when + * {@link MessageCreateAction#queue()} is executed. + * @param message the message that will be used by the {@link Callback}. + * @return a mocked {@link MessageCreateAction}. + */ + public static MessageCreateAction getMessageCreateAction(Callback messageCallback, Message message) { + MessageCreateAction messageAction = mock(MessageCreateAction.class); + Mockito.doAnswer(invocation -> { + messageCallback.callback(message); + return null; + }).when(messageAction).queue(); + return messageAction; + } + + /** + * Get a mocked {@link Message}. + * + * @param content the content of the message. This is the raw, displayed and stripped content. + * @param channel the {@link MessageChannel} the message would be sent in. + * @return a mocked {@link Message}. + */ + public static Message getMessage(String content, MessageChannel channel) { + return getMessage(content, new ArrayList<>(), channel); + } + + /** + * Get a mocked {@link Message}. + * + * @param content the content of the message. This is the raw, displayed and stripped content. + * @param embeds a list of {@link MessageEmbed}s that this message contains. + * @param channel the {@link MessageChannel} the message would be sent in. + * @return a mocked {@link Message}. + */ + public static Message getMessage(String content, List embeds, MessageChannel channel) { + Message message = mock(Message.class); + when(message.getContentRaw()).thenAnswer(invocation -> content); + when(message.getContentDisplay()).thenAnswer(invocation -> content); + when(message.getContentStripped()).thenAnswer(invocation -> content); + when(message.getChannel()).thenAnswer(invocation -> channel); + when(message.getEmbeds()).thenAnswer(invocation -> embeds); + return message; + } +} diff --git a/src/test/java/unit/mocks/MockRole.java b/src/test/java/unit/mocks/MockRole.java new file mode 100644 index 0000000..61f5e62 --- /dev/null +++ b/src/test/java/unit/mocks/MockRole.java @@ -0,0 +1,18 @@ +package unit.mocks; + +import net.dv8tion.jda.api.Permission; +import net.dv8tion.jda.api.entities.Role; +import org.mockito.Mockito; + +import java.util.EnumSet; + +public class MockRole { + + public static Role mockRole(EnumSet permissions) { + Role role = Mockito.mock(Role.class); + + Mockito.when(role.getPermissionsRaw()).thenAnswer(invocationOnMock -> Permission.getRaw(permissions)); + + return role; + } +} diff --git a/src/test/java/unit/mocks/ReadyEventMock.java b/src/test/java/unit/mocks/ReadyEventMock.java new file mode 100644 index 0000000..811b300 --- /dev/null +++ b/src/test/java/unit/mocks/ReadyEventMock.java @@ -0,0 +1,72 @@ +package unit.mocks; + +import dev.jacrispys.JavaBot.events.BotStartup; +import net.dv8tion.jda.api.events.session.ReadyEvent; +import net.dv8tion.jda.api.events.session.SessionState; +import net.dv8tion.jda.api.hooks.EventListener; +import net.dv8tion.jda.api.interactions.commands.build.CommandData; +import org.junit.jupiter.api.Assertions; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static unit.mocks.JDAMock.getJDA; + +public class ReadyEventMock { + + + public static BotStartup mockBotStartup() { + BotStartup startup = Mockito.spy(BotStartup.class); + + Mockito.doReturn(BypassDb.mockSqlConnection()).when(startup).getConnection(); + Mockito.doCallRealMethod().when(startup).onReady(Mockito.any()); + Mockito.doCallRealMethod().when(startup).onGuildJoin(Mockito.any()); + + return startup; + } + + + public static ReadyEvent getReadyEvent() { + ReadyEvent event = Mockito.mock(ReadyEvent.class); + + Mockito.when(event.getState()).thenAnswer(invocationOnMock -> SessionState.READY); + Mockito.when(event.getJDA()).thenAnswer(invocationOnMock -> getJDA()); + Mockito.when(event.getGuildTotalCount()).thenAnswer(invocationOnMock -> 0); + + return event; + } + + public static ReadyEvent getReadyEventCommands() { + return getReadyEvent(); + } + + + + public static List testReadyEventCommands(EventListener listener) { + ReadyEvent event = getReadyEventCommands(); + + listener.onEvent(event); + return JDAMock.getCommandList(); + } + + public static void assertUpdateCommands(EventListener listener, List expectedOutput) { + List cmds = testReadyEventCommands(listener); + List actual = new ArrayList<>(cmds.stream().map(CommandData::getName).toList()); + + Collections.sort(actual); + Assertions.assertEquals(expectedOutput, actual); + } + + public static SessionState testReadyStatus(EventListener listener) { + ReadyEvent event = getReadyEvent(); + + listener.onEvent(event); + return event.getState(); + } + + public static void assertReadyStatus(EventListener listener) { + Assertions.assertEquals(SessionState.READY, testReadyStatus(listener)); + } +} diff --git a/src/test/java/unit/mocks/SlashCommandMock.java b/src/test/java/unit/mocks/SlashCommandMock.java new file mode 100644 index 0000000..8ce5925 --- /dev/null +++ b/src/test/java/unit/mocks/SlashCommandMock.java @@ -0,0 +1,167 @@ +package unit.mocks; + +import net.dv8tion.jda.api.entities.Member; +import net.dv8tion.jda.api.entities.Message; +import net.dv8tion.jda.api.entities.MessageEmbed; +import net.dv8tion.jda.api.entities.channel.middleman.MessageChannel; +import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; +import net.dv8tion.jda.api.hooks.EventListener; +import net.dv8tion.jda.api.interactions.commands.OptionMapping; +import net.dv8tion.jda.api.requests.restaction.interactions.ReplyCallbackAction; +import net.dv8tion.jda.api.utils.messages.MessageCreateData; +import org.mockito.Mockito; +import unit.mocks.util.Callback; + +import java.util.*; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static unit.mocks.MessageMock.getMessage; + + +/** + * Mock SlashCommandInteraction event from: here. + */ +public class SlashCommandMock { + + /** + * Get a mocked {@link SlashCommandInteractionEvent}. + * + * @param channel the channel this event would be executed. + * @param name the name of the slash command. + * @param subcommandName the name of the subcommand of the slash command. + * @param subcommandGroup the subcommand group of the slash command. + * @param options a map with all options for the slash command a user would have inputted. + * @param messageCallback a callback to receive messages that would be sent back to the channel with. + * @param deferReply a callback that is called when a deferred reply is called. The boolean is true when + * the message is ephemeral and false if not. + * {@link MessageChannel#sendMessage(CharSequence)} + * @return a mocked {@link SlashCommandInteractionEvent}. + */ + public static SlashCommandInteractionEvent getSlashCommandInteractionEvent(MessageChannel channel, String name, + String subcommandName, + String subcommandGroup, + Map options, + Member member, + Callback messageCallback, + Callback deferReply) { + SlashCommandInteractionEvent event = mock(SlashCommandInteractionEvent.class); + when(event.getName()).thenAnswer(invocation -> name); + when(event.getSubcommandName()).thenAnswer(invocation -> subcommandName); + when(event.getSubcommandGroup()).thenAnswer(invocation -> subcommandGroup); + when(event.getChannel()).thenAnswer(invocation -> channel); + when(event.getMember()).thenAnswer(invocationOnMock -> member); + + when(event.getOption(anyString())).thenAnswer(invocation -> { + OptionMapping mapping = mock(OptionMapping.class); + // why + when(mapping.getAsAttachment()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsString()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsBoolean()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsLong()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsInt()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsDouble()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsMentionable()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsMember()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsUser()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsRole()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + when(mapping.getAsChannel()).thenAnswer(inv -> options.get((String) invocation.getArgument(0))); + + return mapping; + }); + + when(event.reply(anyString())).thenAnswer(invocation -> + getReplyCallbackAction(getMessage(invocation.getArgument(0), channel), messageCallback)); + when(event.reply(any(MessageCreateData.class))).thenAnswer(invocation -> + getReplyCallbackAction(getMessage(invocation.getArgument(0, MessageCreateData.class).getContent(), + channel), messageCallback)); + when(event.replyEmbeds(anyList())).thenAnswer(invocation -> + getReplyCallbackAction(getMessage(null, invocation.getArgument(0), channel), + messageCallback)); + when(event.replyEmbeds(any(MessageEmbed.class), any(MessageEmbed[].class))).thenAnswer(invocation -> { + List embeds = invocation.getArguments().length == 1 ? new ArrayList<>() : + Arrays.asList(invocation.getArgument(1)); + embeds.add(invocation.getArgument(0)); + return getReplyCallbackAction(getMessage(null, embeds, channel), messageCallback); + }); + + when(event.deferReply()).thenAnswer(invocation -> { + deferReply.callback(false); + return mock(ReplyCallbackAction.class); + }); + + when(event.deferReply(any(Boolean.class))).thenAnswer(invocation -> { + deferReply.callback(invocation.getArgument(0)); + return mock(ReplyCallbackAction.class); + }); + + return event; + } + + /** + * Get a mocked {@link SlashCommandInteractionEvent}. + * + * @param channel the channel this event would be executed. + * @param name the name of the slash command. + * @param subcommandName the name of the subcommand of the slash command. + * @param subcommandGroup the subcommand group of the slash command. + * @param options a map with all options for the slash command a user would have inputted. + * @param messageCallback a callback to receive messages that would be sent back to the channel with. + * {@link MessageChannel#sendMessage(CharSequence)} + * @return a mocked {@link SlashCommandInteractionEvent}. + */ + public static SlashCommandInteractionEvent getSlashCommandInteractionEvent(MessageChannel channel, String name, + String subcommandName, + String subcommandGroup, + Map options, + Member member, + Callback messageCallback) { + return getSlashCommandInteractionEvent(channel, name, subcommandName, subcommandGroup, options, member, messageCallback, + new Callback<>()); + } + + public static SlashCommandInteractionEvent getSlashCommandInteractionEvent(String cmdName, + Member member + ) { + return getSlashCommandInteractionEvent(null, cmdName, null, null, null, member, null, + new Callback<>()); + } + + /** + * Get a mocked {@link ReplyCallbackAction}. + * + * @param message the message that this reply should produce. + * @param messageCallback the callback for receiving the message. + * @return a mocked {@link ReplyCallbackAction}. + */ + private static ReplyCallbackAction getReplyCallbackAction(Message message, Callback messageCallback) { + ReplyCallbackAction action = mock(ReplyCallbackAction.class); + + Mockito.doAnswer(invocation -> { + messageCallback.callback(message); + return null; + }).when(action).queue(); + + when(action.setEphemeral(anyBoolean())).thenAnswer(invocation -> { + when(message.isEphemeral()).thenReturn(true); + return action; + }); + return action; + } + + + public static Message testSlashCommandReply(EventListener listener, String name, Member member) throws InterruptedException { + Callback messageCallback = new Callback<>(); + + MessageChannel channel = MessageMock.getMessageChannel("test-channel", 0L, messageCallback); + SlashCommandInteractionEvent event = getSlashCommandInteractionEvent(channel, name, null, null, + null, member, messageCallback); + + listener.onEvent(event); + return messageCallback.await(); + } + + + +} diff --git a/src/test/java/unit/mocks/util/Callback.java b/src/test/java/unit/mocks/util/Callback.java new file mode 100644 index 0000000..46a0eb0 --- /dev/null +++ b/src/test/java/unit/mocks/util/Callback.java @@ -0,0 +1,32 @@ +package unit.mocks.util; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class Callback { + + private final CountDownLatch countDownLatch; + private T object; + + public Callback() { + this.countDownLatch = new CountDownLatch(1); + } + + public void callback(T object) { + this.object = object; + this.countDownLatch.countDown(); + } + + public T await() throws InterruptedException { + this.countDownLatch.await(); + return object; + } + + public T await(long timeout, TimeUnit timeUnit) throws InterruptedException { + if (!this.countDownLatch.await(timeout, timeUnit)) { + throw new InterruptedException("Timeout elapsed"); + } + return object; + } + +}