package cn.com.poc.common.aspect;

import cn.com.poc.common.annotation.RedisLimit;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.DateUtils;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.lang.reflect.Method;
import java.util.Date;
import java.util.concurrent.TimeUnit;

/**
 * 限流切面
 */
@Aspect
@Component
public class RedisLimitAspect {


    private final Logger logger = LoggerFactory.getLogger(RedisLimitAspect.class);

    @Resource
    private RedisTemplate<String, Integer> redisTemplate;

    @Pointcut("@annotation( cn.com.poc.common.annotation.RedisLimit)")
    public void redisLimitAnnotation() {
    }

    @Around(value = "redisLimitAnnotation()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        Method method = ((MethodSignature) joinPoint.getSignature()).getMethod();
        RedisLimit annotation = method.getAnnotation(RedisLimit.class);
        StringBuilder redisKey = new StringBuilder();
        String key = annotation.key();
        redisKey.append("Limit_");
        redisKey.append(key);
        redisKey.append(":");

        if (annotation.currentUser()) {
            UserBaseEntity currentUser = BlContext.getCurrentUserNotException();
            redisKey.append(currentUser.getUserId());
        }

        if (Boolean.FALSE.equals(redisTemplate.hasKey(redisKey.toString()))) {
            redisTemplate.opsForValue().increment(redisKey.toString(), 1);
            redisTemplate.expire(redisKey.toString(), expireTime(annotation.timeout(), annotation.timeUnit()), TimeUnit.MILLISECONDS);
        } else if (redisTemplate.opsForValue().get(redisKey.toString()).intValue() >= annotation.count()) {
            throw new I18nMessageException(annotation.exceptionInfo());
        } else {
            redisTemplate.opsForValue().increment(redisKey.toString(), 1);
        }

        return joinPoint.proceed();
    }

    private Long expireTime(Long timeout, RedisLimit.LimitTimeUnit limitTimeUnit) {
        switch (limitTimeUnit) {
            case SECONDS:
                return timeout * 1000;
            case MINUTES:
                return timeout * 60 * 1000;
            case HOURS:
                return timeout * 60 * 60 * 1000;
            case DAYS:
                Date date = new Date();
                return DateUtils.diffTwoDate(DateUtils.addDays(date, timeout.intValue()), date);
            case DAY_OF_MONTH:
                Date dayBegin = DateUtils.getDayBegin(DateUtils.addDays(DateUtils.getToday(), timeout.intValue()));
                return DateUtils.diffTwoDate(dayBegin, DateUtils.getToday());

            case MONTH_OF_YEAR:
                Date monthBegin = DateUtils.getMonthBegin(DateUtils.getMonthAfter(DateUtils.getToday(), timeout.intValue()));
                return DateUtils.diffTwoDate(monthBegin, DateUtils.getToday());
            default:
                throw new BusinessException("不支持的单位");
        }
    }

}
