diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index 716a131f58..cca88ffbc7 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -2,7 +2,7 @@ from collections import OrderedDict from copy import copy from functools import partial -from typing import List, Optional, Sequence, cast +from typing import Dict, List, Optional, Sequence, Tuple, cast import aesara.tensor as at from aesara import function @@ -81,6 +81,81 @@ def local_traverse(out): return ret +def construct_nominal_fgraph( + inputs: Sequence[Variable], outputs: Sequence[Variable] +) -> Tuple[ + FunctionGraph, + Sequence[Variable], + Dict[Variable, Variable], + Dict[Variable, Variable], +]: + """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" + dummy_inputs = [] + for n, inp in enumerate(inputs): + if ( + not isinstance(inp, Variable) + or isinstance(inp, Constant) + or isinstance(inp, SharedVariable) + ): + raise TypeError( + f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" + ) + + dummy_inputs.append(inp.type()) + + dummy_shared_inputs = [] + shared_inputs = [] + for var in graph_inputs(outputs, inputs): + if isinstance(var, SharedVariable): + # To correctly support shared variables the inner-graph should + # not see them; otherwise, there will be problems with + # gradients. + # That's why we collect the shared variables and replace them + # with dummies. + shared_inputs.append(var) + dummy_shared_inputs.append(var.type()) + elif var not in inputs and not isinstance(var, Constant): + raise MissingInputError(f"OpFromGraph is missing an input: {var}") + + replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs)) + + new = rebuild_collect_shared( + cast(Sequence[Variable], outputs), + inputs=inputs + shared_inputs, + replace=replacements, + copy_inputs_over=False, + ) + ( + local_inputs, + local_outputs, + (clone_d, update_d, update_expr, new_shared_inputs), + ) = new + + assert len(local_inputs) == len(inputs) + len(shared_inputs) + assert len(local_outputs) == len(outputs) + assert not update_d + assert not update_expr + assert not new_shared_inputs + + fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) + + # The inputs need to be `NominalVariable`s so that we can merge + # inner-graphs + nominal_local_inputs = tuple( + NominalVariable(n, var.type) for n, var in enumerate(local_inputs) + ) + + fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) + + for i, inp in enumerate(fgraph.inputs): + nom_inp = nominal_local_inputs[i] + fgraph.inputs[i] = nom_inp + fgraph.clients.pop(inp, None) + fgraph.add_input(nom_inp) + + return fgraph, shared_inputs, update_d, update_expr + + class OpFromGraph(Op, HasInnerGraph): r""" This creates an `Op` from inputs and outputs lists of variables. @@ -338,76 +413,15 @@ def __init__( f"Inputs and outputs must be Variable instances; got {out}" ) - dummy_inputs = [] - for n, inp in enumerate(inputs): - if ( - not isinstance(inp, Variable) - or isinstance(inp, Constant) - or isinstance(inp, SharedVariable) - ): - raise TypeError( - f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}" - ) - - dummy_inputs.append(inp.type()) - if "updates" in kwargs or "givens" in kwargs: raise NotImplementedError("Updates and givens are not supported") self.is_inline = inline - dummy_shared_inputs = [] - self.shared_inputs = [] - for var in graph_inputs(outputs, inputs): - if isinstance(var, SharedVariable): - # To correctly support shared variables the inner-graph should - # not see them; otherwise, there will be problems with - # gradients. - # That's why we collect the shared variables and replace them - # with dummies. - self.shared_inputs.append(var) - dummy_shared_inputs.append(var.type()) - elif var not in inputs and not isinstance(var, Constant): - raise MissingInputError(f"OpFromGraph is missing an input: {var}") - - replacements = dict( - zip(inputs + self.shared_inputs, dummy_inputs + dummy_shared_inputs) + self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( + inputs, outputs ) - new = rebuild_collect_shared( - cast(Sequence[Variable], outputs), - inputs=inputs + self.shared_inputs, - replace=replacements, - copy_inputs_over=False, - ) - ( - local_inputs, - local_outputs, - (clone_d, update_d, update_expr, shared_inputs), - ) = new - - assert len(local_inputs) == len(inputs) + len(self.shared_inputs) - assert len(local_outputs) == len(outputs) - assert not update_d - assert not update_expr - assert not shared_inputs - - self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) - - # The inputs need to be `NominalVariable`s so that we can merge - # inner-graphs - nominal_local_inputs = tuple( - NominalVariable(n, var.type) for n, var in enumerate(local_inputs) - ) - - self.fgraph.replace_all(zip(local_inputs, nominal_local_inputs)) - - for i, inp in enumerate(self.fgraph.inputs): - nom_inp = nominal_local_inputs[i] - self.fgraph.inputs[i] = nom_inp - self.fgraph.clients.pop(inp, None) - self.fgraph.add_input(nom_inp) - self.kwargs = kwargs self.input_types = [inp.type for inp in inputs] self.output_types = [out.type for out in outputs] diff --git a/aesara/scan/op.py b/aesara/scan/op.py index 481eaf971d..837f8abbaf 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -55,8 +55,7 @@ import aesara from aesara import tensor as at -from aesara.compile import SharedVariable -from aesara.compile.builders import infer_shape +from aesara.compile.builders import construct_nominal_fgraph, infer_shape from aesara.compile.function.pfunc import pfunc from aesara.compile.io import In, Out from aesara.compile.mode import Mode, get_default_mode, get_mode @@ -65,17 +64,13 @@ from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from aesara.graph.basic import ( Apply, - Constant, - NominalVariable, Variable, clone_replace, equal_computations, graph_inputs, io_connection_pattern, - replace_nominals_with_dummies, ) from aesara.graph.features import NoOutputFromInplace -from aesara.graph.fg import FunctionGraph from aesara.graph.op import HasInnerGraph, Op from aesara.graph.utils import InconsistencyError, MissingInputError from aesara.link.c.basic import CLinker @@ -755,22 +750,12 @@ def __init__( If ``True``, all the shared variables used in the inner-graph must be provided. """ - inputs, outputs = replace_nominals_with_dummies(inputs, outputs) + self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs) - input_replacements = [] - for n, v in enumerate(inputs): - if not isinstance(v, (SharedVariable, Constant)): - input_replacements.append((v, NominalVariable(n, v.type))) - - assert not isinstance(v, NominalVariable) - - outputs = clone_replace(outputs, replace=input_replacements) - - if input_replacements: - _, inputs_ = zip(*input_replacements) - inputs = list(inputs_) - else: - inputs = [] + # The shared variables should have been removed, so, if there are + # any, it's because the user didn't specify an input. + if shared_inputs: + raise MissingInputError(f"Scan is missing inputs: {shared_inputs}") self.info = info self.truncate_gradient = truncate_gradient @@ -782,7 +767,7 @@ def __init__( # Clone mode_instance, altering "allow_gc" for the linker, # and adding a message if we profile if self.name: - message = self.name + " sub profile" + message = f"{self.name} sub profile" else: message = "Scan sub profile" @@ -805,7 +790,7 @@ def tensorConstructor(shape, dtype): while idx < info.n_mit_mot_outs: # Not that for mit_mot there are several output slices per # output sequence - o = outputs[idx] + o = self.fgraph.outputs[idx] self.output_types.append( # TODO: What can we actually say about the shape of this # added dimension? @@ -818,7 +803,7 @@ def tensorConstructor(shape, dtype): # mit_sot / sit_sot / nit_sot end = idx + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot - for o in outputs[idx:end]: + for o in self.fgraph.outputs[idx:end]: self.output_types.append( # TODO: What can we actually say about the shape of this # added dimension? @@ -826,7 +811,7 @@ def tensorConstructor(shape, dtype): ) # shared outputs + possibly the ending condition - for o in outputs[end:]: + for o in self.fgraph.outputs[end:]: self.output_types.append(o.type) if info.as_while: @@ -862,8 +847,6 @@ def tensorConstructor(shape, dtype): self.n_outer_inputs = info.n_outer_inputs self.n_outer_outputs = info.n_outer_outputs - self.fgraph = FunctionGraph(inputs, outputs, clone=False) - _ = self.prepare_fgraph(self.fgraph) if any(node.op.destroy_map for node in self.fgraph.apply_nodes): @@ -871,10 +854,6 @@ def tensorConstructor(shape, dtype): "Inner-graphs must not contain in-place operations." ) - # Do the missing inputs check here to have the error early. - for var in graph_inputs(self.inner_outputs, self.inner_inputs): - if var not in self.inner_inputs and not isinstance(var, Constant): - raise MissingInputError(f"ScanOp is missing an input: {repr(var)}") self._cmodule_key = CLinker().cmodule_key_variables( self.inner_inputs, self.inner_outputs, [] ) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index b4af67d572..a3d022ba27 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -586,10 +586,6 @@ def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W): assert np.allclose(aesara_values, v_out) def test_oinp_iinp_iout_oout_mappings(self): - """ - Test the mapping produces by - ScanOp.get_oinp_iinp_iout_oout_mappings() - """ rng = RandomStream(123)