diff --git a/src/main/java/io/antmedia/AntMediaApplicationAdapter.java b/src/main/java/io/antmedia/AntMediaApplicationAdapter.java index 76ddc7c2f..10fad6c68 100644 --- a/src/main/java/io/antmedia/AntMediaApplicationAdapter.java +++ b/src/main/java/io/antmedia/AntMediaApplicationAdapter.java @@ -1509,6 +1509,8 @@ public static boolean updateAppSettingsFile(String appName, AppSettings newAppse store.put(AppSettings.SETTINGS_PUBLISH_JWT_CONTROL_ENABLED, String.valueOf(newAppsettings.isPublishJwtControlEnabled())); store.put(AppSettings.SETTINGS_PLAY_JWT_CONTROL_ENABLED, String.valueOf(newAppsettings.isPlayJwtControlEnabled())); store.put(AppSettings.SETTINGS_JWT_STREAM_SECRET_KEY, newAppsettings.getJwtStreamSecretKey() != null ? newAppsettings.getJwtStreamSecretKey() : ""); + store.put(AppSettings.SETTINGS_JWT_BLACKLIST_ENABLED, String.valueOf(newAppsettings.isJwtBlacklistEnabled())); + store.put(AppSettings.SETTINGS_WEBRTC_ENABLED, String.valueOf(newAppsettings.isWebRTCEnabled())); store.put(AppSettings.SETTINGS_WEBRTC_FRAME_RATE, String.valueOf(newAppsettings.getWebRTCFrameRate())); diff --git a/src/main/java/io/antmedia/AppSettings.java b/src/main/java/io/antmedia/AppSettings.java index 71ed3d78e..7a184b296 100644 --- a/src/main/java/io/antmedia/AppSettings.java +++ b/src/main/java/io/antmedia/AppSettings.java @@ -267,6 +267,7 @@ public class AppSettings implements Serializable{ public static final String SETTINGS_JWT_SECRET_KEY = "settings.jwtSecretKey"; public static final String SETTINGS_JWT_CONTROL_ENABLED = "settings.jwtControlEnabled"; + public static final String SETTINGS_JWT_BLACKLIST_ENABLED = "settings.jwtBlacklistEnabled"; public static final String SETTINGS_IP_FILTER_ENABLED = "settings.ipFilterEnabled"; @@ -1297,6 +1298,8 @@ public class AppSettings implements Serializable{ @Value( "${"+SETTINGS_JWT_CONTROL_ENABLED+":false}" ) private boolean jwtControlEnabled; + @Value( "${"+SETTINGS_JWT_BLACKLIST_ENABLED+":false}" ) + private boolean jwtBlacklistEnabled; /** * Application IP Filter Enabled */ @@ -2668,6 +2671,13 @@ public boolean isJwtControlEnabled() { return jwtControlEnabled; } + public void setJwtBlacklistEnabled(boolean jwtBlacklistEnabled){ + this.jwtBlacklistEnabled = jwtBlacklistEnabled; + } + + public boolean isJwtBlacklistEnabled() { + return jwtBlacklistEnabled; + } public void setJwtControlEnabled(boolean jwtControlEnabled) { this.jwtControlEnabled = jwtControlEnabled; } diff --git a/src/main/java/io/antmedia/datastore/db/DataStore.java b/src/main/java/io/antmedia/datastore/db/DataStore.java index e64af421a..030c55bd8 100644 --- a/src/main/java/io/antmedia/datastore/db/DataStore.java +++ b/src/main/java/io/antmedia/datastore/db/DataStore.java @@ -10,6 +10,8 @@ import java.util.List; import java.util.Map; +import io.antmedia.rest.model.Result; +import io.antmedia.security.ITokenService; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.slf4j.Logger; @@ -424,6 +426,29 @@ public List listAllTokens (Map tokenMap, String streamId, public abstract boolean deleteToken (String tokenId); + /** + * Whitelist specific token. + * @param tokenId id of the token + */ + public abstract boolean whiteListToken(String tokenId); + + /** + * Get all blacklisted tokens. + */ + public abstract List getBlackListedTokens(); + + /** + * Delete all blacklisted expired tokens. + */ + public abstract Result deleteAllBlacklistedExpiredTokens(ITokenService tokenService); + + /** + * Whitelist all blacklisted tokens. + * + * @return + */ + public abstract boolean whiteListAllTokens(); + /** * retrieve specific token * @param tokenId id of the token @@ -1364,7 +1389,18 @@ public List getWebRTCViewerList(Map webRTCView * @param metaData new meta data */ public abstract boolean updateStreamMetaData(String streamId, String metaData); - + + /** + * Blacklist token. + * @param token which will be blacklisted. + */ + public abstract boolean blackListToken(Token token); + + /** + * Get token from blacklist. + * @param tokenId id of the token. + */ + public abstract Token getBlackListedToken(String tokenId); //************************************** //ATTENTION: Write function descriptions while adding new functions diff --git a/src/main/java/io/antmedia/datastore/db/InMemoryDataStore.java b/src/main/java/io/antmedia/datastore/db/InMemoryDataStore.java index f9a1d8906..79632d5de 100644 --- a/src/main/java/io/antmedia/datastore/db/InMemoryDataStore.java +++ b/src/main/java/io/antmedia/datastore/db/InMemoryDataStore.java @@ -2,15 +2,15 @@ import java.io.File; import java.time.Instant; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; -import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.antmedia.rest.model.Result; +import io.antmedia.security.ITokenService; import org.apache.commons.io.FilenameUtils; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -41,8 +41,11 @@ public class InMemoryDataStore extends DataStore { private Map roomMap = new LinkedHashMap<>(); private Map webRTCViewerMap = new LinkedHashMap<>(); + private Gson gson; public InMemoryDataStore(String dbName) { + GsonBuilder builder = new GsonBuilder(); + gson = builder.create(); available = true; } @@ -900,6 +903,69 @@ public boolean deleteToken(String tokenId) { } + @Override + public boolean whiteListToken(String tokenId) { + Token token = getToken(tokenId); + if(token != null && token.isBlackListed()){ + token.setBlackListed(false); + return saveToken(token); + } + + + return false; + } + + @Override + public List getBlackListedTokens() { + + ArrayList tokenBlacklist = new ArrayList<>(); + tokenMap.forEach((tokenId, token) -> { + if(token.isBlackListed()){ + tokenBlacklist.add(tokenId); + } + }); + return tokenBlacklist; + + } + + @Override + public Result deleteAllBlacklistedExpiredTokens(ITokenService tokenService) { + logger.info("Deleting all expired JWTs from token storage."); + AtomicInteger deletedTokenCount = new AtomicInteger(); + + tokenMap.forEach((tokenId, token) -> { + if(token.isBlackListed() && !tokenService.verifyJwt(tokenId,token.getStreamId(),token.getType())){ + if(deleteToken(tokenId)){ + deletedTokenCount.getAndIncrement(); + }else{ + logger.warn("Couldn't delete JWT:{}", tokenId); + } + } + }); + + + if(deletedTokenCount.get() > 0){ + final String successMsg = deletedTokenCount+" JWT deleted successfully from storage."; + logger.info(successMsg); + return new Result(true, successMsg); + }else{ + final String failMsg = "No JWT deleted from storage."; + logger.warn(failMsg); + return new Result(false, failMsg); + } + } + + @Override + public boolean whiteListAllTokens() { + tokenMap.forEach((tokenId, token) -> { + if(token.isBlackListed()){ + whiteListToken(tokenId); + } + }); + + return true; + } + @Override public Token getToken(String tokenId) { @@ -1040,4 +1106,25 @@ public boolean updateStreamMetaData(String streamId, String metaData) { } return result; } + + @Override + public boolean blackListToken(Token token) { + boolean result = false; + + if (token.getStreamId() != null && token.getTokenId() != null) { + token.setBlackListed(true); + return saveToken(token); + } + + return result; + } + + @Override + public Token getBlackListedToken(String tokenId) { + Token token = getToken(tokenId); + if(token != null && token.isBlackListed()){ + return token; + } + return null; + } } \ No newline at end of file diff --git a/src/main/java/io/antmedia/datastore/db/MapBasedDataStore.java b/src/main/java/io/antmedia/datastore/db/MapBasedDataStore.java index 11cd097fb..295f1d6e4 100644 --- a/src/main/java/io/antmedia/datastore/db/MapBasedDataStore.java +++ b/src/main/java/io/antmedia/datastore/db/MapBasedDataStore.java @@ -9,8 +9,11 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; +import io.antmedia.rest.model.Result; +import io.antmedia.security.ITokenService; import org.apache.commons.io.FilenameUtils; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -943,6 +946,79 @@ public boolean deleteToken(String tokenId) { return result; } + @Override + public boolean whiteListToken(String tokenId) { + synchronized (this){ + Token token = getToken(tokenId); + if(token != null && token.isBlackListed()){ + token.setBlackListed(false); + return saveToken(token); + } + } + + return false; + } + + @Override + public List getBlackListedTokens(){ + ArrayList tokenBlacklist = new ArrayList<>(); + synchronized (this){ + tokenMap.forEach((tokenId, tokenAsJson) -> { + Token token = gson.fromJson(tokenAsJson,Token.class); + if(token.isBlackListed()){ + tokenBlacklist.add(tokenId); + } + }); + return tokenBlacklist; + } + } + + @Override + public Result deleteAllBlacklistedExpiredTokens(ITokenService tokenService){ + logger.info("Deleting all expired JWTs from token storage."); + AtomicInteger deletedTokenCount = new AtomicInteger(); + + synchronized (this) { + + tokenMap.forEach((tokenId, tokenAsJson) -> { + Token token = gson.fromJson(tokenAsJson,Token.class); + if(token.isBlackListed() && !tokenService.verifyJwt(tokenId,token.getStreamId(),token.getType())){ + if(deleteToken(tokenId)){ + deletedTokenCount.getAndIncrement(); + }else{ + logger.warn("Couldn't delete JWT:{}", tokenId); + } + } + }); + } + + if(deletedTokenCount.get() > 0){ + final String successMsg = deletedTokenCount+" JWT deleted successfully from storage."; + logger.info(successMsg); + return new Result(true, successMsg); + }else{ + final String failMsg = "No JWT deleted from storage."; + logger.warn(failMsg); + return new Result(false, failMsg); + } + + } + + @Override + public boolean whiteListAllTokens(){ + + synchronized (this) { + tokenMap.forEach((tokenId, tokenAsJson) -> { + Token token = gson.fromJson(tokenAsJson,Token.class); + if(token.isBlackListed()){ + whiteListToken(tokenId); + } + }); + } + return true; + + } + @Override public Token getToken(String tokenId) { return super.getToken(tokenMap, tokenId, gson); @@ -1074,4 +1150,28 @@ public Broadcast getBroadcastFromMap(String streamId) return null; } + @Override + public boolean blackListToken(Token token) { + boolean result = false; + + synchronized (this) { + + if (token.getStreamId() != null && token.getTokenId() != null) { + token.setBlackListed(true); + return saveToken(token); + } + } + return result; + + } + + @Override + public Token getBlackListedToken(String tokenId) { + Token token = getToken(tokenId); + if(token != null && token.isBlackListed()){ + return token; + } + return null; + } + } diff --git a/src/main/java/io/antmedia/datastore/db/MapDBStore.java b/src/main/java/io/antmedia/datastore/db/MapDBStore.java index e3de68fe0..291e0d246 100644 --- a/src/main/java/io/antmedia/datastore/db/MapDBStore.java +++ b/src/main/java/io/antmedia/datastore/db/MapDBStore.java @@ -4,10 +4,7 @@ import java.io.IOException; import java.nio.file.Files; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; -import java.util.Set; -import java.util.Map.Entry; import org.apache.commons.lang3.exception.ExceptionUtils; import org.mapdb.DB; @@ -16,9 +13,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.antmedia.datastore.db.types.Broadcast; import io.antmedia.datastore.db.types.StreamInfo; -import io.antmedia.muxer.IAntMediaStreamHandler; import io.vertx.core.Vertx; @@ -73,6 +68,7 @@ public MapDBStore(String dbName, Vertx vertx) { webRTCViewerMap = db.treeMap(WEBRTC_VIEWER).keySerializer(Serializer.STRING).valueSerializer(Serializer.STRING) .counterEnable().createOrOpen(); + timerId = vertx.setPeriodic(5000, id -> vertx.executeBlocking(b -> { @@ -124,7 +120,10 @@ public void close(boolean deleteDB) { public void clearStreamInfoList(String streamId) { //used in mongo for cluster mode. useless here. } - + + + + @Override public List getStreamInfoList(String streamId) { return new ArrayList<>(); diff --git a/src/main/java/io/antmedia/datastore/db/MongoStore.java b/src/main/java/io/antmedia/datastore/db/MongoStore.java index 11af10fe3..bc0ab3201 100644 --- a/src/main/java/io/antmedia/datastore/db/MongoStore.java +++ b/src/main/java/io/antmedia/datastore/db/MongoStore.java @@ -10,8 +10,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Pattern; +import io.antmedia.rest.model.Result; +import io.antmedia.security.ITokenService; import org.apache.commons.io.FilenameUtils; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -56,6 +59,7 @@ public class MongoStore extends DataStore { public static final String VOD_ID = "vodId"; private static final String VIEWER_ID = "viewerId"; private static final String TOKEN_ID = "tokenId"; + private static final String BLACKLISTED = "blackListed"; public static final String STREAM_ID = "streamId"; private Datastore datastore; private Datastore vodDatastore; @@ -94,7 +98,6 @@ public MongoStore(String host, String username, String password, String dbName) subscriberDatastore = Morphia.createDatastore(mongoClient, dbName + "_subscriber"); detectionMap = Morphia.createDatastore(mongoClient, dbName + "detection"); conferenceRoomDatastore = Morphia.createDatastore(mongoClient, dbName + "room"); - //************************************************* //do not create data store for each type as we do above //************************************************* @@ -107,7 +110,7 @@ public MongoStore(String host, String username, String password, String dbName) vodDatastore.getMapper().mapPackage("io.antmedia.datastore.db.types"); detectionMap.getMapper().mapPackage("io.antmedia.datastore.db.types"); conferenceRoomDatastore.getMapper().mapPackage("io.antmedia.datastore.db.types"); - + tokenDatastore.ensureIndexes(); subscriberDatastore.ensureIndexes(); datastore.ensureIndexes(); @@ -1451,4 +1454,103 @@ public boolean updateStreamMetaData(String streamId, String metaData) { } return false; } + + @Override + public boolean blackListToken(Token token) { + boolean result = false; + //update if exists, else insert + synchronized (this) { + if (token.getStreamId() != null && token.getTokenId() != null) { + Query query = tokenDatastore.find(Token.class).filter(Filters.eq(TOKEN_ID, token.getTokenId())); + + if(query.first() != null){ + final UpdateResult results = query.update(new UpdateOptions().multi(false), set(BLACKLISTED, true)); + if(results.getModifiedCount() == 1){ + result = true; + } + }else{ + token.setBlackListed(true); + result = saveToken(token); + } + } + } + + return result; + + } + + @Override + public Token getBlackListedToken(String tokenId) { + synchronized (this){ + Query query = tokenDatastore.find(Token.class).filter(Filters.eq(TOKEN_ID, tokenId)); + Token fetchedToken = query.first(); + if(fetchedToken != null && fetchedToken.isBlackListed()){ + return fetchedToken; + } + } + return null; + } + + @Override + public boolean whiteListToken(String tokenId) { + synchronized (this){ + Query query = tokenDatastore.find(Token.class).filter(Filters.eq(TOKEN_ID, tokenId)); + final UpdateResult results = query.update(new UpdateOptions().multi(false), set(BLACKLISTED, false)); + return results.wasAcknowledged(); + } + } + + @Override + public List getBlackListedTokens() { + List tokenBlacklist = new ArrayList<>(); + synchronized (this){ + Query query = tokenDatastore.find(Token.class).filter(Filters.eq(BLACKLISTED, true)); + for (Token token : query) { + tokenBlacklist.add(token.getTokenId()); + } + return tokenBlacklist; + } + } + + @Override + public Result deleteAllBlacklistedExpiredTokens(ITokenService tokenService) { + AtomicInteger deletedTokenCount = new AtomicInteger(); + + synchronized (this){ + List tokenBlacklist = getBlackListedTokens(); + tokenBlacklist.forEach(tokenId ->{ + + Token token = getToken(tokenId); + if(!tokenService.verifyJwt(tokenId,token.getStreamId(),token.getType())){ + if(whiteListToken(tokenId)){ + deletedTokenCount.getAndIncrement(); + }else{ + logger.warn("Couldn't delete JWT:{}", tokenId); + } + } + + }); + + } + if(deletedTokenCount.get() > 0){ + final String successMsg = deletedTokenCount+" JWT deleted successfully from blacklist."; + logger.info(successMsg); + return new Result(true, successMsg); + }else{ + final String failMsg = "No JWT deleted from blacklist."; + logger.warn(failMsg); + return new Result(false, failMsg); + } + + } + + @Override + public boolean whiteListAllTokens() { + synchronized (this) { + Query query = tokenDatastore.find(Token.class).filter(Filters.eq(BLACKLISTED, true)); + final UpdateResult results = query.update(new UpdateOptions().multi(true), set(BLACKLISTED, false)); + return results.wasAcknowledged(); + } + } + } diff --git a/src/main/java/io/antmedia/datastore/db/RedisStore.java b/src/main/java/io/antmedia/datastore/db/RedisStore.java index 4208e7fbb..dd8e5fd4f 100644 --- a/src/main/java/io/antmedia/datastore/db/RedisStore.java +++ b/src/main/java/io/antmedia/datastore/db/RedisStore.java @@ -6,9 +6,8 @@ import java.util.Collection; import java.util.Iterator; import java.util.List; -import java.util.Set; -import java.util.Map.Entry; +import io.antmedia.datastore.db.types.Token; import org.apache.commons.lang3.exception.ExceptionUtils; import org.redisson.Redisson; import org.redisson.api.RMap; @@ -17,10 +16,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.antmedia.datastore.db.types.Broadcast; import io.antmedia.datastore.db.types.P2PConnection; import io.antmedia.datastore.db.types.StreamInfo; -import io.antmedia.muxer.IAntMediaStreamHandler; public class RedisStore extends MapBasedDataStore { @@ -123,6 +120,18 @@ public int resetBroadcasts(String hostAddress) { } } + @Override + public boolean blackListToken(Token token) { + + return false; + } + + @Override + public Token getBlackListedToken(String tokenId) { + + return null; + } + @Override public List getStreamInfoList(String streamId) { diff --git a/src/main/java/io/antmedia/datastore/db/types/Token.java b/src/main/java/io/antmedia/datastore/db/types/Token.java index a4cf8bdef..d38b79b48 100644 --- a/src/main/java/io/antmedia/datastore/db/types/Token.java +++ b/src/main/java/io/antmedia/datastore/db/types/Token.java @@ -51,6 +51,16 @@ public class Token { */ @ApiModelProperty(value = "the type of the token") private String type; + + public ObjectId getDbId() { + return dbId; + } + + public void setDbId(ObjectId dbId) { + this.dbId = dbId; + } + + private boolean blackListed; /** * the id of the conference room which requested streams belongs to. @@ -98,6 +108,13 @@ public long getExpireDate() { public void setExpireDate(long expireDate) { this.expireDate = expireDate; } - + + public boolean isBlackListed() { + return blackListed; + } + + public void setBlackListed(boolean blackListed) { + this.blackListed = blackListed; + } } diff --git a/src/main/java/io/antmedia/rest/BroadcastRestService.java b/src/main/java/io/antmedia/rest/BroadcastRestService.java index 50f670d93..b1380f021 100644 --- a/src/main/java/io/antmedia/rest/BroadcastRestService.java +++ b/src/main/java/io/antmedia/rest/BroadcastRestService.java @@ -1,8 +1,17 @@ package io.antmedia.rest; +import java.nio.charset.Charset; +import java.util.Base64; import java.util.List; +import com.auth0.jwt.JWT; +import com.auth0.jwt.exceptions.JWTDecodeException; +import com.auth0.jwt.interfaces.DecodedJWT; +import io.antmedia.rest.model.Jwt; import org.apache.commons.lang3.exception.ExceptionUtils; +import org.json.simple.JSONObject; +import org.json.simple.parser.JSONParser; +import org.json.simple.parser.ParseException; import org.springframework.stereotype.Component; import io.antmedia.AntMediaApplicationAdapter; @@ -40,6 +49,8 @@ import io.swagger.annotations.Info; import io.swagger.annotations.License; import io.swagger.annotations.SwaggerDefinition; +import org.springframework.web.bind.annotation.RequestBody; + import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.GET; @@ -78,6 +89,9 @@ public class BroadcastRestService extends RestServiceBase{ private static final String ABSOLUTE_MOVE = "absolute"; private static final String CONTINUOUS_MOVE = "continuous"; + private final String blacklistNotEnabledMsg = "JWT blacklist is not enabled for this application."; + + @ApiModel(value="SimpleStat", description="Simple generic statistics class to return single values") public static class SimpleStat { @ApiModelProperty(value = "the stat value") @@ -593,8 +607,130 @@ public Result validateTokenV2(@ApiParam(value = "Token to be validated", require return new Result(result); } + @ApiOperation(value = "Add jwt token to blacklist. If added, success field is true, " + + "if not added success field false", response = Result.class) + @POST + @Consumes(MediaType.APPLICATION_JSON) + @Path("/jwt-black-list") + @Produces(MediaType.APPLICATION_JSON) + public Result blackListJwt(@ApiParam(value = "jwt to be added to blacklist.", required = true) Jwt jwt) + { + if(getAppSettings().isJwtBlacklistEnabled()){ + + try{ + final DecodedJWT decodedJWT = JWT.decode(jwt.getJwt()); + final String payload = new String(Base64.getDecoder().decode(decodedJWT.getPayload()), Charset.defaultCharset()); + final JSONParser parser = new JSONParser(); + try{ + final JSONObject jwtPayload = (JSONObject) parser.parse(payload); + final String tokenType = jwtPayload.get("type").toString(); + final String streamId = jwtPayload.get("streamId").toString(); + final long exp = (Long) jwtPayload.get("exp"); + + final Token token = new Token(); + token.setTokenId(jwt.getJwt()); + token.setType(tokenType); + token.setStreamId(streamId); + token.setExpireDate(exp); + + if(!super.verifyJwt(jwt.getJwt(), streamId, tokenType)){ + return new Result(false,"JWT is not valid."); + }else if(getDataStore().getBlackListedToken(jwt.getJwt()) != null){ + return new Result(false, "JWT is already in blacklist."); + }else if(getDataStore().blackListToken(token)){ + return new Result(true, "JWT successfully added to blacklist."); + } + + }catch (ParseException e) { + return new Result(false,"Invalid JWT"); + } + }catch (JWTDecodeException e){ + return new Result(false,"Invalid JWT"); + } + + }else{ + logger.warn(blacklistNotEnabledMsg); + return new Result(false, blacklistNotEnabledMsg); + } - @ApiOperation(value = " Removes all tokens related with requested stream", notes = "", response = Result.class) + return new Result(false); + } + + @ApiOperation(value = "Remove jwt from blacklist. If removed, success field is true, " + + "if not removed success field false", response = Result.class) + @DELETE + @Consumes(MediaType.APPLICATION_JSON) + @Path("/jwt-black-list") + @Produces(MediaType.APPLICATION_JSON) + public Result whiteListJwt(@ApiParam(value = "Jwt to be removed from blacklist.", required = true) Jwt jwt) + { + if(getAppSettings().isJwtBlacklistEnabled()){ + + if(getDataStore().getBlackListedToken(jwt.getJwt()) == null){ + return new Result(false, "JWT does not exist in blacklist."); + + }else if(getDataStore().whiteListToken(jwt.getJwt())){ + return new Result(true, "JWT successfully removed from blacklist."); + + }else{ + return new Result(false, "JWT cannot be removed from blacklist."); + } + }else{ + logger.warn(blacklistNotEnabledMsg); + return new Result(false, blacklistNotEnabledMsg); + } + + } + + @ApiOperation(value = "Get all blacklisted JWTs.", response = Result.class) + @GET + @Path("/jwt-black-list") + @Produces(MediaType.APPLICATION_JSON) + public List getJwtBlacklist() + { + if(getAppSettings().isJwtBlacklistEnabled()) { + return getDataStore().getBlackListedTokens(); + }else{ + logger.warn(blacklistNotEnabledMsg); + return null; + } + } + + @ApiOperation(value = "Delete all expired blacklisted JWTs.", response = Result.class) + @DELETE + @Path("/jwt-black-list-delete-expired") + @Produces(MediaType.APPLICATION_JSON) + public Result deleteAllExpiredJwtFromBlacklist() + { + if(getAppSettings().isJwtBlacklistEnabled()) { + return getDataStore().deleteAllBlacklistedExpiredTokens(getTokenService()); + }else{ + logger.warn(blacklistNotEnabledMsg); + return new Result(false, blacklistNotEnabledMsg); + } + } + + @ApiOperation(value = "White list all blacklisted JWTs.", response = Result.class) + @DELETE + @Path("/jwt-black-list-clear") + @Produces(MediaType.APPLICATION_JSON) + public Result clearJwtBlacklist() + { + if(getAppSettings().isJwtBlacklistEnabled()) { + getDataStore().whiteListAllTokens(); + if(getDataStore().getBlackListedTokens().isEmpty()){ + return new Result(true, "All blacklisted tokens are removed successfully."); + }else{ + return new Result(false, "JWT blacklist clear failed."); + } + }else{ + logger.warn(blacklistNotEnabledMsg); + return new Result(false, blacklistNotEnabledMsg); + } + + } + + @ApiOperation(value = "Removes all tokens related with requested stream", notes = "", response = Result.class) @DELETE @Consumes(MediaType.APPLICATION_JSON) @Path("/{id}/tokens") @@ -603,7 +739,6 @@ public Result revokeTokensV2(@ApiParam(value = "the id of the stream", required return super.revokeTokens(streamId); } - @ApiOperation(value = "Get the all tokens of requested stream", notes = "",responseContainer = "List", response = Token.class) @GET @Path("/{id}/tokens/list/{offset}/{size}") @@ -1221,6 +1356,5 @@ public Result stopPlaying(@ApiParam(value = "the id of the webrtc viewer.", requ boolean result = getApplication().stopPlaying(viewerId); return new Result(result); } - } diff --git a/src/main/java/io/antmedia/rest/RestServiceBase.java b/src/main/java/io/antmedia/rest/RestServiceBase.java index 2c35eaee3..35ea43938 100755 --- a/src/main/java/io/antmedia/rest/RestServiceBase.java +++ b/src/main/java/io/antmedia/rest/RestServiceBase.java @@ -1613,6 +1613,14 @@ protected Object getJwtToken (String streamId, long expireDate, String type, Str return new Result(false, message); } + protected boolean verifyJwt(String jwt, String streamId, String type){ + return getTokenService().verifyJwt(jwt, streamId, type); + } + + protected ITokenService getTokenService(){ + return (ITokenService) getAppContext().getBean(ITokenService.BeanName.TOKEN_SERVICE.toString()); + } + protected Token validateToken (Token token) { Token validatedToken = null; diff --git a/src/main/java/io/antmedia/rest/model/Jwt.java b/src/main/java/io/antmedia/rest/model/Jwt.java new file mode 100644 index 000000000..5139c4e50 --- /dev/null +++ b/src/main/java/io/antmedia/rest/model/Jwt.java @@ -0,0 +1,30 @@ +package io.antmedia.rest.model; + +import dev.morphia.annotations.Entity; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; + +@ApiModel(value="jwt", description="The basic jwt class for jwt blacklist") +@Entity(value = "jwt") +public class Jwt { + @ApiModelProperty(value = "the jwt") + private String jwt; + + // Default constructor + public Jwt() { + } + + // Constructor with jwt parameter + public Jwt(String jwt) { + this.jwt = jwt; + } + + // Getter and setter for jwt + public String getJwt() { + return jwt; + } + + public void setJwt(String jwt) { + this.jwt = jwt; + } +} diff --git a/src/main/java/io/antmedia/security/ITokenService.java b/src/main/java/io/antmedia/security/ITokenService.java index ec838d16e..1096a3007 100644 --- a/src/main/java/io/antmedia/security/ITokenService.java +++ b/src/main/java/io/antmedia/security/ITokenService.java @@ -112,4 +112,5 @@ public String toString() { Map getSubscriberAuthenticatedMap(); + boolean verifyJwt(String jwtTokenId, String streamId, String type); } diff --git a/src/main/java/io/antmedia/security/MockTokenService.java b/src/main/java/io/antmedia/security/MockTokenService.java index d39a84769..44f5e9494 100644 --- a/src/main/java/io/antmedia/security/MockTokenService.java +++ b/src/main/java/io/antmedia/security/MockTokenService.java @@ -42,7 +42,12 @@ public Map getAuthenticatedMap() { public Map getSubscriberAuthenticatedMap() { return subscriberAuthenticatedMap; } - + + @Override + public boolean verifyJwt(String jwtTokenId, String streamId, String type) { + return false; + } + @Override public boolean checkHash(String hash, String streamId, String sessionId, String type) { return true; diff --git a/src/test/java/io/antmedia/integration/RestServiceV2Test.java b/src/test/java/io/antmedia/integration/RestServiceV2Test.java index 83157d4eb..eb3264011 100644 --- a/src/test/java/io/antmedia/integration/RestServiceV2Test.java +++ b/src/test/java/io/antmedia/integration/RestServiceV2Test.java @@ -28,6 +28,7 @@ import javax.servlet.ServletContext; import javax.ws.rs.core.Context; +import com.google.gson.JsonObject; import io.antmedia.AppSettings; import io.antmedia.EncoderSettings; @@ -38,10 +39,8 @@ import org.apache.http.NameValuePair; import org.apache.http.client.HttpClient; import org.apache.http.client.entity.UrlEncodedFormEntity; -import org.apache.http.client.methods.CloseableHttpResponse; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.methods.HttpUriRequest; -import org.apache.http.client.methods.RequestBuilder; +import org.apache.http.client.methods.*; +import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.StringEntity; import org.apache.http.entity.mime.HttpMultipartMode; import org.apache.http.entity.mime.MultipartEntityBuilder; @@ -55,6 +54,7 @@ import org.awaitility.Awaitility; import org.bytedeco.ffmpeg.global.avformat; import org.bytedeco.ffmpeg.global.avutil; +import org.json.simple.JSONObject; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; @@ -1997,4 +1997,177 @@ public Result callIsEnterpriseEdition() throws Exception { } + @Test + public void testJwtBlacklist(){ + ConsoleAppRestServiceTest.resetCookieStore(); + + try { + Result result = ConsoleAppRestServiceTest.callisFirstLogin(); + + if (result.isSuccess()) { + Result createInitialUser = ConsoleAppRestServiceTest.createDefaultInitialUser(); + assertTrue(createInitialUser.isSuccess()); + } + + result = ConsoleAppRestServiceTest.authenticateDefaultUser(); + assertTrue(result.isSuccess()); + + + boolean isEnterprise = callIsEnterpriseEdition().getMessage().contains("Enterprise"); + if(!isEnterprise) { + logger.info("This is not enterprise edition so skipping this test"); + return; + } + + final AppSettings appSettingsModel = ConsoleAppRestServiceTest.callGetAppSettings("cleanapp"); + appSettingsModel.setJwtStreamSecretKey("testtesttesttesttesttesttesttest"); + appSettingsModel.setJwtBlacklistEnabled(true); + + result = ConsoleAppRestServiceTest.callSetAppSettings("cleanapp", appSettingsModel); + assertTrue(result.isSuccess()); + + final String clearJwtBlacklistUrl = ROOT_SERVICE_URL + "/v2/broadcasts/jwt-black-list-clear"; + final String jwtBlacklistUrl = ROOT_SERVICE_URL + "/v2/broadcasts/jwt-black-list"; + + CloseableHttpClient client = HttpClients.custom().setRedirectStrategy(new LaxRedirectStrategy()).build(); + + HttpUriRequest clearJwtBlacklistRequest = RequestBuilder.delete().setUri(clearJwtBlacklistUrl).build(); + HttpResponse clearJwtBlacklistResponse = client.execute(clearJwtBlacklistRequest); + StringBuffer clearJwtBlacklistResult = readResponse(clearJwtBlacklistResponse); + + if (clearJwtBlacklistResponse.getStatusLine().getStatusCode() != 200) { + throw new Exception(clearJwtBlacklistResult.toString()); + } + result = gson.fromJson(clearJwtBlacklistResult.toString(), Result.class); + + assertTrue(result.isSuccess()); + + final String validJwt1 = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdHJlYW1JZCI6InRlc3RzdHJlYW0iLCJ0eXBlIjoicHVibGlzaCIsImV4cCI6OTg4NzUwNzUwMH0.aZRIBC6zHDPw3od9tBCn9gGg3Taab8RpuPUxGr46YM8"; + final String validJwt2 = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdHJlYW1JZCI6InRlc3RzdHJlYW0iLCJ0eXBlIjoicHVibGlzaCIsImV4cCI6OTg4NzUwNzUwMX0.f4YTJUOmO7yuGpD7W4i_fffv2IVi1JB3mZVxNv8LSdI"; + final String validJwt3 = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdHJlYW1JZCI6InRlc3RzdHJlYW0iLCJ0eXBlIjoicHVibGlzaCIsImV4cCI6OTg4NzUwNzUwMn0.bSbIumeAeM-k5zLndtII49z_458L8Lqg3eVweahvpb4"; + + + JsonObject jwtRequest1 = new JsonObject(); + jwtRequest1.addProperty("jwt",validJwt1); + + JsonObject jwtRequest2 = new JsonObject(); + jwtRequest2.addProperty("jwt",validJwt2); + + JsonObject jwtRequest3 = new JsonObject(); + jwtRequest3.addProperty("jwt",validJwt3); + + String jwtRequestBody1 = gson.toJson(jwtRequest1); + HttpPost addJwtRequest1 = new HttpPost(jwtBlacklistUrl); + addJwtRequest1.setHeader("Content-Type", "application/json"); + addJwtRequest1.setEntity(new StringEntity(jwtRequestBody1)); + + String jwtRequestBody2 = gson.toJson(jwtRequest2); + HttpPost addJwtRequest2 = new HttpPost(jwtBlacklistUrl); + addJwtRequest2.setHeader("Content-Type", "application/json"); + addJwtRequest2.setEntity(new StringEntity(jwtRequestBody2)); + + String jwtRequestBody3 = gson.toJson(jwtRequest3); + HttpPost addJwtRequest3 = new HttpPost(jwtBlacklistUrl); + addJwtRequest3.setHeader("Content-Type", "application/json"); + addJwtRequest3.setEntity(new StringEntity(jwtRequestBody3)); + + HttpResponse addJwtResponse1 = client.execute(addJwtRequest1); + + StringBuffer addJwtResult1 = readResponse(addJwtResponse1); + + if (addJwtResponse1.getStatusLine().getStatusCode() != 200) { + throw new Exception(addJwtResult1.toString()); + } + + result = gson.fromJson(addJwtResult1.toString(), Result.class); + assertTrue(result.isSuccess()); + + HttpResponse addJwtResponse2 = client.execute(addJwtRequest2); + + StringBuffer addJwtResult2 = readResponse(addJwtResponse2); + + if (addJwtResponse2.getStatusLine().getStatusCode() != 200) { + throw new Exception(addJwtResult2.toString()); + } + + result = gson.fromJson(addJwtResult2.toString(), Result.class); + assertTrue(result.isSuccess()); + + HttpResponse addJwtResponse3 = client.execute(addJwtRequest3); + + StringBuffer addJwtResult3 = readResponse(addJwtResponse3); + + if (addJwtResponse3.getStatusLine().getStatusCode() != 200) { + throw new Exception(addJwtResult3.toString()); + } + + result = gson.fromJson(addJwtResult3.toString(), Result.class); + assertTrue(result.isSuccess()); + + HttpUriRequest getJwtBlacklistRequest = RequestBuilder.get().setUri(jwtBlacklistUrl) + .setHeader(HttpHeaders.CONTENT_TYPE, "application/json").build(); + + HttpResponse getJwtBlacklistRequestResponse = client.execute(getJwtBlacklistRequest); + + StringBuffer getJwtBlacklistRequestResult = readResponse(getJwtBlacklistRequestResponse); + + ArrayList jwtBlacklist = gson.fromJson(getJwtBlacklistRequestResult.toString(), ArrayList.class); + + int expectedJwtCount = 3; + + assertEquals(expectedJwtCount, jwtBlacklist.size()); + + + HttpUriRequest whiteListJwtRequest = RequestBuilder.delete().setUri(new URIBuilder(jwtBlacklistUrl).addParameter("jwt",validJwt1).build()).build(); + HttpResponse whiteListJwtResponse = client.execute(whiteListJwtRequest); + StringBuffer whiteListJwtResult = readResponse(whiteListJwtResponse); + + + if (whiteListJwtResponse.getStatusLine().getStatusCode() != 200) { + throw new Exception(whiteListJwtResult.toString()); + } + result = gson.fromJson(whiteListJwtResult.toString(), Result.class); + + assertTrue(result.isSuccess()); + + getJwtBlacklistRequestResponse = client.execute(getJwtBlacklistRequest); + + getJwtBlacklistRequestResult = readResponse(getJwtBlacklistRequestResponse); + + jwtBlacklist = gson.fromJson(getJwtBlacklistRequestResult.toString(), ArrayList.class); + + expectedJwtCount = 2; + + assertEquals(expectedJwtCount, jwtBlacklist.size()); + + clearJwtBlacklistResponse = client.execute(clearJwtBlacklistRequest); + clearJwtBlacklistResult = readResponse(clearJwtBlacklistResponse); + + if (clearJwtBlacklistResponse.getStatusLine().getStatusCode() != 200) { + throw new Exception(clearJwtBlacklistResult.toString()); + } + result = gson.fromJson(clearJwtBlacklistResult.toString(), Result.class); + + assertTrue(result.isSuccess()); + + getJwtBlacklistRequestResponse = client.execute(getJwtBlacklistRequest); + + getJwtBlacklistRequestResult = readResponse(getJwtBlacklistRequestResponse); + + jwtBlacklist = gson.fromJson(getJwtBlacklistRequestResult.toString(), ArrayList.class); + + expectedJwtCount = 0; + + assertEquals(expectedJwtCount, jwtBlacklist.size()); + + + } catch (Exception e){ + e.printStackTrace(); + + fail(e.getMessage()); + + } + + } + } diff --git a/src/test/java/io/antmedia/test/AppSettingsUnitTest.java b/src/test/java/io/antmedia/test/AppSettingsUnitTest.java index 99ce24cc6..0561c5372 100644 --- a/src/test/java/io/antmedia/test/AppSettingsUnitTest.java +++ b/src/test/java/io/antmedia/test/AppSettingsUnitTest.java @@ -503,6 +503,7 @@ public void testUnsetAppSettings(AppSettings appSettings) { assertEquals(-1, appSettings.getMaxVideoTrackCount()); assertEquals(2, appSettings.getOriginEdgeIdleTimeout()); assertEquals(false, appSettings.isAddDateTimeToHlsFileName()); + assertEquals(false, appSettings.isJwtControlEnabled()); assertEquals(true, appSettings.isPlayWebRTCStreamOnceForEachSession()); assertEquals(true, appSettings.isStatsBasedABREnabled()); assertEquals(1, appSettings.getAbrDownScalePacketLostRatio(), 0.0001); @@ -510,6 +511,7 @@ public void testUnsetAppSettings(AppSettings appSettings) { assertEquals(30, appSettings.getAbrUpScaleJitterMs(), 0.0001); assertEquals(150, appSettings.getAbrUpScaleRTTMs(), 0.0001); assertNotNull(appSettings.getClusterCommunicationKey()); + assertEquals(false, appSettings.isJwtBlacklistEnabled()); @@ -517,8 +519,11 @@ public void testUnsetAppSettings(AppSettings appSettings) { //When a new field is added or removed please update the number of fields and make this test pass //by also checking its default value. assertEquals("New field is added to settings. PAY ATTENTION: Please CHECK ITS DEFAULT VALUE and fix the number of fields.", + 165, numberOfFields); + + } } diff --git a/src/test/java/io/antmedia/test/db/DBStoresUnitTest.java b/src/test/java/io/antmedia/test/db/DBStoresUnitTest.java index cd420a38f..2f429c152 100644 --- a/src/test/java/io/antmedia/test/db/DBStoresUnitTest.java +++ b/src/test/java/io/antmedia/test/db/DBStoresUnitTest.java @@ -7,6 +7,8 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.File; import java.io.IOException; @@ -15,13 +17,14 @@ import java.util.*; import java.util.concurrent.TimeUnit; +import io.antmedia.rest.model.Result; +import io.antmedia.security.ITokenService; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomUtils; import org.awaitility.Awaitility; import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.mockito.Mockito; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.ApplicationContext; @@ -132,6 +135,8 @@ public void testMapDBStore() throws Exception { testWebRTCViewerOperations(dataStore); testUpdateMetaData(dataStore); testStreamSourceList(dataStore); + testWhitelistAllExpiredTokens(dataStore); + testWhitelistAllTokens(dataStore); } @@ -268,6 +273,8 @@ public void testMongoStore() throws Exception { testUpdateEndpointStatus(dataStore); testWebRTCViewerOperations(dataStore); testUpdateMetaData(dataStore); + testWhitelistAllExpiredTokens(dataStore); + testWhitelistAllTokens(dataStore); } @Test @@ -2142,9 +2149,9 @@ private DataStore createDB(String type, boolean writeStats) { dsf.setDbType(type); dsf.setDbName("testdb"); dsf.setDbHost("localhost"); - ApplicationContext context = Mockito.mock(ApplicationContext.class); - Mockito.when(context.getBean(IAntMediaStreamHandler.VERTX_BEAN_NAME)).thenReturn(vertx); - Mockito.when(context.getBean(ServerSettings.BEAN_NAME)).thenReturn(new ServerSettings()); + ApplicationContext context = mock(ApplicationContext.class); + when(context.getBean(IAntMediaStreamHandler.VERTX_BEAN_NAME)).thenReturn(vertx); + when(context.getBean(ServerSettings.BEAN_NAME)).thenReturn(new ServerSettings()); dsf.setApplicationContext(context); return dsf.getDataStore(); } @@ -2290,6 +2297,18 @@ public void testMongoDBSaveStreamInfo() { deleteStreamInfos(dataStore); } + @Test + public void testMongoDBJwtBlacklist(){ + MongoStore dataStore = new MongoStore("localhost", "", "", "testdb"); + dataStore.whiteListToken(""); + dataStore.getBlackListedTokens(); + dataStore.deleteAllBlacklistedExpiredTokens(null); + dataStore.whiteListAllTokens(); + dataStore.blackListToken(new Token()); + dataStore.getBlackListedToken(""); + + } + public void deleteStreamInfos(MongoStore datastore) { datastore.getDataStore().find(StreamInfo.class).delete(new DeleteOptions() .multi(true)); @@ -2945,4 +2964,98 @@ public void testUpdateMetaData(DataStore dataStore) { assertFalse(dataStore.updateStreamMetaData("someDummyStream"+RandomStringUtils.randomAlphanumeric(8), UPDATED_DATA)); } + + public void testWhitelistAllExpiredTokens(DataStore dataStore){ + + Token token1 = new Token(); + Token token2 = new Token(); + Token token3 = new Token(); + + String jwt1 = "jwt1"; + String jwt2 = "jwt2"; + String jwt3 = "jwt3"; + + String tokenType = "publish"; + String streamId = "test-stream"; + + token1.setTokenId(jwt1); + token2.setTokenId(jwt2); + token3.setTokenId(jwt3); + + token1.setType(tokenType); + token2.setType(tokenType); + token3.setType(tokenType); + + token1.setStreamId(streamId); + token2.setStreamId(streamId); + token3.setStreamId(streamId); + ITokenService tokenService = mock(ITokenService.class); + + Result res = dataStore.deleteAllBlacklistedExpiredTokens(tokenService); + assertFalse(res.isSuccess()); + + dataStore.blackListToken(token1); + dataStore.blackListToken(token2); + dataStore.blackListToken(token3); + + Token jwt = dataStore.getBlackListedToken(token1.getTokenId()); + assertNotNull(jwt); + + when(tokenService.verifyJwt(jwt1,streamId,tokenType)).thenReturn(false); + when(tokenService.verifyJwt(jwt2,streamId,tokenType)).thenReturn(false); + when(tokenService.verifyJwt(jwt3,streamId,tokenType)).thenReturn(false); + + + res = dataStore.deleteAllBlacklistedExpiredTokens(tokenService); + assertTrue(res.isSuccess()); + + } + + public void testWhitelistAllTokens(DataStore dataStore){ + + addJwtsToBlacklist(dataStore); + + + List jwtBlacklist = dataStore.getBlackListedTokens(); + assertEquals(3, jwtBlacklist.size()); + + Token token = dataStore.getBlackListedToken("jwt1"); + dataStore.whiteListAllTokens(); + jwtBlacklist = dataStore.getBlackListedTokens(); + + assertEquals(0, jwtBlacklist.size()); + + + } + + private void addJwtsToBlacklist(DataStore dataStore){ + Token token1 = new Token(); + Token token2 = new Token(); + Token token3 = new Token(); + + String jwt1 = "jwt1"; + String jwt2 = "jwt2"; + String jwt3 = "jwt3"; + + String tokenType = "publish"; + String streamId = "test-stream"; + + token1.setTokenId(jwt1); + token2.setTokenId(jwt2); + token3.setTokenId(jwt3); + + + token1.setType(tokenType); + token2.setType(tokenType); + token3.setType(tokenType); + + token1.setStreamId(streamId); + token2.setStreamId(streamId); + token3.setStreamId(streamId); + + dataStore.blackListToken(token1); + dataStore.blackListToken(token2); + dataStore.blackListToken(token3); + } + } diff --git a/src/test/java/io/antmedia/test/rest/BroadcastRestServiceV2UnitTest.java b/src/test/java/io/antmedia/test/rest/BroadcastRestServiceV2UnitTest.java index 8a3b18a57..a13ff45e4 100644 --- a/src/test/java/io/antmedia/test/rest/BroadcastRestServiceV2UnitTest.java +++ b/src/test/java/io/antmedia/test/rest/BroadcastRestServiceV2UnitTest.java @@ -33,6 +33,8 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; +import io.antmedia.datastore.db.MapDBStore; +import io.antmedia.rest.model.Jwt; import org.apache.commons.lang3.RandomStringUtils; import org.awaitility.Awaitility; import org.bytedeco.ffmpeg.global.avformat; @@ -3112,5 +3114,185 @@ public void testGetCameraProfiles() { assertNull(streamSourceRest.getOnvifDeviceProfiles("invalid id")); } - + + @Test + public void testAddJwtToBlacklist() { + + AppSettings appSettings = mock(AppSettings.class); + ApplicationContext appContext = mock(ApplicationContext.class); + + ITokenService tokenService = mock(ITokenService.class); + + when(appContext.getBean(ITokenService.BeanName.TOKEN_SERVICE.toString())).thenReturn(tokenService); + + when(appSettings.isJwtBlacklistEnabled()).thenReturn(false); + + BroadcastRestService restServiceSpy = Mockito.spy(restServiceReal); + restServiceSpy.setAppCtx(appContext); + + restServiceSpy.setAppSettings(appSettings); + String jwtStr = "test-jwt"; + Jwt jwt = new Jwt(); + jwt.setJwt(jwtStr); + + DataStore store = mock(MapDBStore.class); + + restServiceSpy.setDataStore(store); + + + Result result = restServiceSpy.blackListJwt(jwt); + assertFalse(result.isSuccess()); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(true); + Token token = mock(Token.class); + + when(token.getTokenId()).thenReturn(jwt.getJwt()); + when(token.getStreamId()).thenReturn("test-stream"); + + + when(store.blackListToken(token)).thenReturn(true); + result = restServiceSpy.blackListJwt(jwt); + assertFalse(result.isSuccess()); + String invalidJwtStr = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdHJlYW1JZCI6InRlc3Qtc3RyZWFtIiwidHlwZSI6InF3ZSIsImV4cCI6OTk5OTk5OTk5OTk5OX0.DqfFkRJgKPVXgAkIzucuQtfwP2Oj-Qf9dhUuO_-04bU"; + Jwt invalidJwt = new Jwt(); + invalidJwt.setJwt(invalidJwtStr); + result = restServiceSpy.blackListJwt(invalidJwt); + assertFalse(result.isSuccess()); + + String validJwtStr = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdHJlYW1JZCI6InRlc3Qtc3RyZWFtIiwidHlwZSI6InB1Ymxpc2giLCJleHAiOjk5OTk5OTk5OTk5OTl9.ichno9utOYwVv1qoQWtUpDap7PGYze-zfXRZU31CMnQ"; + Jwt validJwt = new Jwt(); + validJwt.setJwt(validJwtStr); + when(tokenService.verifyJwt(validJwt.getJwt(),"test-stream","publish")).thenReturn(true); + + when(store.blackListToken(any())).thenReturn(true); + + result = restServiceSpy.blackListJwt(validJwt); + assertTrue(result.isSuccess()); + + when(store.getBlackListedToken(validJwt.getJwt())).thenReturn(token); + + result = restServiceSpy.blackListJwt(validJwt); + assertFalse(result.isSuccess()); + + } + + @Test + public void testWhitelistJwt() { + + AppSettings appSettings = mock(AppSettings.class); + ApplicationContext appContext = mock(ApplicationContext.class); + + ITokenService tokenService = mock(ITokenService.class); + + when(appContext.getBean(ITokenService.BeanName.TOKEN_SERVICE.toString())).thenReturn(tokenService); + + when(appSettings.isJwtBlacklistEnabled()).thenReturn(false); + + BroadcastRestService restServiceSpy = Mockito.spy(restServiceReal); + restServiceSpy.setAppCtx(appContext); + + restServiceSpy.setAppSettings(appSettings); + String jwtStr = "test-jwt"; + Jwt jwt = new Jwt(); + jwt.setJwt(jwtStr); + + DataStore store = mock(MapDBStore.class); + + restServiceSpy.setDataStore(store); + + Result result1 = restServiceSpy.whiteListJwt(jwt); + assertFalse(result1.isSuccess()); + + when(appSettings.isJwtBlacklistEnabled()).thenReturn(true); + + when(store.getBlackListedToken(jwtStr)).thenReturn(null); + Result result2 = restServiceSpy.whiteListJwt(jwt); + assertFalse(result2.isSuccess()); + + Token token = mock(Token.class); + when(store.getBlackListedToken(jwtStr)).thenReturn(token); + Result result3 = restServiceSpy.whiteListJwt(jwt); + assertFalse(result3.isSuccess()); + + when(store.whiteListToken(jwtStr)).thenReturn(true); + Result result4 = restServiceSpy.whiteListJwt(jwt); + assertTrue(result4.isSuccess()); + + } + + @Test + public void testGetJwtBlacklist(){ + AppSettings appSettings = mock(AppSettings.class); + ApplicationContext appContext = mock(ApplicationContext.class); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(false); + + String jwt = "test-jwt"; + + DataStore store = mock(MapDBStore.class); + Token token = new Token(); + token.setTokenId(jwt); + token.setType("publish"); + + BroadcastRestService restServiceSpy = Mockito.spy(restServiceReal); + restServiceSpy.setAppCtx(appContext); + store.blackListToken(token); + restServiceSpy.setDataStore(store); + + when(restServiceSpy.getDataStore()).thenReturn(store); + ArrayList tokenList = new ArrayList<>(); + tokenList.add(jwt); + when(restServiceSpy.getAppSettings()).thenReturn(appSettings); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(true); + when(store.getBlackListedTokens()).thenReturn(tokenList); + + List jwtBlacklist = restServiceSpy.getJwtBlacklist(); + + assertTrue(jwtBlacklist.size() > 0); + + } + + @Test + public void testClearJwtBlacklist() { + AppSettings appSettings = mock(AppSettings.class); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(false); + + BroadcastRestService restServiceSpy = Mockito.spy(restServiceReal); + restServiceSpy.setAppSettings(appSettings); + assertFalse(restServiceSpy.clearJwtBlacklist().isSuccess()); + + String jwt = "test-jwt"; + + DataStore store = mock(MapDBStore.class); + Token token = new Token(); + token.setTokenId(jwt); + token.setType("publish"); + store.blackListToken(token); + restServiceSpy.setDataStore(store); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(true); + assertTrue(restServiceSpy.clearJwtBlacklist().isSuccess()); + } + + @Test + public void testDeleteAllExpiredJwtFromBlacklist(){ + AppSettings appSettings = mock(AppSettings.class); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(false); + ApplicationContext appContext = mock(ApplicationContext.class); + + BroadcastRestService restServiceSpy = Mockito.spy(restServiceReal); + restServiceSpy.setAppCtx(appContext); + restServiceSpy.setAppSettings(appSettings); + assertFalse(restServiceSpy.deleteAllExpiredJwtFromBlacklist().isSuccess()); + DataStore store = mock(MapDBStore.class); + Token token = new Token(); + token.setTokenId("token"); + token.setType("publish"); + store.blackListToken(token); + restServiceSpy.setDataStore(store); + when(appSettings.isJwtBlacklistEnabled()).thenReturn(true); + ITokenService tokenService = mock(ITokenService.class); + + when(appContext.getBean(ITokenService.BeanName.TOKEN_SERVICE.toString())).thenReturn(tokenService); + when(store.deleteAllBlacklistedExpiredTokens(tokenService)).thenReturn(new Result(true)); + assertTrue(restServiceSpy.deleteAllExpiredJwtFromBlacklist().isSuccess()); + + } }