# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn object.
#
from collections import defaultdict
import re
from .nested_dict import nested_dict
from .gen_variable_type import should_trace
from .utils import write

try:
    from src.ATen.code_template import CodeTemplate
except ImportError:
    from tools.shared.module_loader import import_module
    CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate

# These functions require manual Python bindings or are not exposed to Python
SKIP_PYTHON_BINDINGS = [
    'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
    '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
    '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
    '_arange.*', '_range.*', '_linspace.*', '_logspace.*',
    '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
    'index', 'unique_dim_consecutive',
    '_indexCopy_', 'max_values', 'min_values',
    '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
    '_th_.*', '_thnn_.*',
    'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
    '_cholesky.*', '_triangular_solve.*',
    'slice', 'randint(_out)?',
    'item', '_local_scalar_dense', 'to',
    'copy_sparse_to_sparse_', 'copy_',
]

# These function signatures are not exposed to Python. Note that this signature
# list does not support regex.
SKIP_PYTHON_BINDINGS_SIGNATURES = [
    'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
    'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
    'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
    'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
]

PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    ${signatures}
  }, /*traceable=*/${traceable});
  ${unpack_self}
  ParsedArgs<${max_args}> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  ${declare_namedtuple_return_types}
  ${dispatch}
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
  HANDLE_TH_ERRORS
  ${declare_namedtuple_return_types}
  ${unpack_self}
  return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_CASE = CodeTemplate("""\
${cond} (r.idx == ${i}) {
  ${call_dispatch}
""")

PY_VARIABLE_OUT = CodeTemplate("""\
if (r.isNone(${out_idx})) {
  ${call_dispatch}
} else {
  ${call_dispatch_out}
}
""")

PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
if (r.isNone(${out_idx})) {
  ${call_dispatch}
} else {
  check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
                         r.layout(${layout_idx}), r.isNone(${layout_idx}),
                         r.device(${device_idx}), r.isNone(${device_idx}));
  ${call_dispatch_out}
}
""")

PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
${dispatch_name}(${actuals})""")

PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
${call_dispatch}.set_requires_grad(${requires_grad})""")

PY_VARIABLE_WRAP = CodeTemplate("""\
return wrap(${namedtuple_return_type}${call_dispatch});""")

PY_VARIABLE_DISPATCH = CodeTemplate("""\
inline ${simple_return_type} ${dispatch_name}(${formal_args}) {
  ${initialize_cuda}
  ${AutoNoGIL}
  return ${dispatch_call}(${dispatch_args});
}
""")

PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")

PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\
static PyStructSequence_Field fields${namedtuple_type_index}[] = {
  ${namedtuple_fields} {nullptr}
};
static PyStructSequence_Desc desc${namedtuple_type_index} = {
  "torch.return_types.${name}", nullptr,
  fields${namedtuple_type_index}, ${namedtuple_size}
};
static PyTypeObject type${namedtuple_type_index};
static bool namedtuple_type_initialized${namedtuple_type_index} = false;
if (!namedtuple_type_initialized${namedtuple_type_index}) {
  PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
  type${namedtuple_type_index}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
  namedtuple_type_initialized${namedtuple_type_index} = true;
}
""")

UNPACK_SELF = "auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;"

PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
${name}(${py_formal_args})""")

# XXX: if you got here because of an assertion failure, it doesn't mean
# it's enough to just extend the list here. Before you do this, make sure
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
    'Tensor',
    'std::tuple<Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,Tensor,int64_t>',
    'std::tuple<Tensor,Tensor,double,int64_t>',
    'std::vector<Tensor>',
    'Scalar', 'bool', 'int64_t', 'void*', 'void',
}

TENSOR_OPTIONS = CodeTemplate("""\
const auto options = TensorOptions()
    .dtype(${dtype})
    .device(${device})
    .layout(${layout}.layout)
    .requires_grad(${requires_grad})
    .pinned_memory(${pin_memory});
