Skip to content

Commit a2d74dd

Browse files
authored
Update README for new API and change factors_list to all_factors
Updated README.md and minor edits to file directory , imports, and parameter lists
2 parents 745dc0c + eede694 commit a2d74dd

20 files changed

+136
-137
lines changed

README.md

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ pip install pywhyllm
2525
PyWhy-LLM seamlessly integrates into your existing causal inference process. Import the necessary classes and start exploring the power of LLM-augmented causal analysis.
2626

2727
```python
28-
from pywhyllm import ModelSuggester, IdentificationSuggester, ValidationSuggester
28+
from pywhyllm.suggesters.model_suggester import ModelSuggester
29+
from pywhyllm.suggesters.identification_suggester import IdentificationSuggester
30+
from pywhyllm.suggesters.validation_suggester import ValidationSuggester
31+
from pywhyllm import RelationshipStrategy
2932

3033
```
3134

@@ -34,17 +37,20 @@ from pywhyllm import ModelSuggester, IdentificationSuggester, ValidationSuggeste
3437

3538
```python
3639
# Create instance of Modeler
37-
modeler = Modeler()
40+
modeler = ModelSuggester('gpt-4')
41+
42+
all_factors = ["smoking", "lung cancer", "exercise habits", "air pollution exposure"]
43+
treatment = "smoking"
44+
outcome = "lung cancer"
45+
46+
# Suggest a list of domain expertises
47+
domain_expertises = modeler.suggest_domain_expertises(all_factors)
3848

3949
# Suggest a set of potential confounders
40-
suggested_confounders = modeler.suggest_confounders(variables=_variables, treatment=treatment, outcome=outcome, llm=gpt4)
50+
suggested_confounders = modeler.suggest_confounders(treatment, outcome, all_factors, domain_expertises)
4151

4252
# Suggest pair-wise relationship between variables
43-
suggested_dag = modeler.suggest_relationships(variables=selected_variables, llm=gpt4)
44-
45-
plt.figure(figsize=(10, 5))
46-
nx.draw(suggested_dag, with_labels=True, node_color='lightblue')
47-
plt.show()
53+
suggested_dag = modeler.suggest_relationships(treatment, outcome, all_factors, domain_expertises, RelationshipStrategy.Pairwise)
4854
```
4955

5056

@@ -54,15 +60,13 @@ plt.show()
5460

5561
```python
5662
# Create instance of Identifier
57-
identifier = Identifier()
63+
identifier = IdentificationSuggester('gpt-4')
5864

59-
# Suggest a backdoor set, front door set, and iv set
60-
suggested_backdoor = identifier.suggest_backdoor(llm=gpt4, treatment=treatment, outcome=outcome, confounders=suggested_confounders)
61-
suggested_frontdoor = identifier.suggest_frontdoor(llm=gpt4, treatment=treatment, outcome=outcome, confounders=suggested_confounders)
62-
suggested_iv = identifier.suggest_iv(llm=gpt4, treatment=treatment, outcome=outcome, confounders=suggested_confounders)
65+
# Suggest a backdoor set, mediator set, and iv set
66+
suggested_backdoor = identifier.suggest_backdoor(treatment, outcome, all_factors, domain_expertises)
67+
suggested_mediators = identifier.suggest_mediators(treatment, outcome, all_factors, domain_expertises)
68+
suggested_iv = identifier.suggest_ivs(treatment, outcome, all_factors, domain_expertises)
6369

64-
# Suggest an estimand based on the suggester backdoor set, front door set, and iv set
65-
estimand = identifier.suggest_estimand(confounders=suggested_confounders, treatment=treatment, outcome=outcome, backdoor=suggested_backdoor, frontdoor=suggested_frontdoor, iv=suggested_iv, llm=gpt4)
6670
```
6771

6872

