1- import os
2- import sys
3-
4- BASE_DIR = os .path .join (os .path .dirname (__file__ ), '..' )
5- sys .path .append (BASE_DIR )
6- import math
7-
81import numpy as np
92import torch
103import torch .autograd as autograd
114import torch .nn as nn
12- import torch .nn .functional as F
13-
14- from causallearn .utils .KCI .KCI import KCI_UInd
15- import torch .autograd as autograd
16- import matplotlib .pyplot as plt
5+ from scipy import stats
176
187
198class MLP (nn .Module ):
@@ -38,7 +27,7 @@ def __init__(self, n_inputs, n_outputs, n_layers=1, n_units=100):
3827
3928 # create layers
4029 layers = [nn .Linear (n_inputs , n_units )]
41- for i in range (n_layers ):
30+ for _ in range (n_layers ):
4231 layers .append (nn .ReLU ())
4332 layers .append (nn .Linear (n_units , n_units ))
4433 layers .append (nn .ReLU ())
@@ -49,62 +38,52 @@ def forward(self, x):
4938 x = self .layers (x )
5039 return x
5140
52-
53- class MixGaussianLayer (nn .Module ):
54- def __init__ (self , Mix_K = 3 ):
55- super (MixGaussianLayer , self ).__init__ ()
56- self .Mix_K = Mix_K
57- self .Pi = nn .Parameter (torch .randn (self .Mix_K , 1 ))
58- self .Mu = nn .Parameter (torch .randn (self .Mix_K , 1 ))
59- self .Var = nn .Parameter (torch .randn (self .Mix_K , 1 ))
60-
61- def forward (self , x ):
62- Constraint_Pi = F .softmax (self .Pi , 0 )
63- # -(x-u)**2/(2var**2)
64- Middle1 = - ((x .expand (len (x ), self .Mix_K ) - self .Mu .T .expand (len (x ), self .Mix_K )).pow (2 )).div (
65- 2 * (self .Var .T .expand (len (x ), self .Mix_K )).pow (2 ))
66- # sum Pi*Middle/var
67- Middle2 = torch .exp (Middle1 ).mm (Constraint_Pi .div (torch .sqrt (2 * math .pi * self .Var .pow (2 ))))
68- # log sum
69- out = sum (torch .log (Middle2 ))
70-
71- return out
72-
73-
7441class PNL (object ):
7542 """
7643 Use of constrained nonlinear ICA for distinguishing cause from effect.
7744 Python Version 3.7
7845 PURPOSE:
7946 To find which one of xi (i=1,2) is the cause. In particular, this
8047 function does
81- 1) preprocessing to make xi rather clear to Gaussian,
48+ 1) preprocessing to make xi rather close to Gaussian,
8249 2) learn the corresponding 'disturbance' under each assumed causal
8350 direction, and
8451 3) performs the independence tests to see if the assumed cause if
8552 independent from the learned disturbance.
8653 """
8754
88- def __init__ (self , kernelX = 'Gaussian' , kernelY = 'Gaussian' , mix_K = 3 , epochs = 100000 ):
55+ def __init__ (self , epochs = 100000 ):
8956 '''
90- Construct the ANM model.
57+ Construct the PNL model.
9158
9259 Parameters:
9360 ----------
94- kernelX: kernel function for hypothetical cause
95- kernelY: kernel function for estimated noise
96- mix_K: number of Gaussian mixtures for independent components
9761 epochs: training epochs.
9862 '''
99- self .kernelX = kernelX
100- self .kernelY = kernelY
101- self .mix_K = mix_K
63+
10264 self .epochs = epochs
65+
66+ def dele_abnormal (self , data_x , data_y ):
67+
68+ mean_x = np .mean (data_x )
69+ sigma_x = np .std (data_x )
70+ remove_idx_x = np .where (abs (data_x - mean_x ) > 3 * sigma_x )[0 ]
71+
72+ mean_y = np .mean (data_y )
73+ sigma_y = np .std (data_y )
74+ remove_idx_y = np .where (abs (data_y - mean_y ) > 3 * sigma_y )[0 ]
75+
76+ remove_idx = np .append (remove_idx_x , remove_idx_y )
10377
104- def nica_mnd (self , X , TotalEpoch , KofMix ):
78+ data_x = np .delete (data_x , remove_idx )
79+ data_y = np .delete (data_y , remove_idx )
80+
81+ return data_x .reshape (len (data_x ), 1 ), data_y .reshape (len (data_y ), 1 )
82+
83+ def nica_mnd (self , X , TotalEpoch ):
10584 """
106- Use of "Nonlinear ICA with MND for Matlab " for distinguishing cause from effect
107- PURPOSE: Performing nonlinear ICA with the minimal nonlinear distortion or smoothness regularization .
85+ Use of "Nonlinear ICA" for distinguishing cause from effect
86+ PURPOSE: Performing nonlinear ICA.
10887
10988 Parameters
11089 ----------
@@ -115,56 +94,54 @@ def nica_mnd(self, X, TotalEpoch, KofMix):
11594 Y (n*T): the separation result.
11695 """
11796 trpattern = X .T
118- trpattern = trpattern - np .tile (np .mean (trpattern , axis = 1 ).reshape (2 , 1 ), (1 , len (trpattern [0 ])))
119- trpattern = np .dot (np .diag (1.5 / np .std (trpattern , axis = 1 )), trpattern )
97+
12098 # --------------------------------------------------------
12199 x1 = torch .from_numpy (trpattern [0 , :]).type (torch .FloatTensor ).reshape (- 1 , 1 )
122100 x2 = torch .from_numpy (trpattern [1 , :]).type (torch .FloatTensor ).reshape (- 1 , 1 )
123101 x1 .requires_grad = True
124102 x2 .requires_grad = True
103+
125104 y1 = x1
126105
127106 Final_y2 = x2
128107 Min_loss = float ('inf' )
129108
130- G1 = MLP (1 , 1 , n_layers = 1 , n_units = 20 )
131- G2 = MLP (1 , 1 , n_layers = 1 , n_units = 20 )
132- # MixGaussian = MixGaussianLayer(Mix_K=KofMix)
133- G3 = MLP (1 , 1 , n_layers = 1 , n_units = 20 )
109+ G1 = MLP (1 , 1 , n_layers = 3 , n_units = 12 )
110+ G2 = MLP (1 , 1 , n_layers = 1 , n_units = 12 )
134111 optimizer = torch .optim .Adam ([
135112 {'params' : G1 .parameters ()},
136- {'params' : G2 .parameters ()},
137- {'params' : G3 .parameters ()}], lr = 1e-4 , betas = (0.9 , 0.99 ))
113+ {'params' : G2 .parameters ()}], lr = 1e-5 , betas = (0.9 , 0.99 ))
114+
115+ loss_all = torch .zeros (TotalEpoch )
116+ loss_pdf_all = torch .zeros (TotalEpoch )
117+ loss_jacob_all = torch .zeros (TotalEpoch )
138118
139119 for epoch in range (TotalEpoch ):
120+ G1 .zero_grad ()
121+ G2 .zero_grad ()
140122
141123 y2 = G2 (x2 ) - G1 (x1 )
142- # y2 = x2 - G1(x1)
143124
144- loss_pdf = torch .sum (( y2 ) ** 2 )
125+ loss_pdf = 0.5 * torch .sum (y2 ** 2 )
145126
146- jacob = autograd .grad (outputs = G2 ( x2 ) , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
127+ jacob = autograd .grad (outputs = y2 , inputs = x2 , grad_outputs = torch .ones (y2 .shape ), create_graph = True ,
147128 retain_graph = True , only_inputs = True )[0 ]
129+
148130 loss_jacob = - torch .sum (torch .log (torch .abs (jacob ) + 1e-16 ))
149131
150132 loss = loss_jacob + loss_pdf
151133
134+ loss_all [epoch ] = loss
135+ loss_pdf_all [epoch ] = loss_pdf
136+ loss_jacob_all [epoch ] = loss_jacob
137+
152138 if loss < Min_loss :
153139 Min_loss = loss
154140 Final_y2 = y2
155-
156- if epoch % 100 == 0 :
157- print ('---------------------------{}-th epoch-------------------------------' .format (epoch ))
158- print ('The Total loss is {}' .format (loss ))
159- print ('The jacob loss is {}' .format (loss_jacob ))
160- print ('The pdf loss is {}' .format (loss_pdf ))
161-
162- optimizer .zero_grad ()
163- loss .backward (retain_graph = True )
141+
142+ loss .backward ()
164143 optimizer .step ()
165- plt .plot (x1 .detach ().numpy (), G1 (x1 ).detach ().numpy (), '.' )
166- plt .plot (x2 .detach ().numpy (), G2 (x2 ).detach ().numpy (),'.' )
167- plt .show ()
144+
168145 return y1 , Final_y2
169146
170147 def cause_or_effect (self , data_x , data_y ):
@@ -181,28 +158,28 @@ def cause_or_effect(self, data_x, data_y):
181158 pval_forward: p value in the x->y direction
182159 pval_backward: p value in the y->x direction
183160 '''
161+ torch .manual_seed (0 )
184162
185- raise SyntaxError ('There are some potential issues in the current implementation of PNL. We are working on them and will update as soon as possible.' )
186-
187- kci = KCI_UInd (self .kernelX , self .kernelY )
163+ # Delete the abnormal samples
164+ data_x , data_y = self .dele_abnormal (data_x , data_y )
188165
189166 # Now let's see if x1 -> x2 is plausible
190167 data = np .concatenate ((data_x , data_y ), axis = 1 )
191- y1 , y2 = self . nica_mnd ( data , self . epochs , self . mix_K )
192- print ( 'To see if x1 -> x2...' )
168+ # print('To see if x1 -> x2...' )
169+ y1 , y2 = self . nica_mnd ( data , self . epochs )
193170
194171 y1_np = y1 .detach ().numpy ()
195172 y2_np = y2 .detach ().numpy ()
196173
197- pval_foward , _ = kci . compute_pvalue (y1_np , y2_np )
174+ _ , pval_forward = stats . ttest_ind (y1_np , y2_np )
198175
199176 # Now let's see if x2 -> x1 is plausible
200- y1 , y2 = self . nica_mnd ( data [:, [ 1 , 0 ]], self . epochs , self . mix_K )
201- print ( 'To see if x2 -> x1...' )
202-
177+ # print('To see if x2 -> x1...' )
178+ y1 , y2 = self . nica_mnd ( data [:, [ 1 , 0 ]], self . epochs )
179+
203180 y1_np = y1 .detach ().numpy ()
204181 y2_np = y2 .detach ().numpy ()
205182
206- pval_backward , _ = kci . compute_pvalue (y1_np , y2_np )
207-
208- return pval_foward , pval_backward
183+ _ , pval_backward = stats . ttest_ind (y1_np , y2_np )
184+
185+ return np . round ( pval_forward , 3 ), np . round ( pval_backward , 3 )
0 commit comments