/*
 * $Id: //depot/prod/bard/main/modules/dfm-common/src/main/java/com/netapp/dfm/common/metrics/executor/ThreadPoolMonitorExecutor.java#9 $
 *
 * Copyright (c) 2020 NetApp, Inc.
 * All rights reserved.
 */
package com.onaro.commons.metrics.executor;


import static com.onaro.commons.metrics.MetricsRegistryProvider.RegistryKey.*;
import static com.onaro.commons.metrics.MetricsRegistryProvider.ThreadPoolMonitoringKey.*;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.onaro.commons.metrics.MetricsRegistryProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This class is the wrapper of ThreadPoolExecutor and intended to be used to capture metrics around thread
 * pool usage
 */
public class ThreadPoolMonitorExecutor extends ThreadPoolExecutor {

    private static final Logger errorLogger = LoggerFactory.getLogger(ThreadPoolMonitorExecutor.class);
    /**
     * Fetch Common registry defined in MetricsRegistryProvider
     */
    private static final MetricRegistry METRIC_REGISTRY = MetricsRegistryProvider.getMetricRegistry(THREAD_POOL);
    /**
     * Used to distinguish between metrics across thread pools
     */
    private final String metricsPrefix;

    /**
     * ThreadPoolScope init as GLOBAL
     */
    private ThreadPoolScope threadPoolScope = ThreadPoolScope.GLOBAL;

    /**
     * ThreadLocal scoped variable used for calculating task execution time
     */
    protected ThreadLocal<Timer.Context> taskExecutionTimer = new ThreadLocal<>();
    protected ThreadLocal<Long> taskExeutionTime = new ThreadLocal<>();

