@@ -234,7 +234,7 @@ def _forward_unimplemented(self, *input: Any) -> None:
234
234
"""
235
235
forward : Callable [..., Any ] = _forward_unimplemented
236
236
237
- def register_buffer (self , name : str , tensor : Tensor , persistent : bool = True ) -> None :
237
+ def register_buffer (self , name : str , tensor : Optional [ Tensor ] , persistent : bool = True ) -> None :
238
238
r"""Adds a buffer to the module.
239
239
240
240
This is typically used to register a buffer that should not to be
@@ -286,7 +286,7 @@ def register_buffer(self, name: str, tensor: Tensor, persistent: bool = True) ->
286
286
else :
287
287
self ._non_persistent_buffers_set .add (name )
288
288
289
- def register_parameter (self , name : str , param : Parameter ) -> None :
289
+ def register_parameter (self , name : str , param : Optional [ Parameter ] ) -> None :
290
290
r"""Adds a parameter to the module.
291
291
292
292
The parameter can be accessed as an attribute using given name.
@@ -325,7 +325,7 @@ def register_parameter(self, name: str, param: Parameter) -> None:
325
325
else :
326
326
self ._parameters [name ] = param
327
327
328
- def add_module (self , name : str , module : 'Module' ) -> None :
328
+ def add_module (self , name : str , module : Optional [ 'Module' ] ) -> None :
329
329
r"""Adds a child module to the current module.
330
330
331
331
The module can be accessed as an attribute using the given name.
0 commit comments