Skip to content

Commit 98cb7b4

Browse files
authored
Merge pull request #6040 from FlorentinD/lp-gs-weighted-21
Fix usage of weighted GraphSage models to Pipelines
2 parents f47698e + 3598a41 commit 98cb7b4

File tree

16 files changed

+281
-254
lines changed

16 files changed

+281
-254
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageAlgorithmFactory.java

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
*/
2020
package org.neo4j.gds.embeddings.graphsage.algo;
2121

22-
import org.neo4j.gds.GraphAlgorithmFactory;
23-
import org.neo4j.gds.api.Graph;
22+
import org.neo4j.gds.GraphStoreAlgorithmFactory;
23+
import org.neo4j.gds.api.GraphStore;
2424
import org.neo4j.gds.config.MutateConfig;
2525
import org.neo4j.gds.core.concurrency.Pools;
2626
import org.neo4j.gds.core.model.ModelCatalog;
@@ -32,13 +32,15 @@
3232
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
3333
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
3434

35+
import java.util.Optional;
36+
3537
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.RESIDENT_MEMORY;
3638
import static org.neo4j.gds.core.utils.mem.MemoryEstimations.TEMPORARY_MEMORY;
3739
import static org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver.resolveModel;
3840
import static org.neo4j.gds.mem.MemoryUsage.sizeOfDoubleArray;
3941
import static org.neo4j.gds.ml.core.EmbeddingUtils.validateRelationshipWeightPropertyValue;
4042

