Skip to content
Open

Dev #1629

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: python -m pip install --upgrade pip==23.2
- name: Install LIT package with testing dependencies
run: python -m pip install -e '.[test]'
- name: Debug dependency tree
run: |
python -m pip install pipdeptree
pipdeptree | grep decorator -A 5 || true
- name: Test LIT
run: pytest -v
- name: Setup Node ${{ matrix.node-version }}
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/client/core/slice_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ export class SliceModule extends LitModule {
// clang-format off
return html`
<div class="row-container">
<input type="text" id="input-box" .value=${this.sliceName}
<input type="text" id="input-box" .value=${this.sliceName ?? ''}
placeholder="Enter name" @input=${onInputChange}
@keyup=${(e: KeyboardEvent) => {onKeyUp(e);}}/>
<button class='hairline-button'
Expand Down
22 changes: 15 additions & 7 deletions lit_nlp/client/modules/annotated_text_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ import {customElement} from 'lit/decorators.js';
import {makeObservable, observable} from 'mobx';

import {LitModule} from '../core/lit_module';
import {type AnnotationGroups, TextSegments} from '../elements/annotated_text_vis';
import {type AnnotationGroups, type AnnotationSpec, type SegmentSpec, TextSegments} from '../elements/annotated_text_vis';
import {MultiSegmentAnnotations, TextSegment} from '../lib/lit_types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {type IndexedInput, ModelInfoMap, Spec} from '../lib/types';
import {doesOutputSpecContain, filterToKeys, findSpecKeys} from '../lib/utils';

// This should be removed.
type AnyDuringMigration = any;

/** LIT module for model output. */
@customElement('annotated-text-gold-module')
export class AnnotatedTextGoldModule extends LitModule {
Expand Down Expand Up @@ -53,13 +56,15 @@ export class AnnotatedTextGoldModule extends LitModule {
// Text segment fields
const segmentNames = findSpecKeys(dataSpec, TextSegment);
const segments: TextSegments = filterToKeys(input.data, segmentNames);
const segmentSpec = filterToKeys(dataSpec, segmentNames);
const segmentSpec: SegmentSpec =
filterToKeys(dataSpec, segmentNames) as AnyDuringMigration;

// Annotation fields
const annotationNames = findSpecKeys(dataSpec, MultiSegmentAnnotations);
const annotations: AnnotationGroups =
filterToKeys(input.data, annotationNames);
const annotationSpec = filterToKeys(dataSpec, annotationNames);
const annotationSpec: AnnotationSpec =
filterToKeys(dataSpec, annotationNames) as AnyDuringMigration;

// If more than one model is selected, AnnotatedTextModule will be offset
// vertically due to the model name header, while this one won't be.
Expand Down Expand Up @@ -149,12 +154,15 @@ export class AnnotatedTextModule extends LitModule {
findSpecKeys(this.appState.currentDatasetSpec, TextSegment);
const segments: TextSegments =
filterToKeys(this.currentData.data, segmentNames);
const segmentSpec =
filterToKeys(this.appState.currentDatasetSpec, segmentNames);
const segmentSpec: SegmentSpec =
filterToKeys(this.appState.currentDatasetSpec, segmentNames) as
AnyDuringMigration;

const outputSpec = this.appState.getModelSpec(this.model).output;
const annotationSpec = filterToKeys(
outputSpec, findSpecKeys(outputSpec, MultiSegmentAnnotations));
const annotationSpec: AnnotationSpec =
filterToKeys(
outputSpec, findSpecKeys(outputSpec, MultiSegmentAnnotations)) as
AnyDuringMigration;
// clang-format off
return html`
<annotated-text-vis .segments=${segments}
Expand Down
10 changes: 7 additions & 3 deletions lit_nlp/client/modules/feature_attribution_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ import {LegendType} from '../elements/color_legend';
import {InterpreterClick, InterpreterSettings} from '../elements/interpreter_controls';
import {SortableTemplateResult, TableData} from '../elements/table';
import {FeatureSalience as FeatureSalienceLitType, LitTypeWithVocab, SingleFieldMatcher} from '../lib/lit_types';
import {IndexedInput, ModelInfoMap} from '../lib/types';
import {styles as sharedStyles} from '../lib/shared_styles.css';
import {type D3Scale, IndexedInput, ModelInfoMap} from '../lib/types';
import * as utils from '../lib/utils';
import {findSpecKeys} from '../lib/utils';
import {SignedSalienceCmap} from '../services/color_service';
import {type NumericFeatureBins} from '../services/group_service';
import {AppState, GroupService} from '../services/services';

import {styles as sharedStyles} from '../lib/shared_styles.css';
import {styles} from './feature_attribution_module.css';

const ALL_DATA = 'Entire Dataset';
Expand Down Expand Up @@ -75,6 +75,9 @@ interface VisToggles {
[name: string]: boolean;
}

// This should be removed.
type AnyDuringMigration = any;

/** Aggregate feature attribution for tabular ML models. */
@customElement('feature-attribution-module')
export class FeatureAttributionModule extends LitModule {
Expand Down Expand Up @@ -440,7 +443,8 @@ export class FeatureAttributionModule extends LitModule {
}

override renderImpl() {
const scale = (val: number) => this.colorMap.bgCmap(val);
const scale: D3Scale =
((val: number) => this.colorMap.bgCmap(val)) as AnyDuringMigration;
scale.domain = () => this.colorMap.colorScale.domain();

// clang-format off
Expand Down
9 changes: 7 additions & 2 deletions lit_nlp/client/modules/pdp_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ interface AllPdpInfo {
// Data for bar or line charts.
type ChartInfo = Map<string|number, number>;

// This should be removed.
type AnyDuringMigration = any;

/**
* A LIT module that renders regression results.
*/
Expand Down Expand Up @@ -172,14 +175,16 @@ export class PdpModule extends LitModule {
const yRange = isClassification ? [0, 1] : [];
const renderChart = (chartData: ChartInfo) => {
if (isNumeric) {
const chartMap: Map<number, number> = chartData as AnyDuringMigration;
return html`
<line-chart height=150 width=300
.scores=${chartData} .yScale=${yRange}>
.scores=${chartMap} .yScale=${yRange}>
</line-chart>`;
} else {
const chartMap: Map<string, number> = chartData as AnyDuringMigration;
return html`
<bar-chart height=150 width=300
.scores=${chartData} .yScale=${yRange}>
.scores=${chartMap} .yScale=${yRange}>
</bar-chart>`;
}

Expand Down
1 change: 1 addition & 0 deletions lit_nlp/client/services/group_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ export class GroupService extends LitService {
getFeatureValForInput(
bins: NumericFeatureBins, d: IndexedInput, feature: string): string | null {
const isNumerical = this.numericalFeatureNames.includes(feature);
// @ts-ignore
return isNumerical ? this.getNumericalBinForExample(bins, d, feature) :
this.dataService.getVal(d.id, feature);
}
Expand Down
46 changes: 41 additions & 5 deletions lit_nlp/components/curves_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def input_spec(self) -> lit_types.Spec:
def output_spec(self) -> lit_types.Spec:
return {
'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'),
'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label')
'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'),
}

def predict_minibatch(
Expand All @@ -64,10 +64,9 @@ def predict_example(ex: lit_types.JsonDict) -> tuple[float, float, float]:
return TEST_DATA[x].prediction

for example in inputs:
output.append({
'pred': predict_example(example),
'aux_pred': [1 / 3, 1 / 3, 1 / 3]
})
output.append(
{'pred': predict_example(example), 'aux_pred': [1 / 3, 1 / 3, 1 / 3]}
)
return output


Expand Down Expand Up @@ -148,6 +147,43 @@ def test_model_output_is_missing_in_config(self):
config={'Label': 'red'},
)

@parameterized.named_parameters(
dict(
testcase_name='red',
label='red',
exp_roc=[(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)],
exp_pr=[(0.5, 0.5), (2 / 3, 1.0), (1.0, 0.5), (1.0, 0.0)],
),
dict(
testcase_name='blue',
label='blue',
exp_roc=[(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)],
exp_pr=[
(0.3333333333333333, 1.0),
(0.5, 1.0),
(1.0, 1.0),
(1.0, 0.0),
],
),
)
def test_interpreter_honors_user_selected_label(
self, label: str, exp_roc: _Curve, exp_pr: _Curve
):
"""Tests a happy scenario when a user doesn't specify the class label."""
curves_data = self.ci.run(
inputs=self.dataset.examples,
model=self.model,
dataset=self.dataset,
config={
curves.TARGET_LABEL_KEY: label,
curves.TARGET_PREDICTION_KEY: 'pred',
},
)
self.assertIn(curves.ROC_DATA, curves_data)
self.assertIn(curves.PR_DATA, curves_data)
self.assertEqual(curves_data[curves.ROC_DATA], exp_roc)
self.assertEqual(curves_data[curves.PR_DATA], exp_pr)

def test_config_spec(self):
"""Tests that the interpreter config has correct fields of correct type."""
spec = self.ci.config_spec()
Expand Down
27 changes: 27 additions & 0 deletions lit_nlp/examples/dalle_mini/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Dalle_Mini Demo for the Learning Interpretability Tool
=======================================================

This demo showcases how LIT can be used in text-to-image generation mode. It is
based on the mini-dalle Mini model
(https://www.piwheels.org/project/dalle-mini/).

You will need a standalone virtual environment for the Python libraries, which
you can set up using the following commands from the root of the LIT repo.

```sh
# Create the virtual environment. You may want to use python3 or python3.10
# depends on how many Python versions you have installed and their aliases.
python -m venv .dalle-mini
source .dalle-mini/bin/activate
# This requirements.txt file will also install the core LIT library deps.
pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt
# The LIT web app still needs to be built in the usual way.
(cd ./lit_nlp && yarn && yarn build)
```

Once your virtual environment is setup, you can launch the demo with the
following command.

```sh
python -m lit_nlp.examples.dalle_mini.demo
```
23 changes: 23 additions & 0 deletions lit_nlp/examples/dalle_mini/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Data loaders for dalle-mini model."""

from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types


class DallePrompts(lit_dataset.Dataset):
"""DallePrompts is a dataset that contains a list of prompts.

It is used to generate images using the dalle-mini model.
"""

def __init__(self, prompts: list[str]):
self._examples = []
for prompt in prompts:
self._examples.append({"prompt": prompt})

@classmethod
def init_spec(cls) -> lit_types.Spec:
return {"prompt": lit_types.TextSegment(required=True)}

def spec(self) -> lit_types.Spec:
return {"prompt": lit_types.TextSegment()}
Loading
Loading