@@ -72,20 +76,16 @@ estimand = identifier.suggest_estimand(confounders=suggested_confounders, treatm
7276

7377
```python
7478
# Create instance of Validator
75-
validator = Validator()
79+
validator = ValidationSuggester('gpt-4')
7680

77-
# Suggest a critique of the provided DAG
78-
suggested_critiques_dag = validator.critique_graph(graph=suggested_dag, llm=gpt4)
81+
# Suggest a critique of the edges in provided DAG
82+
suggested_critiques_dag = validator.critique_graph(all_factors, suggested_dag, domain_expertises, RelationshipStrategy.Pairwise)
7983

8084
# Suggest latent confounders
81-
suggested_latent_confounders = validator.suggest_latent_confounders(treatment=treatment, outcome=outcome, llm=gpt4)
85+
suggested_latent_confounders = validator.suggest_latent_confounders(treatment, outcome, all_factors, domain_expertises)
8286

8387
# Suggest negative controls
84-
suggested_negative_controls = validator.suggest_negative_controls(variables=selected_variables, treatment=treatment, outcome=outcome, llm=gpt4)
85-
86-
plt.figure(figsize=(10, 5))
87-
nx.draw(suggested_critiques_dag, with_labels=True, node_color='lightblue')
88-
plt.show()
88+
suggested_negative_controls = validator.suggest_negative_controls(treatment, outcome, all_factors, domain_expertises)
8989

9090
```
9191

pywhyllm/suggesters/identification_suggester.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, llm=None):
1919
# self,
2020
# treatment: str,
2121
# outcome: str,
22-
# factors_list: list(),
22+
# all_factors: list(),
2323
# llm: guidance.models,
2424
# backdoor: Set[str] = None,
2525
# frontdoor: Set[str] = None,
@@ -41,7 +41,7 @@ def __init__(self, llm=None):
4141
# backdoor_edges, backdoor_set = self.suggest_backdoor(
4242
# treatment=treatment,
4343
# outcome=outcome,
44-
# factors_list=factors_list,
44+
# all_factors=all_factors,
4545
# llm=llm,
4646
# experts=experts,
4747
# analysis_context=analysis_context,
@@ -66,7 +66,7 @@ def __init__(self, llm=None):
6666
# frontdoor_edges, frontdoor_set = self.suggest_frontdoor(
6767
# treatment=treatment,
6868
# outcome=outcome,
69-
# factors_list=factors_list,
69+
# all_factors=all_factors,
7070
# llm=llm,
7171
# experts=experts,
7272
# analysis_context=analysis_context,
@@ -87,7 +87,7 @@ def __init__(self, llm=None):
8787
# ivs_edges, ivs_set = self.suggest_ivs(
8888
# treatment=treatment,
8989
# outcome=outcome,
90-
# factors_list=factors_list,
90+
# all_factors=all_factors,
9191
# llm=llm,
9292
# experts=experts,
9393
# analysis_context=analysis_context,
@@ -116,15 +116,15 @@ def suggest_backdoor(
116116
self,
117117
treatment: str,
118118
outcome: str,
119-
factors_list: list(),
119+
all_factors: list(),
120120
expertise_list: list(),
121-
analysis_context=CONTEXT,
121+
analysis_context: str = CONTEXT,
122122
stakeholders: list() = None
123123
):
124124
backdoor_set = self.model_suggester.suggest_confounders(
125125
treatment=treatment,
126126
outcome=outcome,
127-
factors_list=factors_list,
127+
all_factors=all_factors,
128128
expertise_list=expertise_list,
129129
analysis_context=analysis_context,
130130
stakeholders=stakeholders
@@ -136,9 +136,9 @@ def suggest_frontdoor(
136136
self,
137137
treatment: str,
138138
outcome: str,
139-
factors_list: list(),
139+
all_factors: list(),
140140
expertise_list: list(),
141-
analysis_context=CONTEXT,
141+
analysis_context: str = CONTEXT,
142142
stakeholders: list() = None
143143
):
144144
pass
@@ -147,9 +147,9 @@ def suggest_mediators(
147147
self,
148148
treatment: str,
149149
outcome: str,
150-
factors_list: list(),
150+
all_factors: list(),
151151
expertise_list: list(),
152-
analysis_context=CONTEXT,
152+
analysis_context: str = CONTEXT,
153153
stakeholders: list() = None
154154
):
155155
expert_list: List[str] = list()
@@ -164,16 +164,16 @@ def suggest_mediators(
164164
mediators_edges[(treatment, outcome)] = 1
165165

166166
edited_factors_list: List[str] = []
167-
for i in range(len(factors_list)):
168-
if factors_list[i] != treatment and factors_list[i] != outcome:
169-
edited_factors_list.append(factors_list[i])
167+
for i in range(len(all_factors)):
168+
if all_factors[i] != treatment and all_factors[i] != outcome:
169+
edited_factors_list.append(all_factors[i])
170170

171171
for expert in expert_list:
172172
mediators_edges, mediators_list = self.request_mediators(
173173
treatment=treatment,
174174
outcome=outcome,
175175
domain_expertise=expert,
176-
factors_list=edited_factors_list,
176+
all_factors=edited_factors_list,
177177
mediators_edges=mediators_edges,
178178
analysis_context=analysis_context
179179
)
@@ -187,9 +187,9 @@ def request_mediators(
187187
treatment,
188188
outcome,
189189
domain_expertise,
190-
factors_list,
190+
all_factors,
191191
mediators_edges,
192-
analysis_context=CONTEXT
192+
analysis_context: str = CONTEXT
193193
):
194194
mediators: List[str] = list()
195195

@@ -218,7 +218,7 @@ def request_mediators(
218218
on the causal chain that links the {treatment} to the {outcome}? From your perspective as an expert in
219219
{domain_expertise}, which factor(s) of the following factors, if any at all, mediates, is/are on the causal
220220
chain, that links the {treatment} to the {outcome}? Then provide your step by step chain of thoughts within
221-
the tags <thinking></thinking>. factor_names : {factors_list} Wrap the name of the factor(s), if any at all,
221+
the tags <thinking></thinking>. factor_names : {all_factors} Wrap the name of the factor(s), if any at all,
222222
that has/have a high likelihood of directly influencing and causing the assignment of the {outcome} and also
223223
has/have a high likelihood of being directly influenced and caused by the assignment of the {treatment} within
224224
the tags <mediating_factor>factor_name</mediating_factor>. Where factor_name is one of the items within the
@@ -237,7 +237,7 @@ def request_mediators(
237237
if mediating_factor:
238238
for factor in mediating_factor:
239239
# to not add it twice into the list
240-
if factor in factors_list and factor not in mediators:
240+
if factor in all_factors and factor not in mediators:
241241
mediators.append(factor)
242242
success = True
243243

@@ -262,9 +262,9 @@ def suggest_ivs(
262262
self,
263263
treatment: str,
264264
outcome: str,
265-
factors_list: list(),
265+
all_factors: list(),
266266
expertise_list: list(),
267-
analysis_context=CONTEXT,
267+
analysis_context: str = CONTEXT,
268268
stakeholders: list() = None
269269
):
270270
expert_list: List[str] = list()
@@ -279,17 +279,17 @@ def suggest_ivs(
279279
iv_edges[(treatment, outcome)] = 1
280280

281281
edited_factors_list: List[str] = []
282-
for i in range(len(factors_list)):
283-
if factors_list[i] != treatment and factors_list[i] != outcome:
284-
edited_factors_list.append(factors_list[i])
282+
for i in range(len(all_factors)):
283+
if all_factors[i] != treatment and all_factors[i] != outcome:
284+
edited_factors_list.append(all_factors[i])
285285

286286
for expert in expert_list:
287287
iv_edges, iv_list = self.request_ivs(
288288
treatment=treatment,
289289
outcome=outcome,
290290
analysis_context=analysis_context,
291291
domain_expertise=expert,
292-
factors_list=edited_factors_list,
292+
all_factors=edited_factors_list,
293293
iv_edges=iv_edges,
294294
)
295295

@@ -305,7 +305,7 @@ def request_ivs(
305305
outcome,
306306
analysis_context,
307307
domain_expertise,
308-
factors_list,
308+
all_factors,
309309
iv_edges
310310
):
311311
ivs: List[str] = list()
@@ -338,7 +338,7 @@ def request_ivs(
338338
the {outcome}? Which factor(s) of the following factors, if any at all, are (an) instrumental variable(s)
339339
to the causal relationship of the {treatment} causing the {outcome}? Be concise and keep your thinking
340340
within two paragraphs. Then provide your step by step chain of thoughts within the tags
341-
<thinking></thinking>. factor_names : {factors_list} Wrap the name of the factor(s), if there are any at
341+
<thinking></thinking>. factor_names : {all_factors} Wrap the name of the factor(s), if there are any at
342342
all, that both has/have a high likelihood of influecing and causing the {treatment} and has/have a very low
343343
likelihood of influencing and causing the {outcome}, within the tags <iv_factor>factor_name</iv_factor>.
344344
Where factor_name is one of the items within the factor_names list. If a factor does not have a high
@@ -353,7 +353,7 @@ def request_ivs(
353353

354354
if iv_factors:
355355
for factor in iv_factors:
356-
if factor in factors_list and factor not in ivs:
356+
if factor in all_factors and factor not in ivs:
357357
ivs.append(factor)
358358
success = True
359359

0 commit comments

Comments
 (0)