/******************************************************************************* * Copyright 2017-2018 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ #ifndef MATH_UTILS_HPP #define MATH_UTILS_HPP #include #include #include "utils.hpp" #include "nstl.hpp" #include "mkldnn_traits.hpp" namespace mkldnn { namespace impl { namespace math { template inline typename utils::enable_if::value, typename utils::remove_reference::type>::type saturate(const acc_t &x) { return (typename utils::remove_reference::type)x; } template inline typename utils::enable_if::value, typename utils::remove_reference::type>::type saturate(const acc_t &x) { acc_t v = x; if (v < (acc_t)nstl::numeric_limits::lowest()) v = (acc_t)nstl::numeric_limits::lowest(); if (v > (acc_t)nstl::numeric_limits::max()) v = (acc_t)nstl::numeric_limits::max(); return (typename utils::remove_reference::type)v; } template double saturate(const double &x) { double v = x; if (v < (double)nstl::numeric_limits::lowest()) v = (double)nstl::numeric_limits::lowest(); if (v > (double)nstl::numeric_limits::max()) v = (double)nstl::numeric_limits::max(); return v; } template <> inline int8_t saturate(const uint8_t &x) { return x <= 127u ? x : 127; } template <> inline uint8_t saturate(const int8_t &x) { return x >= 0 ? x : 0; } template inline typename utils::enable_if::value, out_t>::type out_round(float v, round_mode_t rmode = round_mode::nearest) { return (out_t)(rmode == round_mode::down ? floorf(v) : nearbyintf(v)); } template inline typename utils::enable_if::value, out_t>::type out_round(double v, round_mode_t rmode = round_mode::nearest) { return (out_t)(rmode == round_mode::down ? floor(v) : nearbyint(v)); } template inline typename utils::enable_if::value, out_t>::type out_round(float v, round_mode_t rmode = round_mode::nearest) { UNUSED(rmode); return v; } inline int gcd(int a, int b) { a = impl::nstl::abs(a); b = impl::nstl::abs(b); if (a < b) { int x = a; a = b; b = x; } if (b == 0) return a; int r; while ((r = a % b) != 0) { a = b; b = r; } return b; } template inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } /** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ inline int ilog2q(size_t v) { if (v == 0) return -1; int p = 0; # define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); # undef CP return p; } template ::type> inline U one_m_square(T x) { return (U)(1 - x) * (1 + x); } template ::type> inline U x_m_square(T x) { return (U)(1 - x) * x; } /* activation */ template ::type> inline U relu_fwd(T s, A alpha) { return s > 0 ? s : (U)(s * alpha); } template ::type> inline U relu_bwd(T dd, T s, A alpha) { return s > 0 ? dd : (U)(dd * alpha); } template ::type> inline U tanh_fwd(T s) { const float e = tanhf((float) s); return (U)e; } template ::type> inline U tanh_bwd(T dd, T s) { const float e = tanh_fwd((float) s); return (U)(dd * (1 - e) * (1 + e)); } template ::type> inline U elu_fwd(T s, A alpha) { return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); } template ::type> inline U elu_bwd(T dd, T s, A alpha) { return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); } template ::type> inline U square_fwd(T s) { return s * s; } template ::type> inline U square_bwd(T dd, T s) { return dd * 2 * s; } template ::type> inline U abs_fwd(T s) { return s > 0 ? s : -s; } template ::type> inline U abs_bwd(T dd, T s) { return s > 0 ? dd : s < 0 ? -dd : 0; } template ::type> inline U sqrt_fwd(T s) { return s > 0 ? (U)(::sqrtf((float)(s))) : 0; } template ::type> inline U sqrt_bwd(T dd, T s) { return s > 0 ? (U)(dd / (2 * ::sqrtf((float)(s)))) : 0; } template ::type> inline U linear_fwd(T s, A alpha, A beta) { return (U)(alpha * s + beta); } template ::type> inline U linear_bwd(T dd, T s, A alpha, A beta) { (void) s; (void) beta; return (U)(dd * alpha); } template ::type> inline U bounded_relu_fwd(T s, A alpha) { s = s > 0 ? s : 0; return s > alpha ? (U)(alpha) : s; } template ::type> inline U bounded_relu_bwd(T dd, T s, A alpha) { return dd * (0 < s && s < alpha ? 1 : 0); } template ::type> inline U soft_relu_fwd(T s) { float max_logf = 8.872284e+01; //::logf(FLT_MAX) return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; } template ::type> inline U soft_relu_bwd(T dd, T s) { return (U)(dd / (1 + ::expf((float)(-s)))); } template ::type> inline U logistic_fwd(T s) { U v = (U)(::expf((float) -s)); return 1 / (1 + v); } template ::type> inline U logistic_bwd(T dd, T s) { U v = logistic_fwd(s); return dd * v * (1 - v); } template ::type> inline U exp_fwd(T s) { return (U)(::expf((float)s)); } template ::type> inline U exp_bwd(T dd, T s) { return dd * exp_fwd(s); } template ::type> inline U gelu_fwd(T s) { const float sqrt_2_over_pi = 0.797884; const float fitting_const = 0.044715; float v = tanh_fwd(sqrt_2_over_pi * s * (1 + fitting_const * s * s)); return (U)(0.5 * s * (1. + v)); } template ::type> inline U gelu_bwd(T dd, T s) { const float sqrt_2_over_pi = 0.797884; const float fitting_const = 0.044715; float g = s * sqrt_2_over_pi * (1 + fitting_const * s * s); float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s); float v = tanh_fwd(g); return (U)(dd * 0.5 * (1. + v) * (1. + s * (1 - v) * dg)); } inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { using namespace alg_kind; using namespace utils; const bool preserves_zero = true && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic, eltwise_exp) && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); return preserves_zero; } inline float get_bias(const char *bias, size_t offset, data_type_t data_type) { if (!bias) return 0.0f; #define CASE(dt) \ case dt: return (float)((const prec_traits
::type *)bias)[offset] switch (data_type) { CASE(data_type::s8); CASE(data_type::u8); CASE(data_type::s32); CASE(data_type::f32); default: assert(!"unimplemented"); } return 0; // never happens (should probably be a NaN) #undef CASE } } } } #endif