-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_lincs.py
93 lines (70 loc) · 2.85 KB
/
test_lincs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
# @Author: Xiaoning Qi
# @Date: 2022-06-13 09:47:44
# @Last Modified by: Xiaoning Qi
# @Last Modified time: 2024-10-31 15:27:11
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
print(sys.path)
import argparse
from datetime import datetime
import scanpy as sc
from trainer.PRnetTrainer import PRnetTrainer
def parse_args():
parse = argparse.ArgumentParser(description='perturbation-conditioned generative model ')
parse.add_argument('--split_key', default='lincs_split', type=str, help='split key of data')
args = parse.parse_args()
return args
if __name__ == "__main__":
args_train = parse_args()
start_time = datetime.now()
config_kwargs = {
'batch_size' : 512,
'comb_num' : 1,
'save_dir' : './checkpoint/',
'results_dir' : './results/lincs/',
'n_epochs' : 100,
'split_key' : args_train.split_key,
'x_dimension' : 978,
'hidden_layer_sizes' : [128],
'z_dimension' : 64,
'adaptor_layer_sizes' : [128],
'comb_dimension' : 64,
'drug_dimension': 1024,
'dr_rate' : 0.05,
'n_epochs' : 100,
'lr' : 1e-3,
'weight_decay' : 1e-8,
'scheduler_factor' : 0.5,
'scheduler_patience' : 5,
'n_genes' : 20,
'loss' : ['GUSS'],
'obs_key' : 'cov_drug_name'
}
print(os.getcwd())
adata = sc.read('./dataset/Lincs_L1000.h5ad')
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
Trainer = PRnetTrainer(
adata,
batch_size=config_kwargs['batch_size'],
comb_num=config_kwargs['comb_num'],
split_key=config_kwargs['split_key'],
model_save_dir=config_kwargs['save_dir'],
results_save_dir=config_kwargs['results_dir'],
x_dimension=config_kwargs['x_dimension'],
hidden_layer_sizes=config_kwargs['hidden_layer_sizes'],
z_dimension=config_kwargs['z_dimension'],
adaptor_layer_sizes=config_kwargs['adaptor_layer_sizes'],
comb_dimension=config_kwargs['comb_dimension'],
drug_dimension=config_kwargs['drug_dimension'],
dr_rate=config_kwargs['dr_rate'],
n_genes=config_kwargs['n_genes'],
loss = config_kwargs['loss'],
obs_key = config_kwargs['obs_key']
)
Trainer.test('./checkpoint/lincs_best_epoch_all.pt')
end_time = datetime.now()
during_time = (end_time-start_time).seconds/60
print(f'start time: {start_time} end_time: {end_time} time:{during_time} min')