Skip to content

Commit 39b689c

Browse files
committed
Resolve graph in GraphSageAlgorithmFactory
resolving the todo as we no longer have anonymous graphs.
1 parent 922f00a commit 39b689c

File tree

9 files changed

+48
-216
lines changed

9 files changed

+48
-216
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageAlgorithmFactory.java

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
*/
2020
package org.neo4j.gds.embeddings.graphsage.algo;
2121

22-
import org.neo4j.gds.GraphAlgorithmFactory;
23-
import org.neo4j.gds.api.Graph;
22+
import org.neo4j.gds.GraphStoreAlgorithmFactory;
23+
import org.neo4j.gds.api.GraphStore;
2424
import org.neo4j.gds.config.MutateConfig;
2525
import org.neo4j.gds.core.concurrency.Pools;
2626
import org.neo4j.gds.core.model.ModelCatalog;
@@ -38,7 +38,7 @@
3838
import static org.neo4j.gds.mem.MemoryUsage.sizeOfDoubleArray;
3939
import static org.neo4j.gds.ml.core.EmbeddingUtils.validateRelationshipWeightPropertyValue;
4040

41-
public class GraphSageAlgorithmFactory<CONFIG extends GraphSageBaseConfig> extends GraphAlgorithmFactory<GraphSage, CONFIG> {
41+
public class GraphSageAlgorithmFactory<CONFIG extends GraphSageBaseConfig> extends GraphStoreAlgorithmFactory<GraphSage, CONFIG> {
4242

4343
private final ModelCatalog modelCatalog;
4444

@@ -48,30 +48,22 @@ public GraphSageAlgorithmFactory(ModelCatalog modelCatalog) {
4848
}
4949

5050
@Override
51-
public String taskName() {
52-
return GraphSage.class.getSimpleName();
53-
}
54-
55-
@Override
56-
public GraphSage build(
57-
Graph graph,
58-
CONFIG configuration,
59-
ProgressTracker progressTracker
60-
) {
61-
51+
public GraphSage build(GraphStore graphStore, CONFIG configuration, ProgressTracker progressTracker) {
6252
var executorService = Pools.DEFAULT;
6353
var model = resolveModel(
6454
modelCatalog,
6555
configuration.username(),
6656
configuration.modelName()
6757
);
6858

69-
if(model.trainConfig().isWeighted()) {
70-
validateRelationshipWeightPropertyValue(graph, configuration.concurrency(), executorService);
71-
}
59+
var graph = graphStore.getGraph(
60+
configuration.nodeLabelIdentifiers(graphStore),
61+
configuration.internalRelationshipTypes(graphStore),
62+
model.trainConfig().relationshipWeightProperty()
63+
);
7264

73-
if (!model.trainConfig().isWeighted() && graph.hasRelationshipProperty()) {
74-
throw new IllegalStateException("Model was trained without relationship weights. Expected an unweighted graph");
65+
if(model.trainConfig().hasRelationshipWeightProperty()) {
66+
validateRelationshipWeightPropertyValue(graph, configuration.concurrency(), executorService);
7567
}
7668

7769
return new GraphSage(
@@ -83,6 +75,11 @@ public GraphSage build(
8375
);
8476
}
8577

78+
@Override
79+
public String taskName() {
80+
return GraphSage.class.getSimpleName();
81+
}
82+
8683
@Override
8784
public MemoryEstimation memoryEstimation(CONFIG config) {
8885
var model = resolveModel(modelCatalog, config.username(), config.modelName());
@@ -98,8 +95,8 @@ public MemoryEstimation memoryEstimation(CONFIG config) {
9895
}
9996

10097
@Override
101-
public Task progressTask(Graph graph, CONFIG config) {
102-
return Tasks.leaf(taskName(), graph.nodeCount());
98+
public Task progressTask(GraphStore graphStore, CONFIG config) {
99+
return Tasks.leaf(taskName(), graphStore.getGraph(config.nodeLabelIdentifiers(graphStore)).nodeCount());
103100
}
104101

105102
private MemoryEstimation withNodeCount(GraphSageTrainConfig config, long nodeCount, boolean mutate) {

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageBaseConfig.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121

2222
import org.neo4j.gds.config.AlgoBaseConfig;
2323
import org.neo4j.gds.config.BatchSizeConfig;
24-
import org.neo4j.gds.config.RelationshipWeightConfig;
2524
import org.neo4j.gds.model.ModelConfig;
2625

27-
public interface GraphSageBaseConfig extends AlgoBaseConfig, BatchSizeConfig, ModelConfig, RelationshipWeightConfig {
26+
public interface GraphSageBaseConfig extends AlgoBaseConfig, BatchSizeConfig, ModelConfig {
2827
long serialVersionUID = 0x42L;
2928
}

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,6 @@ default boolean propertiesMustExistForEachNodeLabel() {
151151
return false;
152152
}
153153

154-
@Configuration.Ignore
155-
@Value.Derived
156-
default boolean isWeighted() {
157-
return relationshipWeightProperty() != null;
158-
}
159-
160154
@Configuration.Ignore
161155
default List<LayerConfig> layerConfigs(int featureDimension) {
162156
List<LayerConfig> result = new ArrayList<>(sampleSizes().size());

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageTest.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626
import org.neo4j.gds.NodeLabel;
2727
import org.neo4j.gds.Orientation;
2828
import org.neo4j.gds.api.Graph;
29+
import org.neo4j.gds.api.GraphStore;
2930
import org.neo4j.gds.beta.generator.PropertyProducer;
3031
import org.neo4j.gds.beta.generator.RandomGraphGenerator;
3132
import org.neo4j.gds.beta.generator.RelationshipDistribution;
3233
import org.neo4j.gds.compat.Neo4jProxy;
3334
import org.neo4j.gds.config.RandomGraphGeneratorConfig;
3435
import org.neo4j.gds.core.Aggregation;
3536
import org.neo4j.gds.core.concurrency.Pools;
37+
import org.neo4j.gds.core.huge.HugeGraph;
38+
import org.neo4j.gds.core.loading.CSRGraphStoreUtil;
3639
import org.neo4j.gds.core.loading.construction.NodeLabelTokens;
3740
import org.neo4j.gds.core.model.InjectModelCatalog;
3841
import org.neo4j.gds.core.model.ModelCatalog;
@@ -52,12 +55,14 @@
5255

5356
import java.util.Arrays;
5457
import java.util.List;
58+
import java.util.Optional;
5559
import java.util.Random;
5660
import java.util.stream.LongStream;
5761

5862
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
5963
import static org.neo4j.gds.assertj.Extractors.removingThreadId;
6064
import static org.neo4j.gds.compat.TestLog.INFO;
65+
import static org.neo4j.kernel.database.NamedDatabaseId.NAMED_SYSTEM_DATABASE_ID;
6166

6267
@GdlExtension
6368
@ModelCatalogExtension
@@ -73,6 +78,9 @@ class GraphSageTest {
7378
@Inject
7479
private Graph orphanGraph;
7580

81+
@Inject
82+
private GraphStore orphanGraphStore;
83+
7684
@InjectModelCatalog
7785
private ModelCatalog modelCatalog;
7886

@@ -82,12 +90,13 @@ class GraphSageTest {
8290
private static final String MODEL_NAME = "graphSageModel";
8391

8492
private Graph graph;
93+
private GraphStore graphStore;
8594
private HugeObjectArray<double[]> features;
8695
private ImmutableGraphSageTrainConfig.Builder configBuilder;
8796

8897
@BeforeEach
8998
void setUp() {
90-
graph = RandomGraphGenerator.builder()
99+
HugeGraph randomGraph = RandomGraphGenerator.builder()
91100
.nodeCount(NODE_COUNT)
92101
.averageDegree(3)
93102
.nodeLabelProducer(nodeId -> NodeLabelTokens.of("P"))
@@ -100,7 +109,18 @@ void setUp() {
100109
.allowSelfLoops(RandomGraphGeneratorConfig.AllowSelfLoops.NO)
101110
.build().generate();
102111

112+
graph = randomGraph;
113+
103114
long nodeCount = graph.nodeCount();
115+
116+
graphStore = CSRGraphStoreUtil.createFromGraph(
117+
NAMED_SYSTEM_DATABASE_ID,
118+
randomGraph,
119+
"REL",
120+
Optional.of("weight"),
121+
4
122+
);
123+
104124
features = HugeObjectArray.newArray(double[].class, nodeCount);
105125

106126
Random random = new Random();
@@ -137,7 +157,7 @@ void shouldNotMakeNanEmbeddings(Aggregator.AggregatorType aggregator) {
137157
.build();
138158

139159
var graphSage = new GraphSageAlgorithmFactory<>(modelCatalog).build(
140-
orphanGraph,
160+
orphanGraphStore,
141161
streamConfig,
142162
ProgressTracker.NULL_TRACKER
143163
);
@@ -209,7 +229,7 @@ void testLogging() {
209229

210230
var log = Neo4jProxy.testLog();
211231
var graphSage = new GraphSageAlgorithmFactory<>(modelCatalog).build(
212-
graph,
232+
graphStore,
213233
streamConfig,
214234
log,
215235
EmptyTaskRegistryFactory.INSTANCE

proc/embeddings/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageCompanion.java

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import org.jetbrains.annotations.NotNull;
2323
import org.neo4j.gds.api.properties.nodes.DoubleArrayNodePropertyValues;
24-
import org.neo4j.gds.core.CypherMapWrapper;
2524
import org.neo4j.gds.core.model.Model;
2625
import org.neo4j.gds.core.model.ModelCatalog;
2726
import org.neo4j.gds.embeddings.graphsage.algo.GraphSage;
@@ -32,13 +31,8 @@
3231
import org.neo4j.gds.executor.GraphStoreValidation;
3332
import org.neo4j.gds.executor.validation.AfterLoadValidation;
3433
import org.neo4j.gds.executor.validation.ValidationConfiguration;
35-
import org.neo4j.gds.utils.StringFormatting;
3634

3735
import java.util.List;
38-
import java.util.Map;
39-
40-
import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY;
41-
import static org.neo4j.gds.model.ModelConfig.MODEL_NAME_KEY;
4236

4337
public final class GraphSageCompanion {
4438

@@ -83,29 +77,4 @@ public List<AfterLoadValidation<CONFIG>> afterLoadValidations() {
8377
};
8478
}
8579

86-
static Map<String, Object> getActualConfig(Object graphNameOrConfig, Map<String, Object> maybeConfig) {
87-
return graphNameOrConfig instanceof Map
88-
? (Map<String, Object>) graphNameOrConfig
89-
: maybeConfig;
90-
}
91-
92-
// FIXME:
93-
// For AlgoBaseProc to provide the correct graph, this needs to match with the `trainConfig.relationshipWeightProperty`.
94-
// Ideally we would resolve the graph from the GraphStore after the model was resolved. But as GraphSage also supports anonymous loading, this is not possible with the current AlgoBaseProc.
95-
// For now we resolve the model at proc level and set the corresponding relationshipWeightProperty (thus no default)
96-
public static void injectRelationshipWeightPropertyFromModel(Map<String, Object> configuration, ModelCatalog modelCatalog, String username) {
97-
if (configuration.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) {
98-
throw new IllegalArgumentException(StringFormatting.formatWithLocale(
99-
"The parameter `%s` cannot be overwritten during embedding computation. Instead, specify this parameter in the configuration of the model training.",
100-
RELATIONSHIP_WEIGHT_PROPERTY
101-
));
102-
}
103-
104-
String modelName = CypherMapWrapper.create(configuration).requireString(MODEL_NAME_KEY);
105-
106-
var trainProperty = GraphSageModelResolver
107-
.resolveModel(modelCatalog, username, modelName).trainConfig().relationshipWeightProperty();
108-
configuration.put(RELATIONSHIP_WEIGHT_PROPERTY, trainProperty);
109-
}
110-
11180
}

proc/embeddings/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageMutateProc.java

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds.embeddings.graphsage;
2121

2222
import org.neo4j.gds.AlgorithmFactory;
23-
import org.neo4j.gds.GraphAlgorithmFactory;
23+
import org.neo4j.gds.GraphStoreAlgorithmFactory;
2424
import org.neo4j.gds.MutatePropertyProc;
2525
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2626
import org.neo4j.gds.core.CypherMapWrapper;
@@ -44,9 +44,7 @@
4444
import java.util.stream.Stream;
4545

4646
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.GRAPHSAGE_DESCRIPTION;
47-
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.getActualConfig;
4847
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.getNodeProperties;
49-
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.injectRelationshipWeightPropertyFromModel;
5048
import static org.neo4j.gds.executor.ExecutionMode.MUTATE_NODE_PROPERTY;
5149
import static org.neo4j.procedure.Mode.READ;
5250

@@ -59,12 +57,6 @@ public Stream<MutateResult> mutate(
5957
@Name(value = "graphName") String graphName,
6058
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
6159
) {
62-
injectRelationshipWeightPropertyFromModel(
63-
getActualConfig(graphName, configuration),
64-
modelCatalog(),
65-
username.username()
66-
);
67-
6860
ComputationResult<GraphSage, GraphSage.GraphSageResult, GraphSageMutateConfig> computationResult = compute(
6961
graphName,
7062
configuration
@@ -78,12 +70,6 @@ public Stream<MemoryEstimateResult> estimate(
7870
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
7971
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
8072
) {
81-
injectRelationshipWeightPropertyFromModel(
82-
getActualConfig(graphNameOrConfiguration, algoConfiguration),
83-
modelCatalog(),
84-
username.username()
85-
);
86-
8773
return computeEstimate(graphNameOrConfiguration, algoConfiguration);
8874
}
8975

@@ -111,7 +97,7 @@ protected GraphSageMutateConfig newConfig(String username, CypherMapWrapper conf
11197
}
11298

11399
@Override
114-
public GraphAlgorithmFactory<GraphSage, GraphSageMutateConfig> algorithmFactory() {
100+
public GraphStoreAlgorithmFactory<GraphSage, GraphSageMutateConfig> algorithmFactory() {
115101
return new GraphSageAlgorithmFactory<>(modelCatalog());
116102
}
117103

proc/embeddings/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageStreamProc.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds.embeddings.graphsage;
2121

2222
import org.neo4j.gds.AlgorithmFactory;
23-
import org.neo4j.gds.GraphAlgorithmFactory;
23+
import org.neo4j.gds.GraphStoreAlgorithmFactory;
2424
import org.neo4j.gds.StreamProc;
2525
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2626
import org.neo4j.gds.core.CypherMapWrapper;
@@ -46,8 +46,6 @@
4646
import java.util.stream.Stream;
4747

4848
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.GRAPHSAGE_DESCRIPTION;
49-
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.getActualConfig;
50-
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.injectRelationshipWeightPropertyFromModel;
5149
import static org.neo4j.gds.executor.ExecutionMode.STREAM;
5250

5351
@GdsCallable(name = "gds.beta.graphSage.stream", description = GRAPHSAGE_DESCRIPTION, executionMode = STREAM)
@@ -59,8 +57,6 @@ public Stream<GraphSageStreamResult> stream(
5957
@Name(value = "graphName") String graphName,
6058
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
6159
) {
62-
injectRelationshipWeightPropertyFromModel(getActualConfig(graphName, configuration), modelCatalog(), username.username());
63-
6460
return stream(compute(graphName, configuration));
6561
}
6662

@@ -70,8 +66,6 @@ public Stream<MemoryEstimateResult> estimate(
7066
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
7167
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
7268
) {
73-
injectRelationshipWeightPropertyFromModel(getActualConfig(graphNameOrConfiguration, algoConfiguration), modelCatalog(), username.username());
74-
7569
return computeEstimate(graphNameOrConfiguration, algoConfiguration);
7670
}
7771

@@ -112,7 +106,7 @@ protected GraphSageStreamConfig newConfig(String username, CypherMapWrapper conf
112106
}
113107

114108
@Override
115-
public GraphAlgorithmFactory<GraphSage, GraphSageStreamConfig> algorithmFactory() {
109+
public GraphStoreAlgorithmFactory<GraphSage, GraphSageStreamConfig> algorithmFactory() {
116110
return new GraphSageAlgorithmFactory<>(modelCatalog());
117111
}
118112

proc/embeddings/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageWriteProc.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.neo4j.gds.embeddings.graphsage;
2121

2222
import org.neo4j.gds.AlgorithmFactory;
23-
import org.neo4j.gds.GraphAlgorithmFactory;
23+
import org.neo4j.gds.GraphStoreAlgorithmFactory;
2424
import org.neo4j.gds.WriteProc;
2525
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
2626
import org.neo4j.gds.core.CypherMapWrapper;
@@ -44,9 +44,7 @@
4444
import java.util.stream.Stream;
4545

4646
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.GRAPHSAGE_DESCRIPTION;
47-
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.getActualConfig;
4847
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.getNodeProperties;
49-
import static org.neo4j.gds.embeddings.graphsage.GraphSageCompanion.injectRelationshipWeightPropertyFromModel;
5048
import static org.neo4j.gds.executor.ExecutionMode.WRITE_NODE_PROPERTY;
5149

5250
@GdsCallable(name = "gds.beta.graphSage.write", description = GRAPHSAGE_DESCRIPTION, executionMode = WRITE_NODE_PROPERTY)
@@ -58,8 +56,6 @@ public Stream<GraphSageWriteResult> write(
5856
@Name(value = "graphName") String graphName,
5957
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
6058
) {
61-
injectRelationshipWeightPropertyFromModel(getActualConfig(graphName, configuration), modelCatalog(), username.username());
62-
6359
return write(compute(graphName, configuration));
6460
}
6561

@@ -69,8 +65,6 @@ public Stream<MemoryEstimateResult> estimate(
6965
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
7066
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
7167
) {
72-
injectRelationshipWeightPropertyFromModel(getActualConfig(graphNameOrConfiguration, algoConfiguration), modelCatalog(), username.username());
73-
7468
return computeEstimate(graphNameOrConfiguration, algoConfiguration);
7569
}
7670

@@ -93,7 +87,7 @@ protected GraphSageWriteConfig newConfig(String username, CypherMapWrapper confi
9387
}
9488

9589
@Override
96-
public GraphAlgorithmFactory<GraphSage, GraphSageWriteConfig> algorithmFactory() {
90+
public GraphStoreAlgorithmFactory<GraphSage, GraphSageWriteConfig> algorithmFactory() {
9791
return new GraphSageAlgorithmFactory<>(modelCatalog());
9892
}
9993

0 commit comments

Comments
 (0)