using System;

using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Crypto.Utilities;

namespace Org.BouncyCastle.Crypto.Generators
{
    /// <summary>Generator for MGF1 as defined in Pkcs 1v2</summary>
    public sealed class Mgf1BytesGenerator
        : IDerivationFunction
    {
        private readonly IDigest m_digest;
        private readonly int m_hLen;

        private byte[] m_buffer;

        /// <param name="digest">the digest to be used as the source of generated bytes</param>
        public Mgf1BytesGenerator(IDigest digest)
        {
            m_digest = digest;
            m_hLen = digest.GetDigestSize();
        }

        public void Init(IDerivationParameters parameters)
        {
            if (!(parameters is MgfParameters mgfParameters))
                throw new ArgumentException("MGF parameters required for MGF1Generator");

            m_buffer = new byte[mgfParameters.SeedLength + 4 + m_hLen];
            mgfParameters.GetSeed(m_buffer, 0);
        }

        /// <summary>the underlying digest.</summary>
        public IDigest Digest => m_digest;

        /// <summary>Fill <c>len</c> bytes of the output buffer with bytes generated from the derivation function.
        /// </summary>
        public int GenerateBytes(byte[] output, int outOff, int length)
        {
            Check.OutputLength(output, outOff, length, "output buffer too short");

#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
            return GenerateBytes(output.AsSpan(outOff, length));
#else
            int hashPos = m_buffer.Length - m_hLen;
            int counterPos = hashPos - 4;
            uint counter = 0;

            m_digest.Reset();

            int end = outOff + length;
            int limit = end - m_hLen;

            while (outOff <= limit)
            {
                Pack.UInt32_To_BE(counter++, m_buffer, counterPos);

                m_digest.BlockUpdate(m_buffer, 0, hashPos);
                m_digest.DoFinal(output, outOff);

                outOff += m_hLen;
            }

            if (outOff < end)
            {
                Pack.UInt32_To_BE(counter, m_buffer, counterPos);

                m_digest.BlockUpdate(m_buffer, 0, hashPos);
                m_digest.DoFinal(m_buffer, hashPos);

                Array.Copy(m_buffer, hashPos, output, outOff, end - outOff);
            }

            return length;
#endif
        }

#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER
        public int GenerateBytes(Span<byte> output)
        {
            int hashPos = m_buffer.Length - m_hLen;
            int counterPos = hashPos - 4;
            uint counter = 0;

            m_digest.Reset();

            int pos = 0, length = output.Length, limit = length - m_hLen;

            while (pos <= limit)
            {
                Pack.UInt32_To_BE(counter++, m_buffer.AsSpan(counterPos));

                m_digest.BlockUpdate(m_buffer.AsSpan(0, hashPos));
                m_digest.DoFinal(output[pos..]);

                pos += m_hLen;
            }

            if (pos < length)
            {
                Pack.UInt32_To_BE(counter, m_buffer.AsSpan(counterPos));

                m_digest.BlockUpdate(m_buffer.AsSpan(0, hashPos));
                m_digest.DoFinal(m_buffer.AsSpan(hashPos));
                m_buffer.AsSpan(hashPos, length - pos).CopyTo(output[pos..]);
            }

            return length;
        }
#endif
    }
}