41-
public class GraphSageAlgorithmFactory<CONFIG extends GraphSageBaseConfig> extends GraphAlgorithmFactory<GraphSage, CONFIG> {
43+
public class GraphSageAlgorithmFactory<CONFIG extends GraphSageBaseConfig> extends GraphStoreAlgorithmFactory<GraphSage, CONFIG> {
4244

4345
private final ModelCatalog modelCatalog;
4446

@@ -48,30 +50,22 @@ public GraphSageAlgorithmFactory(ModelCatalog modelCatalog) {
4850
}
4951

5052
@Override
51-
public String taskName() {
52-
return GraphSage.class.getSimpleName();
53-
}
54-
55-
@Override
56-
public GraphSage build(
57-
Graph graph,
58-
CONFIG configuration,
59-
ProgressTracker progressTracker
60-
) {
61-
53+
public GraphSage build(GraphStore graphStore, CONFIG configuration, ProgressTracker progressTracker) {
6254
var executorService = Pools.DEFAULT;
6355
var model = resolveModel(
6456
modelCatalog,
6557
configuration.username(),
6658
configuration.modelName()
6759
);
6860

69-
if(model.trainConfig().isWeighted()) {
70-
validateRelationshipWeightPropertyValue(graph, configuration.concurrency(), executorService);
71-
}
61+
var graph = graphStore.getGraph(
62+
configuration.nodeLabelIdentifiers(graphStore),
63+
configuration.internalRelationshipTypes(graphStore),
64+
Optional.ofNullable(model.trainConfig().relationshipWeightProperty())
65+
);
7266

73-
if (!model.trainConfig().isWeighted() && graph.hasRelationshipProperty()) {
74-
throw new IllegalStateException("Model was trained without relationship weights. Expected an unweighted graph");
67+
if(graph.hasRelationshipProperty()) {
68+
validateRelationshipWeightPropertyValue(graph, configuration.concurrency(), executorService);
7569
}
7670

7771
return new GraphSage(
@@ -83,6 +77,11 @@ public GraphSage build(
8377
);
8478
}
8579

80+
@Override
81+
public String taskName() {
82+
return GraphSage.class.getSimpleName();
83+
}
84+
8685
@Override
8786
public MemoryEstimation memoryEstimation(CONFIG config) {
8887
var model = resolveModel(modelCatalog, config.username(), config.modelName());
@@ -98,8 +97,8 @@ public MemoryEstimation memoryEstimation(CONFIG config) {
9897
}
9998

10099
@Override
101-
public Task progressTask(Graph graph, CONFIG config) {
102-
return Tasks.leaf(taskName(), graph.nodeCount());
100+
public Task progressTask(GraphStore graphStore, CONFIG config) {
101+
return Tasks.leaf(taskName(), graphStore.getGraph(config.nodeLabelIdentifiers(graphStore)).nodeCount());
103102
}
104103

105104
private MemoryEstimation withNodeCount(GraphSageTrainConfig config, long nodeCount, boolean mutate) {

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageBaseConfig.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121

2222
import org.neo4j.gds.config.AlgoBaseConfig;
2323
import org.neo4j.gds.config.BatchSizeConfig;
24-
import org.neo4j.gds.config.RelationshipWeightConfig;
2524
import org.neo4j.gds.model.ModelConfig;
2625

27-
public interface GraphSageBaseConfig extends AlgoBaseConfig, BatchSizeConfig, ModelConfig, RelationshipWeightConfig {
26+
public interface GraphSageBaseConfig extends AlgoBaseConfig, BatchSizeConfig, ModelConfig {
2827
long serialVersionUID = 0x42L;
2928
}

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,6 @@ default boolean propertiesMustExistForEachNodeLabel() {
151151
return false;
152152
}
153153

154-
@Configuration.Ignore
155-
@Value.Derived
156-
default boolean isWeighted() {
157-
return relationshipWeightProperty() != null;
158-
}
159-
160154
@Configuration.Ignore
161155
default List<LayerConfig> layerConfigs(int featureDimension) {
162156
List<LayerConfig> result = new ArrayList<>(sampleSizes().size());

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageTest.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626
import org.neo4j.gds.NodeLabel;
2727
import org.neo4j.gds.Orientation;
2828
import org.neo4j.gds.api.Graph;
29+
import org.neo4j.gds.api.GraphStore;
2930
import org.neo4j.gds.beta.generator.PropertyProducer;
3031
import org.neo4j.gds.beta.generator.RandomGraphGenerator;
3132
import org.neo4j.gds.beta.generator.RelationshipDistribution;
3233
import org.neo4j.gds.compat.Neo4jProxy;
3334
import org.neo4j.gds.config.RandomGraphGeneratorConfig;
3435
import org.neo4j.gds.core.Aggregation;
3536
import org.neo4j.gds.core.concurrency.Pools;
37+
import org.neo4j.gds.core.huge.HugeGraph;
38+
import org.neo4j.gds.core.loading.CSRGraphStoreUtil;
3639
import org.neo4j.gds.core.loading.construction.NodeLabelTokens;
3740
import org.neo4j.gds.core.model.InjectModelCatalog;
3841
import org.neo4j.gds.core.model.ModelCatalog;
@@ -52,6 +55,7 @@
5255

5356
import java.util.Arrays;
5457
import java.util.List;
58+
import java.util.Optional;
5559
import java.util.Random;
5660
import java.util.stream.LongStream;
5761

@@ -73,6 +77,9 @@ class GraphSageTest {
7377
@Inject
7478
private Graph orphanGraph;
7579

80+
@Inject
81+
private GraphStore orphanGraphStore;
82+
7683
@InjectModelCatalog
7784
private ModelCatalog modelCatalog;
7885

@@ -82,12 +89,13 @@ class GraphSageTest {
8289
private static final String MODEL_NAME = "graphSageModel";
8390

8491
private Graph graph;
92+
private GraphStore graphStore;
8593
private HugeObjectArray<double[]> features;
8694
private ImmutableGraphSageTrainConfig.Builder configBuilder;
8795

8896
@BeforeEach
8997
void setUp() {
90-
graph = RandomGraphGenerator.builder()
98+
HugeGraph randomGraph = RandomGraphGenerator.builder()
9199
.nodeCount(NODE_COUNT)
92100
.averageDegree(3)
93101
.nodeLabelProducer(nodeId -> NodeLabelTokens.of("P"))
@@ -100,7 +108,18 @@ void setUp() {
100108
.allowSelfLoops(RandomGraphGeneratorConfig.AllowSelfLoops.NO)
101109
.build().generate();
102110

111+
graph = randomGraph;
112+
103113
long nodeCount = graph.nodeCount();
114+
115+
graphStore = CSRGraphStoreUtil.createFromGraph(
116+
Neo4jProxy.randomDatabaseId(),
117+
randomGraph,
118+
"REL",
119+
Optional.of("weight"),
120+
4
121+
);
122+
104123
features = HugeObjectArray.newArray(double[].class, nodeCount);
105124

106125
Random random = new Random();
@@ -137,7 +156,7 @@ void shouldNotMakeNanEmbeddings(Aggregator.AggregatorType aggregator) {
137156
.build();
138157

139158
var graphSage = new GraphSageAlgorithmFactory<>(modelCatalog).build(
140-
orphanGraph,
159+
orphanGraphStore,
141160
streamConfig,
142161
ProgressTracker.NULL_TRACKER
143162
);
@@ -209,7 +228,7 @@ void testLogging() {
209228

210229
var log = Neo4jProxy.testLog();
211230
var graphSage = new GraphSageAlgorithmFactory<>(modelCatalog).build(
212-
graph,
231+
graphStore,
213232
streamConfig,
214233
log,
215234
EmptyTaskRegistryFactory.INSTANCE

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/TrainingPipeline.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ private int numberOfTrainerConfigs() {
9191
.sum();
9292
}
9393

94-
public void addNodePropertyStep(NodePropertyStep step) {
94+
public void addNodePropertyStep(ExecutableNodePropertyStep step) {
9595
validateUniqueMutateProperty(step);
9696
this.nodePropertySteps.add(step);
9797
}
@@ -140,7 +140,7 @@ public void setAutoTuningConfig(AutoTuningConfig autoTuningConfig) {
140140
this.autoTuningConfig = autoTuningConfig;
141141
}
142142

143-
private void validateUniqueMutateProperty(NodePropertyStep step) {
143+
private void validateUniqueMutateProperty(ExecutableNodePropertyStep step) {
144144
this.nodePropertySteps.forEach(nodePropertyStep -> {
145145
var newMutatePropertyName = step.config().get(MUTATE_PROPERTY_KEY);
146146
var existingMutatePropertyName = nodePropertyStep.config().get(MUTATE_PROPERTY_KEY);

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,22 @@
2121

2222
import org.neo4j.gds.api.GraphStore;
2323
import org.neo4j.gds.config.AlgoBaseConfig;
24+
import org.neo4j.gds.config.RelationshipWeightConfig;
2425
import org.neo4j.gds.config.ToMapConvertible;
26+
import org.neo4j.gds.core.model.Model;
27+
import org.neo4j.gds.executor.ExecutionContext;
28+
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
2529
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
2630

2731
import java.util.ArrayList;
2832
import java.util.HashMap;
2933
import java.util.List;
3034
import java.util.Map;
35+
import java.util.Objects;
3136
import java.util.Optional;
3237

3338
import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY;
39+
import static org.neo4j.gds.model.ModelConfig.MODEL_NAME_KEY;
3440

3541
public class LinkPredictionTrainingPipeline extends TrainingPipeline<LinkFeatureStep> {
3642

@@ -81,23 +87,41 @@ public void specificValidateBeforeExecution(GraphStore graphStore, AlgoBaseConfi
8187
}
8288
}
8389

84-
public Map<String, List<String>> tasksByRelationshipProperty() {
90+
public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext executionContext) {
8591
Map<String, List<String>> tasksByRelationshipProperty = new HashMap<>();
86-
nodePropertySteps().forEach(existingStep -> {
92+
93+
for (ExecutableNodePropertyStep existingStep : nodePropertySteps()) {
8794
if (existingStep.config().containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) {
8895
var existingProperty = (String) existingStep.config().get(RELATIONSHIP_WEIGHT_PROPERTY);
8996
var tasks = tasksByRelationshipProperty.computeIfAbsent(
9097
existingProperty,
9198
key -> new ArrayList<>()
9299
);
93100
tasks.add(existingStep.procName());
101+
} else if (existingStep.config().containsKey(MODEL_NAME_KEY)) {
102+
Optional.ofNullable(executionContext.modelCatalog().getUntyped(
103+
executionContext.username(),
104+
((String) existingStep.config().get(MODEL_NAME_KEY))
105+
))
106+
.map(Model::trainConfig)
107+
.filter(config -> config instanceof RelationshipWeightConfig)
108+
.map(config -> ((RelationshipWeightConfig) config).relationshipWeightProperty())
109+
.filter(Objects::nonNull)
110+
.ifPresent(property -> {
111+
var tasks = tasksByRelationshipProperty.computeIfAbsent(
112+
property,
113+
key -> new ArrayList<>()
114+
);
115+
tasks.add(existingStep.procName());
116+
});
94117
}
95-
});
118+
}
119+
96120
return tasksByRelationshipProperty;
97121
}
98122

99-
public Optional<String> relationshipWeightProperty() {
100-
var relationshipWeightPropertySet = tasksByRelationshipProperty().entrySet();
123+
public Optional<String> relationshipWeightProperty(ExecutionContext executionContext) {
124+
var relationshipWeightPropertySet = tasksByRelationshipProperty(executionContext).entrySet();
101125
return relationshipWeightPropertySet.isEmpty()
102126
? Optional.empty()
103127
: Optional.of(relationshipWeightPropertySet.iterator().next().getKey());

0 commit comments

Comments
 (0)