package com.onaro.commons.metrics.executor;

import static com.onaro.commons.metrics.MetricsRegistryProvider.RegistryKey.*;
import static com.onaro.commons.metrics.MetricsRegistryProvider.ThreadPoolMonitoringKey.*;

import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import com.codahale.metrics.Gauge;
import com.codahale.metrics.MetricRegistry;
import com.onaro.commons.metrics.MetricsRegistryProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This class is the wrapper of ForkJoinPool and intended to be used to capture metrics around forkjoin pool usage.
 */
public class ForkJoinPoolMonitorExecutor extends ForkJoinPool {

    private static final Logger errorLogger = LoggerFactory.getLogger(ForkJoinPoolMonitorExecutor.class);
    /**
     * Fetch Common registry defined in MetricsRegistryProvider
     */
    private static final MetricRegistry METRIC_REGISTRY = MetricsRegistryProvider.getMetricRegistry(THREAD_POOL);
    /**
     * Used to distinguish between metrics across thread pools
     */
    private final String metricsPrefix;

    /**
     * ThreadPoolScope init as GLOBAL
     */
    private ThreadPoolScope threadPoolScope = ThreadPoolScope.GLOBAL;

    /**
     * Count the submitted tasks to forkjoinpool
     */
    private final AtomicInteger submittedTasksCount = new AtomicInteger();

    public ForkJoinPoolMonitorExecutor(String poolName, ThreadPoolScope threadPoolScope) {
        super(Runtime.getRuntime().availableProcessors(), buildThreadFactory(poolName), null, false);
        this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
        this.threadPoolScope = threadPoolScope;
        registerGauges(threadPoolScope);
    }

    /**
     * This method is wrapper on default forkJoinPool: new ForkJoinPool()
     *
     * @param poolName
     * @param isMethodScoped Pass as ThreadPoolScope.METHOD if threadpool is method scoped
     * @return
     */
    public static ForkJoinPoolMonitorExecutor defaultForkJoinPoolExecutor(String poolName, ThreadPoolScope threadPoolScope) {
        return new ForkJoinPoolMonitorExecutor(poolName, threadPoolScope);
    }

    /**
     * Build Custom forkjoin factory
     *
     * @param poolName
     */
    private static ForkJoinWorkerThreadFactory buildThreadFactory(String poolName) {
        final ForkJoinWorkerThreadFactory factory = new ForkJoinWorkerThreadFactory() {
            @Override
            public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
                final ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
                worker.setName(poolName + worker.getPoolIndex());
                return worker;
            }
        };
        return factory;
    }

    /**
     * Wrapper around ForkJoinPool invoke
     *
     * @see java.util.concurrent.ForkJoinPool#invoke(java.util.concurrent.ForkJoinTask)
     */
    @Override
    public <T> T invoke(ForkJoinTask<T> task) {
        submittedTasksCount.incrementAndGet();
        return super.invoke(task);
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) {
        if (tasks != null) {
            tasks.forEach((t) -> submittedTasksCount.incrementAndGet());
        }
        return super.invokeAll(tasks);
    }

    private void registerGauges(ThreadPoolScope threadPoolScope) {
        try {
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, CORE_POOL_SIZE, (Gauge<Integer>) this::getPoolSize, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, PARALLELISM_LEVEL, (Gauge<Integer>) this::getParallelism, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, QUEUE_SIZE, (Gauge<Long>) () -> (this.getQueuedTaskCount() + this.getQueuedSubmissionCount()), threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, ACTIVE_THREADS, (Gauge<Integer>) this::getActiveThreadCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, ASYNC_MODE, (Gauge<Boolean>) this::getAsyncMode, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, STEAL_COUNT, (Gauge<Long>) this::getStealCount, threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, SUBMITTED_TASKS, (Gauge<Integer>) () -> submittedTasksCount.get(), threadPoolScope);
            MetricsRegistryProvider.registerThreadPool(METRIC_REGISTRY, metricsPrefix, THREAD_POOL_SCOPE, (Gauge<String>) () -> this.threadPoolScope.toString(), threadPoolScope);
        } catch (Exception e) {
            errorLogger.warn("Error while registering gauges ", e);
        }
    }

}
