Skip to content

Commit e46fa74

Browse files
committed
fixed more backend issues
1 parent cb0aca6 commit e46fa74

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

examples/playground/run_backend.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
import pygrank as pg
22
import torch
3+
from timeit import default_timer as time
34

45
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56

7+
_, graph, community = next(pg.load_datasets_one_community(["youtube"], graph_api=pg, min_group_size=50))
8+
print(f"Nodes {len(graph)}, edges {graph.number_of_edges()}")
9+
10+
ppr = pg.HeatKernel(
11+
normalization="symmetric",
12+
assume_immutability=True
13+
)
14+
signal = pg.to_signal(graph, {node: 1.0 for node in community})
15+
preprocessor = ppr.preprocessor
16+
#ppr = pg.ParameterTuner(preprocessor=preprocessor)
17+
"""
18+
with pg.Backend("numpy"):
19+
preprocessor(graph)
20+
torch.cuda.synchronize() # correct timing
21+
tic = time()
22+
scores = ppr(signal)
23+
print("numpy", ppr.convergence, "actual time", time()-tic)"""
624

725
with pg.Backend("torch_sparse", device=device):
8-
_, graph, community = next(pg.load_datasets_one_community(["amazon"]))
9-
ppr = pg.PageRank(
10-
alpha=0.9,
11-
normalization="symmetric",
12-
assume_immutability=True,
13-
convergence=pg.ConvergenceManager(max_iters=38, error_type="iters"),
14-
)
15-
ppr.preprocessor(graph)
16-
signal = pg.to_signal(graph, {node: 1.0 for node in community})
26+
preprocessor(graph)
1727
torch.cuda.synchronize() # correct timing
28+
tic = time()
1829
scores = ppr(signal)
19-
print(ppr.convergence)
20-
print(scores["B00005MHUG"]) # 0.00508212111890316
21-
print(scores["B00006RGI2"]) # 0.70645672082901
22-
print(scores["0006497993"]) # 0.19633759558200836
30+
print("torch_sparse", ppr.convergence, "actual time", time()-tic)

pygrank/algorithms/postprocess/postprocess.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ class Sweep(Postprocessor):
375375
Applies a sweep procedure that divides personalized node ranks by corresponding non-personalized ones.
376376
"""
377377

378-
def __init__(self, ranker: NodeRanking = None, uniform_ranker: NodeRanking = None):
378+
def __init__(self, ranker: NodeRanking = None, uniform_ranker: NodeRanking = None, assume_immutability: bool = True):
379379
"""
380380
Initializes the sweep procedure.
381381
@@ -404,7 +404,7 @@ def __init__(self, ranker: NodeRanking = None, uniform_ranker: NodeRanking = Non
404404
super().__init__(ranker)
405405
self.uniform_ranker = ranker if uniform_ranker is None else uniform_ranker
406406
self.centrality = MethodHasher(
407-
lambda graph: self.uniform_ranker.rank(graph), assume_immutability=True
407+
lambda graph: self.uniform_ranker.rank(graph), assume_immutability=assume_immutability
408408
)
409409

410410
def _transform(self, ranks: GraphSignal, **kwargs):

pygrank/core/backend/pytorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def backend_init(mode="dense", device=None):
4141
return
4242
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4343
warnings.warn(
44-
f"[pygrank.backend.pytorch] Automatically detected device to run on {device}: {torch.cuda.get_device_name(device)}"
44+
f"[pygrank.backend.pytorch] Automatically detected device to run on {device}: {torch.get_device(device)}"
4545
)
4646
if device is not None and isinstance(device, str):
4747
device = torch.device(device)

pygrank/core/backend/torch_sparse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def backend_init(device="auto"):
6161
return
6262
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6363
warnings.warn(
64-
f"[pygrank.backend.torch_sparse] Automatically detected device to run on {device}: {torch.cuda.get_device_name(device)}"
64+
f"[pygrank.backend.torch_sparse] Automatically detected device to run on {device}: {torch.device(device)}"
6565
)
6666
if device is not None and isinstance(device, str):
6767
device = torch.device(device)

0 commit comments

Comments
 (0)