""")


def should_generate_python_binding(declaration):
    name = declaration['name']
    for pattern in SKIP_PYTHON_BINDINGS:
        if re.match('^' + pattern + '$', name):
            return False

    simple_types = [arg['simple_type'] for arg in declaration['arguments']]
    signature = '{}({})'.format(name, ', '.join(simple_types))
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
        if pattern == signature:
            return False

    # TODO: fix handling of SparseTensor. We don't want to generate Python
    # bindings to SparseTensor overloads, such as add(Tensor, SparseTensorRef),
    # since the Tensor-based signature already dynamically dispatches correctly.
    # However, sparse_mask only has a SparseTensor signature so we need to bind
    # that function.
    for arg in declaration['arguments']:
        if arg['type'] == 'SparseTensorRef' and declaration['name'] != 'sparse_mask':
            return False

    return True


def get_py_variable_methods(declarations):
    """
    Get declarations (grouped by name) which should be generated
    as methods on Tensor.
    """
    def should_bind(declaration):
        return (should_generate_python_binding(declaration) and
                declaration['mode'] != 'NN' and
                declaration.get('python_module') != 'nn' and
                'Tensor' in declaration['method_of'])

    return group_declarations_by_name(declarations, should_bind)


def gen_py_variable_methods(out, declarations, template_path):
    PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
    PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')

    py_variable_methods = get_py_variable_methods(declarations)

    env = create_python_bindings(py_variable_methods, True)
    write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
    write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)


def get_py_nn_functions(declarations):
    """
    Get declarations (grouped by name) which should be generated
    as functions in the "nn" module.
    """
    def should_bind(declaration):
        return (should_generate_python_binding(declaration) and
                (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))

    return group_declarations_by_name(declarations, should_bind)


def gen_py_nn_functions(out, declarations, template_path):
    PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
    PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
    PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')

    py_nn_functions = get_py_nn_functions(declarations)

    env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
    write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
    write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
    write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)


def get_py_torch_functions(declarations):
    """
    Get declarations (grouped by name) which should be generated
    as functions in the "torch" module.
    """
    def should_bind(declaration):
        return (should_generate_python_binding(declaration) and
                declaration['mode'] != 'NN' and
                declaration.get('python_module') != 'nn' and
                'namespace' in declaration['method_of'])

    return group_declarations_by_name(declarations, should_bind)


def gen_py_torch_functions(out, declarations, template_path):
    PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
    PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')

    py_torch_functions = get_py_torch_functions(declarations)

    env = create_python_bindings(py_torch_functions, has_self=False)
    write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
    write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)


def group_declarations_by_name(declarations, should_bind_fn):
    """Group declarations by name ignoring _out suffix"""
    groups = defaultdict(list)
    for declaration in declarations:
        name = declaration['name']
        if should_bind_fn(declaration):
            if name.endswith('_out'):
                groups[name[:-4]].append(declaration)
            else:
                groups[name].append(declaration)
    return groups


def get_type_default(declaration):
    if declaration['name'].startswith('randperm') or \
            declaration['name'] == 'tril_indices' or \
            declaration['name'] == 'triu_indices':
        return 'torch.int64'
    else:
        return 'None'


def create_python_bindings(python_functions, has_self, is_module=False):
    """Generates Python bindings to ATen functions"""
    py_methods = []
    py_method_defs = []
    py_method_dispatch = []

    unpack_methods = {
        'const Tensor &': 'tensor',
        'SparseTensorRef': 'tensor',
        'Tensor &': 'tensor',
        'Generator *': 'generator',
        'Storage &': 'storage',
        'const Type &': 'scalartype',
        'const THPLayout &': 'layout',
        'const Device &': 'device',
        'c10::optional<ScalarType>': 'scalartypeOptional',
        'c10::optional<Scalar>': 'scalarOptional',
        'c10::optional<int64_t>': 'toInt64Optional',
        'c10::optional<bool>': 'toBoolOptional',
        'IntArrayRef': 'intlist',
        'int64_t': 'toInt64',
        'bool': 'toBool',
        'double': 'toDouble',
        'std::string': 'string',
    }

    unpack_with_default_methods = {
        'IntArrayRef': 'setDefaultIntlist',
        'Scalar': 'scalarWithDefault',
        'int64_t': 'toInt64WithDefault',
        'bool': 'setDefaultBool',
        'double': 'setDefaultDouble',
        'const Type &': 'scalartypeWithDefault',
        'const THPLayout &': 'layoutWithDefault',
        'const Device &': 'deviceWithDefault',
        'ScalarType': 'scalartypeWithDefault',
    }

    def emit_single_dispatch(declaration, out_idx, base_env):
        env = {}
        simple_return_type = declaration['return_type'].replace(' &', '')
        assert simple_return_type in SUPPORTED_RETURN_TYPES, \
            declaration['name'] + ' returns unsupported type: ' + simple_return_type

        body = []
        actuals = []
        formal_args = []
        arg_idx = 0

        def is_output(arg):
            return arg.get('output', False)

        inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
        outputs = [arg for arg in declaration['arguments'] if is_output(arg)]

        has_tensor_options = any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments'])

        def get_type_args(args):
            return [arg for arg in args if arg['simple_type'] == 'Type']
        type_actual_args = get_type_args(declaration['arguments'])
        type_binding_args = get_type_args(declaration['python_binding_arguments'])
        assert len(type_actual_args + type_binding_args) <= 1
        if type_binding_args and len(outputs) == 0:
            # out(s) determines the dtype if it is present, so only use this if there are no outputs.
            type_args = type_binding_args
        else:
            type_args = type_actual_args

        if type_args and len(outputs) > 1:
            raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")

        def parse_arg(arg, arg_index, unpack_args=False):
            name = arg['name']
            typename = arg['type']
            if typename.startswith('IntArrayRef['):
                typename = 'IntArrayRef'
            if typename.startswith('LongTensor'):
                typename = 'Tensor'

            if arg.get('python_default_init'):
                assert typename in unpack_with_default_methods, \
                    '`{}` type is not supported in python_default_init'.format(typename)
                unpack_with_default = unpack_with_default_methods.get(typename)
                default_expr = arg.get('python_default_init')
                expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
            else:
                unpack = unpack_methods.get(typename, typename.lower())
                expr = 'r.{}({})'.format(unpack, arg_index)

            if unpack_args:
                body.append('auto {} = {};'.format(name, expr))
                expr = name

            if typename == 'SparseTensorRef':
                expr = 'SparseTensorRef({})'.format(expr)

            dispatch_type = typename
            if dispatch_type == 'Tensor':
                dispatch_type = 'const Tensor &'
            elif dispatch_type == 'Tensor &':
                dispatch_type = 'Tensor'
            elif dispatch_type == 'const Device &':
                dispatch_type = 'c10::optional<int32_t>'
            formal = '{} {}'.format(dispatch_type, name)
            return expr, formal

        def append_actuals_formals(actual, formal):
            actuals.append(actual)
            formal_args.append(formal)

        # We always want to unpack when we have TensorOptions.
        unpack = has_tensor_options
        for arg in inputs:
            if arg['simple_type'] in ['Type', 'TensorOptions']:
                continue
            if has_self and arg['name'] == 'self':
                formal_args.append('Tensor & self')
                actuals.append('self')
                continue
            append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
            arg_idx += 1

        if len(outputs) == 1:
            append_actuals_formals(*parse_arg(outputs[0], arg_idx))
        elif len(outputs) > 1:
            N = len(outputs)
            body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
            for i, arg in enumerate(outputs):
                formal_args.append('Tensor & {}'.format(arg['name']))
                actuals.append('results[{}]'.format(i))

        layout = None
        parsed_type_args = None
        # type args go after the outputs to match the signature generation.
        arg_idx = arg_idx if out_idx is None else out_idx + 1
        for arg in type_args:
            parsed_type_args = parse_arg(arg, arg_idx, unpack)
            arg_idx += 1

        # check python_binding_arguments
        has_device_bind = False
        requires_grad = None
        python_binding_arguments = declaration.get('python_binding_arguments', [])
        if 'dtype' in (a['name'] for a in python_binding_arguments):
            if not has_tensor_options:
                arg_idx += 1

        if 'layout' in (a['name'] for a in python_binding_arguments):
            layout_idx, device_idx, pin_memory_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2, arg_idx + 3)
        else:
            device_idx, pin_memory_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)

        device = None
        for arg in python_binding_arguments:
            if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
                pass  # already handled by type_dispatched_args
            elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
                # out(s) determines the type and layout if it is present, so only use this if there are no outputs.
                if len(outputs) == 0:
                    layout = parse_arg(arg, layout_idx)[0]
            elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
                if len(outputs) == 0:
                    assert parsed_type_args
                    assert layout
                    device, device_type = parse_arg(arg, device_idx, True)

                    if not has_tensor_options:
                        # add type, device formals and corresponding actuals.
                        # The type actual is the ATen type mapped from (ScalarType, Layout, Device)
                        # The device actual is the corresponding AutoGPU index for the Device.
                        formal_args.append(parsed_type_args[1])
                        formal_args.append(device_type)
                        actuals.append("torch::getVariableType({}, {}, {})".format(parsed_type_args[0], layout, device))
                        actuals.append('{}.index()'.format(device))

                    has_device_bind = True
            elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
                requires_grad = parse_arg(arg, requires_grad_idx)[0]
            elif arg['name'] == 'pin_memory' and arg['simple_type'] == 'bool':
                pin_memory = parse_arg(arg, pin_memory_idx)[0]
            else:
                raise RuntimeError(("found {} in python_binding_arguments but only "
                                    "\"bool pin_memory\", \"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
                                    "\"Device device\" are supported".format(arg)))

        dtype = parsed_type_args[0] if parsed_type_args else None
        if has_tensor_options and all([dtype, device, layout, requires_grad]):
            body.append(TENSOR_OPTIONS.substitute({
                'dtype': dtype,
                'layout': layout,
                'device': device,
                'requires_grad': requires_grad,
                'pin_memory': pin_memory,
            }))
            formal_args.append('const TensorOptions & options')
            actuals.append('options')

        env['unpack_args'] = []
        env['formal_args'] = formal_args
        env['actuals'] = actuals

        if has_tensor_options:
            env['initialize_cuda'] = 'maybe_initialize_cuda(options);'
        else:
            env['initialize_cuda'] = ''

        if 'call_args' in declaration:
            env['dispatch_args'] = declaration['call_args']
        else:
            env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]

        if 'Tensor' in declaration['method_of']:
            env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
            env['dispatch_call'] = 'self.{}'.format(declaration['name'])
        elif 'namespace' in declaration['method_of']:
            namespace = 'torch' if (has_tensor_options or declaration['name'].endswith('_like')) else 'at'
            env['dispatch_call'] = '{}::{}'.format(namespace, declaration['name'])
        else:
            raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')

        env['AutoNoGIL'] = 'AutoNoGIL no_gil;' if not declaration['with_gil'] else ''

        # Use the simple_return_type (Tensor) rather than the fancy return type
        # (Tensor &).  This is important because the dispatch functions take
        # mutable arguments *by value*, not by reference.  If you then return
        # a a reference to such an argument, you will now have a pointer to a
        # dangling stack entry.  Not good.
        #
        # You want:
        #
        #   Tensor dispatch_selu_(Tensor self) { return at::selu_(self); }
        #
        # *not*
        #
        #   Tensor& dispatch_selu_(Tensor self) { return at::selu_(self); }
        #
        # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
        # codegen looks like dispatch_selu_(wrap(tensor)), and you can't take a
        # mutable reference to temporary.  Maybe we could assign it to a
        # variable itself.)
        env['simple_return_type'] = simple_return_type

        env = nested_dict(env, nested_dict(base_env, declaration))
        call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
        if requires_grad and not has_tensor_options:
            call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
                                                                     requires_grad=requires_grad)
        if simple_return_type == 'void':
            body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
            body.append('Py_RETURN_NONE;')
        else:
            body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
        py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
        return body

    def emit_dispatch(i, dictionary, base_env):
        if 'out' in dictionary:
            out_idx = len([arg for arg in dictionary['out']['arguments']
                           if not arg.get('output', False)])
            env = {}
            env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
            env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)

            has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
            if has_dtype_bind:
                body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
                                                             layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
            else:
                body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
        else:
            body = emit_single_dispatch(dictionary['base'], None, base_env)

        cond = 'if' if i == 0 else '} else if'
        return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)

    def get_python_binding_arguments(declaration):
        python_binding_arguments = []
        has_tensor_input_arg = False
        has_type_input_arg = False
        has_options_arg = False
        for arg in declaration['arguments']:
            if arg.get('output', False):
                continue
            typename = arg['simple_type']
            if typename in ['Tensor', 'TensorList']:
                has_tensor_input_arg = True
            if arg['simple_type'] == 'Type':
                has_type_input_arg = True
            elif arg['simple_type'] == 'TensorOptions':
                has_options_arg = True
            if arg['name'] == 'requires_grad':
                raise ValueError("argument named requires_grad not supported")

        has_tensor_return = False
        for ret in declaration['returns']:
            if ret['dynamic_type'] in ['Tensor', 'TensorList']:
                # this probably won't work if one of the returns is not a tensor, but it will
                # produce a compile-time error that is obvious
                has_tensor_return = True

        is_like_function = name.endswith('_like')
        is_like_function_with_options = is_like_function and has_options_arg
        is_factory_function = has_tensor_return and not has_tensor_input_arg
        is_factory_or_like_function = has_tensor_return and (not has_tensor_input_arg or is_like_function)

        if (is_factory_function and not has_type_input_arg) or has_options_arg:
            default_type = get_type_default(declaration)
            py_default_dtype = 'self.scalar_type()' if is_like_function_with_options else None
            dtype_arg = {
                'default': default_type,
                'dynamic_type': 'Type',
                'kwarg_only': True,
                'name': 'dtype',
                'type': 'const Type &',
                'simple_type': 'Type',
                'python_default_init': py_default_dtype,
            }
            python_binding_arguments.append(dtype_arg)
        if is_factory_function or is_like_function_with_options:
            py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_function_with_options else None
            layout_arg = {
                'default': 'torch.strided',
                'dynamic_type': 'Layout',
                'kwarg_only': True,
                'name': 'layout',
                'type': 'const THPLayout &',
                'simple_type': 'Layout',
                'python_default_init': py_default_layout,
            }
            python_binding_arguments.append(layout_arg)
            py_default_device = 'self.device()' if is_like_function_with_options else None
            device_arg = {
                'default': 'None',
                'default_init': 'None',
                'dynamic_type': 'Device',
                'kwarg_only': True,
                'name': 'device',
                'type': 'const Device &',
                'simple_type': 'Device',
                'python_default_init': py_default_device
            }
            python_binding_arguments.append(device_arg)
            pin_memory_arg = {
                'default': False,
                'dynamic_type': 'bool',
                'kwarg_only': True,
                'name': 'pin_memory',
                'type': 'bool',
                'simple_type': 'bool',
            }
            python_binding_arguments.append(pin_memory_arg)
        if is_factory_or_like_function:
            requires_grad_arg = {
                'default': False,
                'dynamic_type': 'bool',
                'kwarg_only': True,
                'name': 'requires_grad',
                'type': 'bool',
                'simple_type': 'bool',
            }
            python_binding_arguments.append(requires_grad_arg)
        return python_binding_arguments

    def emit_namedtuple_return_type_def(declaration, next_index):
        returns = declaration['returns']
        if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
            declaration['namedtuple_return_type'] = ''
            return '', next_index
        declaration['namedtuple_type_index'] = next_index
        declaration['namedtuple_fields'] = ''
        for x in returns:
            # See Note [field_name versus name]
            if 'field_name' not in x:
                # When building on Windows, `PyStructSequence_UnnamedField` could not be
                # resolved by the linker for some reason, which cause error in building:
                #
                # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
                # PyStructSequence_UnnamedField
                #
                # Thus, at this point in time, we do not support unnamed
                # fields in namedtuple; you must either name all fields,
                # or none of them.
                raise ValueError("Unnamed field is not supported by codegen")
            else:
                declaration['namedtuple_fields'] += '{"' + x['field_name'] + '", ""}, '
        declaration['namedtuple_size'] = len(returns)
        declaration['namedtuple_return_type'] = '&type{}, '.format(next_index)
        return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1

    def process_function(name, declarations):
        for declaration in declarations:
            declaration['python_binding_arguments'] = get_python_binding_arguments(declaration)

        env = {
            'name': name,
            'dispatch_name': 'dispatch_{}'.format(name),
            'pycname': 'THPVariable_{}'.format(name),
            'signatures': [],
            'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
            'unpack_self': [],
            'dispatch': [],
            'declare_namedtuple_return_types': '',
        }

        if has_self:
            env['unpack_self'] = [UNPACK_SELF]

        # generate namedtuple type declare
        next_index = 0
        for declaration in declarations:
            typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index)
            env['declare_namedtuple_return_types'] += typedef

        # emit dispatch
        grouped = group_declarations(declarations)
        for i, dictionary in enumerate(grouped):
            signature = dictionary['signature']
            if has_self:
                signature = signature.replace('Tensor self, ', '')
                signature = signature.replace('Tensor self', '')
            if not has_self:
                # Use 'input' instead of 'self' for NN functions
                signature = signature.replace('Tensor self', 'Tensor input')
            signature = signature.replace('SparseTensorRef', 'Tensor')
            if dictionary['base'].get('deprecated', False):
                signature += '|deprecated'
            env['signatures'].append('"{}",'.format(signature))
            env['dispatch'].append(emit_dispatch(i, dictionary, env))

        env['dispatch'].append('}')

        env['traceable'] = 'true' if all(should_trace(d) for d in declarations) else 'false'

        if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
            tmpl = PY_VARIABLE_METHOD_NOARGS
            env['actuals'] = ['self']
            env['flags'] = 'METH_NOARGS'
            env['namedtuple_return_type'] = declarations[0]['namedtuple_return_type']
        else:
            tmpl = PY_VARIABLE_METHOD_VARARGS
            env['flags'] = 'METH_VARARGS | METH_KEYWORDS'

        if not is_module and not has_self:
            env['flags'] += ' | METH_STATIC'

        py_methods.append(tmpl.substitute(env))
        py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))

    for name in sorted(python_functions.keys()):
        process_function(name, python_functions[name])

    return {
        'py_methods': py_methods,
        'py_method_defs': py_method_defs,
        'py_method_dispatch': py_method_dispatch,
    }


def group_declarations(declarations):
    """Returns a list of dictionaries containing the optional keys:

       "base": the regular ATen declaration (e.g. conv2d)
       "out": the out variant (e.g. conv2d_out)
       "signature": the signature used for Python argument parsing
    """
    grouped = defaultdict(dict)

    # first group by signature ignoring out arguments
    for declaration in declarations:
        signature = get_python_signature(declaration, False)
        v = grouped[signature]
        if declaration['name'].endswith('_out'):
            v['out'] = declaration
            # prefer the signature with optional out=... arguments
            v['signature'] = get_python_signature(declaration, True)
        else:
            v['base'] = declaration
            if 'signature' not in v:
                v['signature'] = signature

    result = []
    for _, dictionary in sorted(grouped.items()):
        if 'base' not in dictionary:
            raise RuntimeError("'base' not in dictionary", dictionary)
        result.append(dictionary)
    return sort_declarations(result)


# This function declares a partial order on declarations, and sorts them according
# to its linear extension. This is necessary, because there's some ambiguity in the
# choice of overload, and we want a different order.
#
# See Note[Order of overloads matters]
def sort_declarations(grouped_decls):

    # TODO: This is a hack!
    #
    # For some reason, when you specify a Scalar argument in a native
    # function, you get a Declarations.yaml entry that looks like this:
    #
    #   - default: 1
    #     dynamic_type: Scalar
    #     is_nullable: false
    #     kwarg_only: true
    #     name: alpha
    #     type: Scalar
    #
    # This is contrast to when there is a 'real' argument in TH
    # Declarations.cwrap; this gets (correctly?) translated into
    # dynamic_type: real, and type: Scalar.  I would like to fix this
    # at the source but I have never understood what dynamic_type is
    # supposed to be.
    def normalized_dynamic_type(arg):
        if arg['dynamic_type'] == 'real':
            return 'Scalar'
        return arg['dynamic_type']

    def is_coord_smaller(arg1, arg2):
        return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'

    def is_smaller(d1, d2):
        """Returns True if d1 < d2 in the partial order."""
        args1, args2 = d1['base']['arguments'], d2['base']['arguments']
        if len(args1) != len(args2):
            return False
        any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
        all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
                                   is_coord_smaller(arg1, arg2)
                                   for arg1, arg2 in zip(args1, args2))
        return any_smaller and all_smaller_or_equal

    # Construct the relation graph
    larger_than = defaultdict(set)
    for i1, decl1 in enumerate(grouped_decls):
        for i2, decl2 in enumerate(grouped_decls):
            if is_smaller(decl1, decl2):
                larger_than[i1].add(i2)

    if not larger_than:
        return grouped_decls

    # Use a topological sort to sort decls according to the partial order.
    sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
                   if i not in larger_than]
    for i, decl in sorted_deps:
        for i2 in sorted(larger_than.keys()):
            larger = larger_than[i2]
            larger.discard(i)
            if not larger:
                del larger_than[i2]
                sorted_deps.append((i2, grouped_decls[i2]))

    return [decl for i, decl in sorted_deps]


def get_python_signature(declaration, include_out):
    # Compute the Python function signature for argument parsing,
    # as specified in torch/csrc/utils/python_arg_parser.h.  WARNING:
    # this is NOT the same type signature as specified by PEP 484
    # as understood by mypy; our format was independently developed
    # and has some quirks to make it more suitable specifically
    # for error parsing.
    #
    # For a translation to mypy-valid type signatures, see
    # tools/gen_pyi.py.  If you change any logic here, please
    # check that file too.
    py_formal_args = []
    output_args = []
    type_args = []
    positional = True

    def get_py_formal_arg(arg):
        typename = arg['simple_type']
        typename = typename if typename != 'Type' else 'ScalarType'

        # TODO: remove this and make optional types in simple_type to be consistent across
        # tensor and other types after make Tensor? be optional instead of undefined
        if arg.get('is_nullable') and '?' not in typename:
            typename = '{}?'.format(typename)

        if arg.get('size') is not None:
            typename = '{}[{}]'.format(typename, arg['size'])
        param = typename + ' ' + arg['name']
        default = None
        if arg.get('default') is not None:
            default = arg['default']
            if default == 'nullptr' or default == 'nullopt' or default == '{}':
                default = 'None'
        if default is not None:
            param += '=' + str(default)
        return param

    for arg in declaration['arguments']:
        if arg.get('output', False):
            output_args.append(arg)
            continue
        if arg['simple_type'] == 'Type':
            type_args.append(arg)
            continue
        # Skip `TensorOptions` in Python, as it is only used on the C++ side.
        if arg['simple_type'] == 'TensorOptions':
            continue
        if arg.get('kwarg_only', False) and positional:
            py_formal_args.append('*')
            positional = False
        param = get_py_formal_arg(arg)
        py_formal_args.append(param)

    # add output arguments
    name = declaration['name']
    if name.endswith('_out'):
        name = name[:-4]

    if len(output_args) > 0 and include_out:
        assert declaration['name'].endswith('_out')
        if positional:
            py_formal_args.append('*')
            positional = False
        typenames = [arg['simple_type'] for arg in output_args]
        if len(typenames) > 1:
            typename = 'TensorList[{}]'.format(len(typenames))
        else:
            typename = typenames[0]
        if len(output_args) == 1:
            # The nn module bindings are often not exposed to the user directly
            # but via torch.nn modules and functionals.
            py_formal_args.append(typename + ' ' + output_args[0]['name'] + '=None')
        else:
            # NB: For more than 1 output args the type name is a TensorList
            # and as such we don't (yet) need to consider the naming.
            py_formal_args.append(typename + ' out=None')

    # we could put this in the loop above but we want to ensure both type dispatched args
    # and python binding arguments are after the out argument; this matches the case
    # where there is a python binding argument dtype, which is necessary to match
    # the function signatures between the out and non-out variant.
    assert len(type_args) <= 1
    for arg in type_args:
        if positional:  # assume type_args should be kwarg_only.
            py_formal_args.append('*')
            positional = False
        py_formal_args.append(get_py_formal_arg(arg))

    if len(declaration['python_binding_arguments']) > 0:
        for arg in declaration['python_binding_arguments']:
            if arg.get('kwarg_only', False) and positional:
                py_formal_args.append('*')
                positional = False
            py_formal_args.append(get_py_formal_arg(arg))

    # Python function signature.
    # This is the string that we give to FunctionParameter, which is
    # then parsed into the actual structure which we do parsing
    # with.
    return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)