Christina Theodoris
commited on
Commit
·
3d06203
1
Parent(s):
3072225
Correct order of state dict in in silico perturber stats and tensor dims of alt state emb in in silico perturber
Browse files
geneformer/in_silico_perturber.py
CHANGED
|
@@ -266,7 +266,6 @@ def quant_cos_sims(model,
|
|
| 266 |
def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
|
| 267 |
cos = torch.nn.CosineSimilarity(dim=2)
|
| 268 |
original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
|
| 269 |
-
alt_emb = alt_emb[None, None, :]
|
| 270 |
origin_v_end = cos(original_emb,alt_emb)
|
| 271 |
perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
|
| 272 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
|
@@ -483,7 +482,7 @@ class InSilicoPerturber:
|
|
| 483 |
"only outputs effect on cell embeddings.")
|
| 484 |
|
| 485 |
if self.cell_states_to_model is not None:
|
| 486 |
-
if
|
| 487 |
for key,value in self.cell_states_to_model.items():
|
| 488 |
if (len(value) == 3) and isinstance(value, tuple):
|
| 489 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
|
|
| 266 |
def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
|
| 267 |
cos = torch.nn.CosineSimilarity(dim=2)
|
| 268 |
original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
|
|
|
|
| 269 |
origin_v_end = cos(original_emb,alt_emb)
|
| 270 |
perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
|
| 271 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
|
|
|
| 482 |
"only outputs effect on cell embeddings.")
|
| 483 |
|
| 484 |
if self.cell_states_to_model is not None:
|
| 485 |
+
if len(self.cell_states_to_model.items()) == 1:
|
| 486 |
for key,value in self.cell_states_to_model.items():
|
| 487 |
if (len(value) == 3) and isinstance(value, tuple):
|
| 488 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
geneformer/in_silico_perturber_stats.py
CHANGED
|
@@ -108,9 +108,10 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
| 108 |
|
| 109 |
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
| 110 |
def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
| 111 |
-
|
|
|
|
| 112 |
alt_end_state_exists = False
|
| 113 |
-
elif (len(cell_states_to_model[
|
| 114 |
alt_end_state_exists = True
|
| 115 |
|
| 116 |
random_tuples = []
|
|
@@ -120,20 +121,15 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
|
| 120 |
random_tuples += dict_i.get((token, "cell_emb"),[])
|
| 121 |
|
| 122 |
if alt_end_state_exists == False:
|
| 123 |
-
goal_end_random_megalist = [goal_end for goal_end
|
| 124 |
-
start_state_random_megalist = [start_state for goal_end,start_state in random_tuples]
|
| 125 |
elif alt_end_state_exists == True:
|
| 126 |
-
goal_end_random_megalist = [goal_end for goal_end,alt_end
|
| 127 |
-
alt_end_random_megalist = [alt_end for goal_end,alt_end
|
| 128 |
-
start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
|
| 129 |
|
| 130 |
# downsample to improve speed of ranksums
|
| 131 |
if len(goal_end_random_megalist) > 100_000:
|
| 132 |
random.seed(42)
|
| 133 |
goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
|
| 134 |
-
if len(start_state_random_megalist) > 100_000:
|
| 135 |
-
random.seed(42)
|
| 136 |
-
start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
|
| 137 |
if alt_end_state_exists == True:
|
| 138 |
if len(alt_end_random_megalist) > 100_000:
|
| 139 |
random.seed(42)
|
|
@@ -161,10 +157,10 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
|
| 161 |
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
| 162 |
|
| 163 |
if alt_end_state_exists == False:
|
| 164 |
-
goal_end_cos_sim_megalist = [goal_end for goal_end
|
| 165 |
elif alt_end_state_exists == True:
|
| 166 |
-
goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end
|
| 167 |
-
alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end
|
| 168 |
mean_alt_end = np.mean(alt_end_cos_sim_megalist)
|
| 169 |
pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
|
| 170 |
|
|
@@ -451,7 +447,7 @@ class InSilicoPerturberStats:
|
|
| 451 |
raise
|
| 452 |
|
| 453 |
if self.cell_states_to_model is not None:
|
| 454 |
-
if
|
| 455 |
for key,value in self.cell_states_to_model.items():
|
| 456 |
if (len(value) == 3) and isinstance(value, tuple):
|
| 457 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
|
|
| 108 |
|
| 109 |
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
| 110 |
def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
| 111 |
+
cell_state_key = list(cell_states_to_model.keys())[0]
|
| 112 |
+
if cell_states_to_model[cell_state_key][2] == []:
|
| 113 |
alt_end_state_exists = False
|
| 114 |
+
elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
|
| 115 |
alt_end_state_exists = True
|
| 116 |
|
| 117 |
random_tuples = []
|
|
|
|
| 121 |
random_tuples += dict_i.get((token, "cell_emb"),[])
|
| 122 |
|
| 123 |
if alt_end_state_exists == False:
|
| 124 |
+
goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
|
|
|
|
| 125 |
elif alt_end_state_exists == True:
|
| 126 |
+
goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
|
| 127 |
+
alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
|
|
|
|
| 128 |
|
| 129 |
# downsample to improve speed of ranksums
|
| 130 |
if len(goal_end_random_megalist) > 100_000:
|
| 131 |
random.seed(42)
|
| 132 |
goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
|
|
|
|
|
|
|
|
|
|
| 133 |
if alt_end_state_exists == True:
|
| 134 |
if len(alt_end_random_megalist) > 100_000:
|
| 135 |
random.seed(42)
|
|
|
|
| 157 |
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
| 158 |
|
| 159 |
if alt_end_state_exists == False:
|
| 160 |
+
goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
|
| 161 |
elif alt_end_state_exists == True:
|
| 162 |
+
goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
|
| 163 |
+
alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
|
| 164 |
mean_alt_end = np.mean(alt_end_cos_sim_megalist)
|
| 165 |
pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
|
| 166 |
|
|
|
|
| 447 |
raise
|
| 448 |
|
| 449 |
if self.cell_states_to_model is not None:
|
| 450 |
+
if len(self.cell_states_to_model.items()) == 1:
|
| 451 |
for key,value in self.cell_states_to_model.items():
|
| 452 |
if (len(value) == 3) and isinstance(value, tuple):
|
| 453 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|