    public ThreadPoolMonitorExecutor(
            int corePoolSize,
            int maximumPoolSize,
            long keepAliveTime,
            TimeUnit unit,
            BlockingQueue<Runnable> workQueue,
            String poolName,
            ThreadPoolScope threadPoolScope
    ) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, new ThreadFactoryBuilder().setNameFormat(poolName + "-%d").build());
        this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
        this.threadPoolScope = threadPoolScope;
        registerGauges(threadPoolScope);
    }

    public ThreadPoolMonitorExecutor(
            int corePoolSize,
            int maximumPoolSize,
            long keepAliveTime,
            TimeUnit unit,
            BlockingQueue<Runnable> workQueue,
            RejectedExecutionHandler handler,
            String poolName,
            ThreadPoolScope threadPoolScope
    ) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
        this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
        this.threadPoolScope = threadPoolScope;
        registerGauges(threadPoolScope);
    }

    public ThreadPoolMonitorExecutor(
            int corePoolSize,
            int maximumPoolSize,
            long keepAliveTime,
            TimeUnit unit,
            BlockingQueue<Runnable> workQueue,
            ThreadFactory threadFactory,
            RejectedExecutionHandler handler,
            String poolName,
            ThreadPoolScope threadPoolScope
    ) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
        this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
        this.threadPoolScope = threadPoolScope;
        registerGauges(threadPoolScope);
    }

    /**
     * This method is wrapper on top of Executors.newFixedThreadPool
     *
     * @param poolSize
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD if threadpool is method scoped
     * @return
     */
    public static ThreadPoolMonitorExecutor newFixedThreadPool(int poolSize, String poolName, ThreadPoolScope threadPoolScope) {
        return new ThreadPoolMonitorExecutor(poolSize, poolSize, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(), poolName, threadPoolScope);
    }

    /**
     * This method is mimicking thread pool defined in spring provided ThreadPoolTaskExecutor
     *
     * @param corePoolSize
     * @param maximumPoolSize
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD if threadpool is method scoped
     * @return
     */
    public static ThreadPoolMonitorExecutor newSpringFixedThreadPool(int corePoolSize, int maximumPoolSize, String poolName, ThreadPoolScope threadPoolScope) {
        return new ThreadPoolMonitorExecutor(corePoolSize, maximumPoolSize, 60L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(), poolName, threadPoolScope);
    }

    /**
     * This method is wrapper on top of Executors.newSingleThreadExecutor
     *
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD  if threadpool is method scoped
     * @return
     */
    public static ThreadPoolMonitorExecutor newSingleThreadExecutor(String poolName, ThreadPoolScope threadPoolScope) {
        return new ThreadPoolMonitorExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(), poolName, threadPoolScope);
    }

    /**
     * This method is wrapper on top of Executors.newCachedThreadPool
     *
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD if threadpool is method scoped
     * @return
     */
    public static ThreadPoolMonitorExecutor newCachedThreadPool(String poolName, ThreadPoolScope threadPoolScope) {
        return new ThreadPoolMonitorExecutor(0, Integer.MAX_VALUE, 60L, TimeUnit.SECONDS, new SynchronousQueue<Runnable>(), poolName, threadPoolScope);
    }


    /**
     * This method is wrapper on top of execute(Runnable task)
     *
     * @see java.util.concurrent.ThreadPoolExecutor#execute(java.lang.Runnable)
     */
    @Override
    public void execute(Runnable task) {
        final long startTime = System.currentTimeMillis();
        super.execute(() -> {
            MetricsRegistryProvider.updateDurationTimer(METRIC_REGISTRY, metricsPrefix, TASK_QUEUE_WAIT_TIME, startTime);
            task.run();
        });
    }

    /**
     * This method is wrapper on top of submit(Runnable task)
     *
     * @see java.util.concurrent.ScheduledThreadPoolExecutor#submit(java.lang.Runnable)
     */
    @Override
    public Future<?> submit(Runnable task) {
        final long startTime = System.currentTimeMillis();
        return super.submit(() -> {
            MetricsRegistryProvider.updateDurationTimer(METRIC_REGISTRY, metricsPrefix, TASK_QUEUE_WAIT_TIME, startTime);
            task.run();
        });
    }

    /**
     * This method is wrapper on top of submit(Callable<T> task)
     *
     * @see java.util.concurrent.AbstractExecutorService#submit(java.util.concurrent.Callable)
     */
    @Override
    public <T> Future<T> submit(Callable<T> task) {
        final long startTime = System.currentTimeMillis();
        return super.submit(() -> {
            MetricsRegistryProvider.updateDurationTimer(METRIC_REGISTRY, metricsPrefix, TASK_QUEUE_WAIT_TIME, startTime);
            return task.call();
        });
    }

    @Override
    protected void beforeExecute(Thread thread, Runnable task) {
        super.beforeExecute(thread, task);
        Timer timer = MetricsRegistryProvider.timer(METRIC_REGISTRY, metricsPrefix, TASK_EXECUTION);
        taskExecutionTimer.set(timer.time());
    }

    @Override
    protected void afterExecute(Runnable task, Throwable throwable) {
        Timer.Context context = taskExecutionTimer.get();
        long elapsed = context.stop(); // nanoseconds
        long elapsedTimeInSecond = elapsed / 1_000_000_000; // in seconds
        taskExeutionTime.set(elapsedTimeInSecond);
        super.afterExecute(task, throwable);
    }

    private void registerGauges(ThreadPoolScope threadPoolScope) {
        try {
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, CORE_POOL_SIZE, (Gauge<Integer>) this::getCorePoolSize, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, MAX_POOL_SIZE, (Gauge<Integer>) this::getMaximumPoolSize, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, QUEUE_SIZE, (Gauge<Integer>) () -> getQueue().size(), threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, ACTIVE_THREADS, (Gauge<Integer>) this::getActiveCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, COMPLETED_TASKS, (Gauge<Long>) this::getCompletedTaskCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, SUBMITTED_TASKS, (Gauge<Long>) this::getTaskCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, NOT_COMPLETED_TASKS, (Gauge<Long>) (() -> this.getTaskCount() - this.getCompletedTaskCount()), threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, THREAD_POOL_SCOPE, (Gauge<String>) () -> this.threadPoolScope.toString(), threadPoolScope);
        } catch (Exception e) {
            errorLogger.warn("Error while registering gauges ", e);
        }
    }
}
