# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Handles function calls, by generating compiled function names and calls.

Note: this transformer does not rename the top level object being converted;
that is the caller's responsibility.

Requires function_scopes.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gast

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.utils import ag_logging


# TODO(mdan): Rename to FunctionCallsTransformer.


class _Function(object):

  no_root = True

  def __init__(self):
    self.context_name = None


set_trace_warned = False


class _ArgTemplateBuilder(object):
  """Constructs a tuple representing the positional arguments in a call.

  Example (yes, it's legal Python 3):

      f(*args1, b, *args2, c, d)  ->  args1 + (b,) + args2 + (c, d)
  """

  def __init__(self):
    self._arg_accumulator = []
    self._argspec = []
    self._finalized = False

  def _consume_args(self):
    if self._arg_accumulator:
      self._argspec.append(
          gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load()))
      self._arg_accumulator = []

  def add_arg(self, a):
    self._arg_accumulator.append(a)

  def add_stararg(self, a):
    self._consume_args()
    self._argspec.append(
        gast.Call(
            gast.Name(
                'tuple', ctx=gast.Load(), annotation=None, type_comment=None),
            args=[a],
            keywords=()))

  def finalize(self):
    self._consume_args()
    self._finalized = True

  def to_ast(self):
    assert self._finalized
    if self._argspec:
      result = self._argspec[0]
      for i in range(1, len(self._argspec)):
        result = gast.BinOp(result, gast.Add(), self._argspec[i])
      return result
    return gast.Tuple([], gast.Load())


class CallTreeTransformer(converter.Base):
  """Transforms the call tree by renaming transformed symbols."""

  def visit_Lambda(self, node):
    if not anno.hasanno(node, 'function_context_name'):
      # Lambda functions created during the conversion process have no
      # context manager.
      return self.generic_visit(node)
    with self.state[_Function] as fn_scope:
      fn_scope.context_name = anno.getanno(node, 'function_context_name')
      return self.generic_visit(node)

  def visit_FunctionDef(self, node):
    # Decorators and arg defaults are part of the outer scope.
    node.decorator_list = self.visit_block(node.decorator_list)
    node.args.defaults = self.visit_block(node.args.defaults)
    for i, d in enumerate(node.args.kw_defaults):
      if d is not None:
        node.args.kw_defaults[i] = self.visit(d)
    with self.state[_Function] as fn_scope:
      # Note: if the conversion process ever creates helper functions, this
      # assumption will no longer hold.
      assert anno.hasanno(node, 'function_context_name'), (
          'The function_scopes converter always creates a scope for functions.')
      fn_scope.context_name = anno.getanno(node, 'function_context_name')
      node.body = self.visit_block(node.body)
      if node.returns:
        node.returns = self.visit(node.returns)
      return node

  def visit_With(self, node):
    # Context manager calls (in node.items) are not converted.
    node.body = self.visit_block(node.body)
    return node

  def _args_to_tuple(self, node):
    """Ties together all positional and *arg arguments in a single tuple."""
    # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
    # For example for
    #   f(a, b, *args)
    # instead of writing:
    #   (a, b) + args
    # just write this?
    #   tuple(a, b, *args)
    builder = _ArgTemplateBuilder()
    for a in node.args:
      if isinstance(a, gast.Starred):
        builder.add_stararg(a.value)
      else:
        builder.add_arg(a)
    builder.finalize()
    return builder.to_ast()

  def _kwargs_to_dict(self, node):
    """Ties together all keyword and **kwarg arguments in a single dict."""
    if node.keywords:
      return gast.Call(
          gast.Name(
              'dict', ctx=gast.Load(), annotation=None, type_comment=None),
          args=(),
          keywords=node.keywords)
    else:
      return parser.parse_expression('None')

  def visit_Call(self, node):
    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
    function_context_name = self.state[_Function].context_name
    node = self.generic_visit(node)

    # TODO(mdan): Refactor converted_call as a 'Call' operator.

    # Calls to the internal 'ag__' module are never converted (though their
    # arguments might be).
    if full_name.startswith('ag__.'):
      return node

    # Calls to the function context manager (inserted by function_scopes) are
    # also safe.
    if full_name.startswith(function_context_name + '.'):
      return node

    # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
    # the normal mechanisms to bypass these literals because they are sensitive
    # to the frame they are being called from.
    # TODO(mdan): Generalize this to a "static allowlist" config.
    if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
      global set_trace_warned
      if not set_trace_warned:
        # TODO(mdan): Update and shorten once available on tensorflow.org.
        ag_logging.warn(
            'Detected `pdb.set_trace()` in user code. The code'
            ' generated by AutoGraph is not optimized for step-by-step'
            ' debugging. See https://github.com/tensorflow/tensorflow/'
            'blob/master/tensorflow/python/autograph/g3doc/reference/'
            'debugging.md.')
        set_trace_warned = True
      return node

    if (full_name == 'print' and
        not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
      return node

    template = """
      ag__.converted_call(func, args, kwargs, function_ctx)
    """
    new_call = templates.replace_as_expression(
        template,
        func=node.func,
        args=self._args_to_tuple(node),
        kwargs=self._kwargs_to_dict(node),
        function_ctx=function_context_name)

    return new_call


def transform(node, ctx):
  """Transform function call to the compiled counterparts.

  Args:
    node: AST
    ctx: EntityContext
  Returns:
    A tuple (node, new_names):
        node: The transformed AST
        new_names: set(string), containing any newly-generated names
  """
  node = qual_names.resolve(node)

  node = CallTreeTransformer(ctx).visit(node)
  return node