@@ -67,14 +67,24 @@ def __init__(self, config):
67
67
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
68
68
# which might result in infinite recursion (#3506)
69
69
self ._writing_pyc = False
70
+ self ._basenames_to_check_rewrite = {"conftest" }
71
+ self ._marked_for_rewrite_cache = {}
72
+ self ._session_paths_checked = False
70
73
71
74
def set_session (self , session ):
72
75
self .session = session
76
+ self ._session_paths_checked = False
77
+
78
+ def _imp_find_module (self , name , path = None ):
79
+ """Indirection so we can mock calls to find_module originated from the hook during testing"""
80
+ return imp .find_module (name , path )
73
81
74
82
def find_module (self , name , path = None ):
75
83
if self ._writing_pyc :
76
84
return None
77
85
state = self .config ._assertstate
86
+ if self ._early_rewrite_bailout (name , state ):
87
+ return None
78
88
state .trace ("find_module called for: %s" % name )
79
89
names = name .rsplit ("." , 1 )
80
90
lastname = names [- 1 ]
@@ -87,7 +97,7 @@ def find_module(self, name, path=None):
87
97
pth = path [0 ]
88
98
if pth is None :
89
99
try :
90
- fd , fn , desc = imp . find_module (lastname , path )
100
+ fd , fn , desc = self . _imp_find_module (lastname , path )
91
101
except ImportError :
92
102
return None
93
103
if fd is not None :
@@ -166,6 +176,44 @@ def find_module(self, name, path=None):
166
176
self .modules [name ] = co , pyc
167
177
return self
168
178
179
+ def _early_rewrite_bailout (self , name , state ):
180
+ """
181
+ This is a fast way to get out of rewriting modules. Profiling has
182
+ shown that the call to imp.find_module (inside of the find_module
183
+ from this class) is a major slowdown, so, this method tries to
184
+ filter what we're sure won't be rewritten before getting to it.
185
+ """
186
+ if self .session is not None and not self ._session_paths_checked :
187
+ self ._session_paths_checked = True
188
+ for path in self .session ._initialpaths :
189
+ # Make something as c:/projects/my_project/path.py ->
190
+ # ['c:', 'projects', 'my_project', 'path.py']
191
+ parts = str (path ).split (os .path .sep )
192
+ # add 'path' to basenames to be checked.
193
+ self ._basenames_to_check_rewrite .add (os .path .splitext (parts [- 1 ])[0 ])
194
+
195
+ # Note: conftest already by default in _basenames_to_check_rewrite.
196
+ parts = name .split ("." )
197
+ if parts [- 1 ] in self ._basenames_to_check_rewrite :
198
+ return False
199
+
200
+ # For matching the name it must be as if it was a filename.
201
+ parts [- 1 ] = parts [- 1 ] + ".py"
202
+ fn_pypath = py .path .local (os .path .sep .join (parts ))
203
+ for pat in self .fnpats :
204
+ # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
205
+ # on the name alone because we need to match against the full path
206
+ if os .path .dirname (pat ):
207
+ return False
208
+ if fn_pypath .fnmatch (pat ):
209
+ return False
210
+
211
+ if self ._is_marked_for_rewrite (name , state ):
212
+ return False
213
+
214
+ state .trace ("early skip of rewriting module: %s" % (name ,))
215
+ return True
216
+
169
217
def _should_rewrite (self , name , fn_pypath , state ):
170
218
# always rewrite conftest files
171
219
fn = str (fn_pypath )
@@ -185,12 +233,20 @@ def _should_rewrite(self, name, fn_pypath, state):
185
233
state .trace ("matched test file %r" % (fn ,))
186
234
return True
187
235
188
- for marked in self ._must_rewrite :
189
- if name == marked or name .startswith (marked + "." ):
190
- state .trace ("matched marked file %r (from %r)" % (name , marked ))
191
- return True
236
+ return self ._is_marked_for_rewrite (name , state )
192
237
193
- return False
238
+ def _is_marked_for_rewrite (self , name , state ):
239
+ try :
240
+ return self ._marked_for_rewrite_cache [name ]
241
+ except KeyError :
242
+ for marked in self ._must_rewrite :
243
+ if name == marked or name .startswith (marked + "." ):
244
+ state .trace ("matched marked file %r (from %r)" % (name , marked ))
245
+ self ._marked_for_rewrite_cache [name ] = True
246
+ return True
247
+
248
+ self ._marked_for_rewrite_cache [name ] = False
249
+ return False
194
250
195
251
def mark_rewrite (self , * names ):
196
252
"""Mark import names as needing to be rewritten.
@@ -207,6 +263,7 @@ def mark_rewrite(self, *names):
207
263
):
208
264
self ._warn_already_imported (name )
209
265
self ._must_rewrite .update (names )
266
+ self ._marked_for_rewrite_cache .clear ()
210
267
211
268
def _warn_already_imported (self , name ):
212
269
self .config .warn (
@@ -241,7 +298,7 @@ def load_module(self, name):
241
298
242
299
def is_package (self , name ):
243
300
try :
244
- fd , fn , desc = imp . find_module (name )
301
+ fd , fn , desc = self . _imp_find_module (name )
245
302
except ImportError :
246
303
return False
247
304
if fd is not None :
0 commit comments