Skip to content

Commit 43b1e09

Browse files
committed
Derive relationshipWeightProperty from trained models
This allows using weighted GraphSage models to generate embeddings as a nodePropertyStep
1 parent 39b689c commit 43b1e09

File tree

8 files changed

+236
-40
lines changed

8 files changed

+236
-40
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageAlgorithmFactory.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
3333
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
3434

35+
import java.util.Optional;
36+
3537
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.RESIDENT_MEMORY;
3638
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.TEMPORARY_MEMORY;
3739
import static org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver.resolveModel;
@@ -59,10 +61,10 @@ public GraphSage build(GraphStore graphStore, CONFIG configuration, ProgressTrac
5961
var graph = graphStore.getGraph(
6062
configuration.nodeLabelIdentifiers(graphStore),
6163
configuration.internalRelationshipTypes(graphStore),
62-
model.trainConfig().relationshipWeightProperty()
64+
Optional.ofNullable(model.trainConfig().relationshipWeightProperty())
6365
);
6466

65-
if(model.trainConfig().hasRelationshipWeightProperty()) {
67+
if(graph.hasRelationshipProperty()) {
6668
validateRelationshipWeightPropertyValue(graph, configuration.concurrency(), executorService);
6769
}
6870

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/TrainingPipeline.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ private int numberOfTrainerConfigs() {
9191
.sum();
9292
}
9393

94-
public void addNodePropertyStep(NodePropertyStep step) {
94+
public void addNodePropertyStep(ExecutableNodePropertyStep step) {
9595
validateUniqueMutateProperty(step);
9696
this.nodePropertySteps.add(step);
9797
}
@@ -140,7 +140,7 @@ public void setAutoTuningConfig(AutoTuningConfig autoTuningConfig) {
140140
this.autoTuningConfig = autoTuningConfig;
141141
}
142142

143-
private void validateUniqueMutateProperty(NodePropertyStep step) {
143+
private void validateUniqueMutateProperty(ExecutableNodePropertyStep step) {
144144
this.nodePropertySteps.forEach(nodePropertyStep -> {
145145
var newMutatePropertyName = step.config().get(MUTATE_PROPERTY_KEY);
146146
var existingMutatePropertyName = nodePropertyStep.config().get(MUTATE_PROPERTY_KEY);

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,22 @@
2121

2222
import org.neo4j.gds.api.GraphStore;
2323
import org.neo4j.gds.config.AlgoBaseConfig;
24+
import org.neo4j.gds.config.RelationshipWeightConfig;
2425
import org.neo4j.gds.config.ToMapConvertible;
26+
import org.neo4j.gds.core.model.Model;
27+
import org.neo4j.gds.executor.ExecutionContext;
28+
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
2529
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
2630

2731
import java.util.ArrayList;
2832
import java.util.HashMap;
2933
import java.util.List;
3034
import java.util.Map;
35+
import java.util.Objects;
3136
import java.util.Optional;
3237

3338
import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY;
39+
import static org.neo4j.gds.model.ModelConfig.MODEL_NAME_KEY;
3440

3541
public class LinkPredictionTrainingPipeline extends TrainingPipeline<LinkFeatureStep> {
3642

@@ -81,23 +87,41 @@ public void specificValidateBeforeExecution(GraphStore graphStore, AlgoBaseConfi
8187
}
8288
}
8389

84-
public Map<String, List<String>> tasksByRelationshipProperty() {
90+
public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext executionContext) {
8591
Map<String, List<String>> tasksByRelationshipProperty = new HashMap<>();
86-
nodePropertySteps().forEach(existingStep -> {
92+
93+
for (ExecutableNodePropertyStep existingStep : nodePropertySteps()) {
8794
if (existingStep.config().containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) {
8895
var existingProperty = (String) existingStep.config().get(RELATIONSHIP_WEIGHT_PROPERTY);
8996
var tasks = tasksByRelationshipProperty.computeIfAbsent(
9097
existingProperty,
9198
key -> new ArrayList<>()
9299
);
93100
tasks.add(existingStep.procName());
101+
} else if (existingStep.config().containsKey(MODEL_NAME_KEY)) {
102+
Optional.ofNullable(executionContext.modelCatalog().getUntyped(
103+
executionContext.username(),
104+
((String) existingStep.config().get(MODEL_NAME_KEY))
105+
))
106+
.map(Model::trainConfig)
107+
.filter(config -> config instanceof RelationshipWeightConfig)
108+
.map(config -> ((RelationshipWeightConfig) config).relationshipWeightProperty())
109+
.filter(Objects::nonNull)
110+
.ifPresent(property -> {
111+
var tasks = tasksByRelationshipProperty.computeIfAbsent(
112+
property,
113+
key -> new ArrayList<>()
114+
);
115+
tasks.add(existingStep.procName());
116+
});
94117
}
95-
});
118+
}
119+
96120
return tasksByRelationshipProperty;
97121
}
98122

99-
public Optional<String> relationshipWeightProperty() {
100-
var relationshipWeightPropertySet = tasksByRelationshipProperty().entrySet();
123+
public Optional<String> relationshipWeightProperty(ExecutionContext executionContext) {
124+
var relationshipWeightPropertySet = tasksByRelationshipProperty(executionContext).entrySet();
101125
return relationshipWeightPropertySet.isEmpty()
102126
? Optional.empty()
103127
: Optional.of(relationshipWeightPropertySet.iterator().next().getKey());

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionPipelineTest.java renamed to pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,37 @@
2222
import org.assertj.core.api.InstanceOfAssertFactories;
2323
import org.junit.jupiter.api.Nested;
2424
import org.junit.jupiter.api.Test;
25+
import org.neo4j.gds.NodeLabel;
26+
import org.neo4j.gds.RelationshipType;
27+
import org.neo4j.gds.api.schema.GraphSchema;
28+
import org.neo4j.gds.core.model.Model;
29+
import org.neo4j.gds.core.model.ModelCatalog;
30+
import org.neo4j.gds.core.model.OpenModelCatalog;
31+
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
32+
import org.neo4j.gds.executor.ExecutionContext;
2533
import org.neo4j.gds.executor.GdsCallableFinder;
34+
import org.neo4j.gds.executor.ImmutableExecutionContext;
2635
import org.neo4j.gds.ml.models.TrainingMethod;
2736
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
2837
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
2938
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfigImpl;
3039
import org.neo4j.gds.ml.pipeline.AutoTuningConfig;
40+
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
3141
import org.neo4j.gds.ml.pipeline.NodePropertyStep;
3242
import org.neo4j.gds.ml.pipeline.TestGdsCallableFinder;
3343
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.CosineFeatureStep;
3444
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.HadamardFeatureStep;
45+
import org.neo4j.gds.model.catalog.TestTrainConfigImpl;
46+
import org.neo4j.gds.model.catalog.TestWeightedTrainConfigImpl;
3547

48+
import java.util.Collection;
3649
import java.util.List;
3750
import java.util.Map;
3851
import java.util.stream.Collectors;
3952

4053
import static org.assertj.core.api.Assertions.assertThat;
4154

42-
class LinkPredictionPipelineTest {
55+
class LinkPredictionTrainingPipelineTest {
4356

4457
@Test
4558
void canCreateEmptyPipeline() {
@@ -153,6 +166,134 @@ void overridesTheSplitConfig() {
153166
.returns(splitConfigOverride, LinkPredictionTrainingPipeline::splitConfig);
154167
}
155168

169+
@Test
170+
void deriveRelationshipWeightProperty() {
171+
var executionContext = ImmutableExecutionContext.builder()
172+
.username("")
173+
.build();
174+
175+
var pipeline = new LinkPredictionTrainingPipeline();
176+
177+
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
178+
179+
var step = new TestNodePropertyStep(Map.of("relationshipWeightProperty", "myWeight"));
180+
181+
pipeline.addNodePropertyStep(step);
182+
183+
assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("myWeight");
184+
}
185+
186+
@Test
187+
void deriveRelationshipWeightPropertyFromTrainedModel() {
188+
var modelCatalog = new OpenModelCatalog();
189+
190+
String modelName = "myModel";
191+
modelCatalog.set(Model.of(
192+
"",
193+
modelName,
194+
"myAlgo",
195+
GraphSchema.empty(),
196+
1L,
197+
TestWeightedTrainConfigImpl.builder()
198+
.username("")
199+
.modelName(modelName)
200+
.relationshipWeightProperty("derivedWeight").build(),
201+
Map::of
202+
));
203+
204+
var executionContext = ImmutableExecutionContext.builder()
205+
.username("")
206+
.modelCatalog(modelCatalog)
207+
.build();
208+
209+
var pipeline = new LinkPredictionTrainingPipeline();
210+
211+
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
212+
213+
var step = new TestNodePropertyStep(Map.of("modelName", modelName));
214+
215+
pipeline.addNodePropertyStep(step);
216+
217+
assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("derivedWeight");
218+
}
219+
220+
@Test
221+
void notDerivePropertyFromUnweightedTrainedModel() {
222+
var modelCatalog = new OpenModelCatalog();
223+
224+
String modelName = "myModel";
225+
modelCatalog.set(Model.of(
226+
"",
227+
modelName,
228+
"myAlgo",
229+
GraphSchema.empty(),
230+
1L,
231+
TestTrainConfigImpl.builder()
232+
.username("")
233+
.modelName(modelName)
234+
.build(),
235+
Map::of
236+
));
237+
238+
var executionContext = ImmutableExecutionContext.builder()
239+
.username("")
240+
.modelCatalog(modelCatalog)
241+
.build();
242+
243+
var pipeline = new LinkPredictionTrainingPipeline();
244+
245+
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
246+
247+
var step = new TestNodePropertyStep(Map.of("modelName", modelName));
248+
249+
pipeline.addNodePropertyStep(step);
250+
251+
assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
252+
}
253+
254+
private static class TestNodePropertyStep implements ExecutableNodePropertyStep {
255+
private final Map<String, Object> config;
256+
257+
public TestNodePropertyStep(Map<String, Object> config) {
258+
this.config = config;
259+
}
260+
261+
@Override
262+
public void execute(
263+
ExecutionContext executionContext,
264+
String graphName,
265+
Collection<NodeLabel> nodeLabels,
266+
Collection<RelationshipType> relTypes
267+
) {
268+
269+
}
270+
271+
@Override
272+
public Map<String, Object> config() {
273+
return config;
274+
}
275+
276+
@Override
277+
public String procName() {
278+
return null;
279+
}
280+
281+
@Override
282+
public MemoryEstimation estimate(
283+
ModelCatalog modelCatalog,
284+
String username,
285+
List<String> nodeLabels,
286+
List<String> relTypes
287+
) {
288+
return null;
289+
}
290+
291+
@Override
292+
public Map<String, Object> toMap() {
293+
return config;
294+
}
295+
}
296+
156297
@Nested
157298
class ToMapTest {
158299

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddStepProcs.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ private void validateRelationshipProperty(
7777
Map<String, Object> procedureConfig
7878
) {
7979
if (!procedureConfig.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) return;
80-
var maybeRelationshipProperty = pipeline.relationshipWeightProperty();
80+
var maybeRelationshipProperty = pipeline.relationshipWeightProperty(executionContext());
8181
if (maybeRelationshipProperty.isEmpty()) return;
8282
var relationshipProperty = maybeRelationshipProperty.get();
8383
var property = (String) procedureConfig.get(RELATIONSHIP_WEIGHT_PROPERTY);
8484
if (relationshipProperty.equals(property)) return;
8585

86-
String tasks = pipeline.tasksByRelationshipProperty()
86+
String tasks = pipeline.tasksByRelationshipProperty(executionContext())
8787
.get(relationshipProperty)
8888
.stream()
8989
.map(s -> "`" + s + "`")

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ public Map<DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
131131
config.relationshipTypes(),
132132
config.nodeLabels(),
133133
config.randomSeed(),
134-
pipeline.relationshipWeightProperty()
134+
pipeline.relationshipWeightProperty(executionContext)
135135
);
136136

137137
var splitConfig = pipeline.splitConfig();

proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineIntegrationTest.java

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,23 @@ public class LinkPredictionPipelineIntegrationTest extends BaseProcTest {
7878
@Neo4jGraph
7979
static String GRAPH =
8080
NODES +
81-
"(a)-[:REL]->(b), " +
82-
"(a)-[:REL]->(c), " +
83-
"(a)-[:REL]->(d), " +
84-
"(b)-[:REL]->(c), " +
85-
"(b)-[:REL]->(d), " +
86-
"(c)-[:REL]->(d), " +
87-
"(e)-[:REL]->(f), " +
88-
"(e)-[:REL]->(g), " +
89-
"(f)-[:REL]->(g), " +
90-
"(h)-[:REL]->(i), " +
91-
"(j)-[:REL]->(k), " +
92-
"(j)-[:REL]->(l), " +
93-
"(k)-[:REL]->(l), " +
94-
"(m)-[:REL]->(n), " +
95-
"(m)-[:REL]->(o), " +
96-
"(n)-[:REL]->(o), " +
97-
"(a)-[:REL]->(p), " +
81+
"(a)-[:REL {weight: 5.0}]->(b), " +
82+
"(a)-[:REL {weight: 2.0}]->(c), " +
83+
"(a)-[:REL {weight: 4.0}]->(d), " +
84+
"(b)-[:REL {weight: 5.0}]->(c), " +
85+
"(b)-[:REL {weight: 3.0}]->(d), " +
86+
"(c)-[:REL {weight: 2.0}]->(d), " +
87+
"(e)-[:REL {weight: 1.0}]->(f), " +
88+
"(e)-[:REL {weight: 2.0}]->(g), " +
89+
"(f)-[:REL {weight: 5.0}]->(g), " +
90+
"(h)-[:REL {weight: 5.0}]->(i), " +
91+
"(j)-[:REL {weight: 5.0}]->(k), " +
92+
"(j)-[:REL {weight: 4.0}]->(l), " +
93+
"(k)-[:REL {weight: 5.0}]->(l), " +
94+
"(m)-[:REL {weight: 4.0}]->(n), " +
95+
"(m)-[:REL {weight: 5.0}]->(o), " +
96+
"(n)-[:REL {weight: 2.0}]->(o), " +
97+
"(a)-[:REL {weight: 5.0}]->(p), " +
9898

9999
"(a)-[:IGNORED]->(e), " +
100100
"(m)-[:IGNORED]->(a), " +
@@ -245,18 +245,19 @@ protected String getUsername() {
245245

246246
@Test
247247
void runWithGraphSage() {
248-
runQueryWithUser("CALL gds.graph.project('g_2',{ N: { properties: ['z']}}, {REL: {orientation: 'UNDIRECTED'}})");
248+
runQueryWithUser("CALL gds.graph.project('g_2',{ N: { properties: ['z']}}, {REL: {orientation: 'UNDIRECTED', properties: ['weight']}})");
249249

250250
runQueryWithUser("CALL gds.beta.graphSage.train(" +
251-
" 'g_2'," +
252-
" {" +
253-
" modelName: 'exampleTrainModel'," +
254-
" featureProperties: ['z']," +
255-
" aggregator: 'mean'," +
256-
" activationFunction: 'sigmoid'," +
257-
" randomSeed: 1337," +
258-
" sampleSizes: [25, 10]" +
259-
" })");
251+
" 'g_2'," +
252+
" {" +
253+
" modelName: 'exampleTrainModel'," +
254+
" relationshipWeightProperty: 'weight'," +
255+
" featureProperties: ['z']," +
256+
" aggregator: 'mean'," +
257+
" activationFunction: 'sigmoid'," +
258+
" randomSeed: 1337," +
259+
" sampleSizes: [25, 10]" +
260+
" })");
260261

261262
runQueryWithUser("CALL gds.beta.pipeline.linkPrediction.create('myPipe') ");
262263
runQueryWithUser("CALL gds.beta.pipeline.linkPrediction.configureSplit('myPipe', {validationFolds: 2, testFraction: 0.3, trainFraction: 0.3})");

0 commit comments

Comments
 (0)