# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing utilities."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
import sys
import types
import unittest

from tensorflow.python.eager import def_function
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class AutoGraphTestCase(test.TestCase):
  """Tests specialized for AutoGraph, which run as tf.functions.

  These tests use a staged programming-like approach: most of the test code runs
  as-is inside a tf.function, but the assertions are lifted outside the
  function, and run with the corresponding function values instead.

  For example, the test:

      def test_foo(self):
        baz = bar();
        self.assertEqual(baz, value)

  is equivalent to writing:

      def test_foo(self):
        @tf.function
        def test_fn():
          baz = bar();
          return baz, value

        baz_actual, value_actual = test_fn()
        self.assertEqual(baz_actual, value_actual)

  Only assertions that require evaluation outside the function are lifted
  outside the function scope. The rest execute inline, at function creation
  time.
  """

  def __new__(cls, *args):
    obj = super().__new__(cls)

    for name in cls.__dict__:
      if not name.startswith(unittest.TestLoader.testMethodPrefix):
        continue
      m = getattr(obj, name)
      if callable(m):
        wrapper = obj._run_as_tf_function(m)
        setattr(obj, name, types.MethodType(wrapper, obj))

    return obj

  def _op_callback(
      self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
    self.trace_log.append(op_type)

  def _run_as_tf_function(self, fn):

    def wrapper(self):
      @def_function.function(autograph=False)  # Testing autograph itself.
      def fn_wrapper():
        self.assertions = []
        self.raises_cm = None
        self.graph_assertions = []
        self.trace_log = []
        fn()
        targets = [args for _, args in self.assertions]
        return targets

      try:
        tensors = fn_wrapper()

        for assertion in self.graph_assertions:
          assertion(fn_wrapper.get_concrete_function().graph)

        actuals = self.evaluate(tensors)

      except:  # pylint:disable=bare-except
        if self.raises_cm is not None:
          # Note: Yes, the Raises and function contexts cross.
          self.raises_cm.__exit__(*sys.exc_info())
          return
        else:
          raise

      for (assertion, _), values in zip(self.assertions, actuals):
        assertion(*values)

    return wrapper

  def variable(self, name, value, dtype):
    with ops.init_scope():
      if name not in self.variables:
        self.variables[name] = variables.Variable(value, dtype=dtype)
        self.evaluate(self.variables[name].initializer)
    return self.variables[name]

  def setUp(self):
    super().setUp()
    self.variables = {}
    self.trace_log = []
    self.raises_cm = None
    op_callbacks.add_op_callback(self._op_callback)

  def tearDown(self):
    op_callbacks.remove_op_callback(self._op_callback)
    self.trace_log = None
    self.variables = None
    super().tearDown()

  def assertGraphContains(self, op_regex, n):
    def assertion(graph):
      matches = []
      for node in graph.as_graph_def().node:
        if re.match(op_regex, node.name):
          matches.append(node)
      for fn in graph.as_graph_def().library.function:
        for node_def in fn.node_def:
          if re.match(op_regex, node_def.name):
            matches.append(node_def)
      self.assertLen(matches, n)

    self.graph_assertions.append(assertion)

  def assertOpCreated(self, op_type):
    self.assertIn(op_type, self.trace_log)

  def assertOpsNotCreated(self, op_types):
    self.assertEmpty(set(op_types) & set(self.trace_log))

  def assertNoOpsCreated(self):
    self.assertEmpty(self.trace_log)

  def assertEqual(self, *args):
    self.assertions.append((super().assertEqual, list(args)))

  def assertDictEqual(self, *args):
    self.assertions.append((super().assertDictEqual, list(args)))

  def assertRaisesRuntime(self, *args):
    if self.raises_cm is not None:
      raise ValueError('cannot use more than one assertRaisesRuntime in a test')
    self.raises_cm = self.assertRaisesRegex(*args)
    self.raises_cm.__enter__()