package com.onaro.commons.metrics.executor;

import static com.onaro.commons.metrics.MetricsRegistryProvider.RegistryKey.*;
import static com.onaro.commons.metrics.MetricsRegistryProvider.ThreadPoolMonitoringKey.*;

import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ScheduledThreadPoolExecutor;

import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.onaro.commons.metrics.MetricsRegistryProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This class is the wrapper of ScheduledThreadPoolExecutor and intended to be used to capture metrics around thread
 * pool usage
 */

public class ScheduledThreadPoolMonitorExecutor extends ScheduledThreadPoolExecutor {

    private static final Logger errorLogger = LoggerFactory.getLogger(ScheduledThreadPoolMonitorExecutor.class);
    /**
     * Fetch Common registry defined in MetricsRegistryProvider
     */
    private static final MetricRegistry METRIC_REGISTRY = MetricsRegistryProvider.getMetricRegistry(THREAD_POOL);
    /**
     * Used to distinguish between metrics across thread pools
     */
    private final String metricsPrefix;

    /**
     * ThreadPoolScope init as GLOBAL
     */
    private ThreadPoolScope threadPoolScope = ThreadPoolScope.GLOBAL;

    /**
     * ThreadLocal scoped variable used for calculating task execution time
     */
    protected ThreadLocal<Timer.Context> taskExecutionTimer = new ThreadLocal<>();
    protected ThreadLocal<Long> taskExeutionTime = new ThreadLocal<>();

    public ScheduledThreadPoolMonitorExecutor(int corePoolSize, String poolName, ThreadPoolScope threadPoolScope) {
        super(corePoolSize, new ThreadFactoryBuilder().setNameFormat(poolName + "-%d").build());
        this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
        this.threadPoolScope = threadPoolScope;
        registerGauges(threadPoolScope);
    }

    public ScheduledThreadPoolMonitorExecutor(int corePoolSize, RejectedExecutionHandler handler, String poolName, ThreadPoolScope threadPoolScope) {
        super(corePoolSize, new ThreadFactoryBuilder().setNameFormat(poolName + "-%d").build(), handler);
        this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
        this.threadPoolScope = threadPoolScope;
        registerGauges(threadPoolScope);
    }

    /**
     * This method is wrapper on top of Executors.newScheduledThreadPool
     *
     * @param poolSize
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD if threadpool is method scoped
     * @return
     */
    public static ScheduledThreadPoolMonitorExecutor newScheduledThreadPool(int poolSize, String poolName, ThreadPoolScope threadPoolScope) {
        return new ScheduledThreadPoolMonitorExecutor(poolSize, poolName, threadPoolScope);
    }

    /**
     * This method is wrapper on top of Executors.newSingleThreadScheduledExecutor
     *
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD if threadpool is method scoped
     * @return
     */
    public static ScheduledThreadPoolMonitorExecutor newSingleThreadScheduledExecutor(String poolName, ThreadPoolScope threadPoolScope) {
        return new ScheduledThreadPoolMonitorExecutor(1, poolName, threadPoolScope);
    }

    /**
     * This method is wrapper on top of execute(Runnable task)
     *
     * @see java.util.concurrent.ScheduledThreadPoolExecutor#execute(java.lang.Runnable)
     */
    @Override
    public void execute(Runnable task) {
        final long startTime = System.currentTimeMillis();
        super.execute(() -> {
            MetricsRegistryProvider.updateDurationTimer(METRIC_REGISTRY, metricsPrefix, TASK_QUEUE_WAIT_TIME, startTime);
            task.run();
        });
    }

    /**
     * This method is wrapper on top of submit(Runnable task)
     *
     * @see java.util.concurrent.ScheduledThreadPoolExecutor#submit(java.lang.Runnable)
     */
    @Override
    public Future<?> submit(Runnable task) {
        final long startTime = System.currentTimeMillis();
        return super.submit(() -> {
            MetricsRegistryProvider.updateDurationTimer(METRIC_REGISTRY, metricsPrefix, TASK_QUEUE_WAIT_TIME, startTime);
            task.run();
        });
    }

    /**
     * This method is wrapper on top of submit(Callable<T> task)
     *
     * @see java.util.concurrent.ScheduledThreadPoolExecutor#submit(java.lang.Runnable)
     */
    @Override
    public <T> Future<T> submit(Callable<T> task) {
        final long startTime = System.currentTimeMillis();
        return super.submit(() -> {
            MetricsRegistryProvider.updateDurationTimer(METRIC_REGISTRY, metricsPrefix, TASK_QUEUE_WAIT_TIME, startTime);
            return task.call();
        });
    }

    @Override
    protected void beforeExecute(Thread thread, Runnable task) {
        super.beforeExecute(thread, task);
        Timer timer = MetricsRegistryProvider.timer(METRIC_REGISTRY, metricsPrefix, TASK_EXECUTION);
        taskExecutionTimer.set(timer.time());
    }

    @Override
    protected void afterExecute(Runnable task, Throwable throwable) {
        Timer.Context context = taskExecutionTimer.get();
        long elapsed = context.stop(); // nanoseconds
        long elapsedTimeInSecond = elapsed / 1_000_000_000; // in seconds
        taskExeutionTime.set(elapsedTimeInSecond);
        super.afterExecute(task, throwable);
    }

    private void registerGauges(ThreadPoolScope threadPoolScope) {
        try {
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, CORE_POOL_SIZE, (Gauge<Integer>) this::getCorePoolSize, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, MAX_POOL_SIZE, (Gauge<Integer>) this::getMaximumPoolSize, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, QUEUE_SIZE, (Gauge<Integer>) () -> getQueue().size(), threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, ACTIVE_THREADS, (Gauge<Integer>) this::getActiveCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, COMPLETED_TASKS, (Gauge<Long>) this::getCompletedTaskCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, SUBMITTED_TASKS, (Gauge<Long>) this::getTaskCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, NOT_COMPLETED_TASKS, (Gauge<Long>) (() -> this.getTaskCount() - this.getCompletedTaskCount()), threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, THREAD_POOL_SCOPE, (Gauge<String>) () -> this.threadPoolScope.toString(), threadPoolScope);
        } catch (Exception e) {
            errorLogger.warn("Error while registering gauges ", e);
        }
    }
}
