Skip to content

Commit 445a418

Browse files
Speed up RAAPTOR implementation
1 parent 96a51a9 commit 445a418

File tree

4 files changed

+1104
-57
lines changed

4 files changed

+1104
-57
lines changed

Algorithms/CSA/CSA.h

Lines changed: 201 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
namespace CSA {
1919

20-
template<bool PATH_RETRIEVAL = true, int ENABLE_PRUNING = 0, typename PROFILER = NoProfiler>
20+
template<bool PATH_RETRIEVAL = true, typename PROFILER = NoProfiler>
2121
class CSA {
2222

2323
public:
2424
constexpr static bool PathRetrieval = PATH_RETRIEVAL;
2525
using Profiler = PROFILER;
26-
constexpr static int EnablePruning = ENABLE_PRUNING;
27-
using Type = CSA<PathRetrieval, EnablePruning, Profiler>;
26+
using Type = CSA<PathRetrieval, Profiler>;
2827
using TripFlag = std::conditional_t<PathRetrieval, ConnectionId, bool>;
2928

3029
private:
@@ -136,11 +135,6 @@ class CSA {
136135
const Connection& connection = data.connections[i];
137136
if (targetStop != noStop && connection.departureTime > arrivalTime[targetStop]) break;
138137
if (connectionIsReachable(connection, i)) {
139-
if constexpr (EnablePruning == 1) {
140-
if (connection.arrivalTime > arrivalTime[targetStop]) {
141-
continue;
142-
}
143-
}
144138
profiler.countMetric(METRIC_CONNECTIONS);
145139
arrivalByTrip(connection.arrivalStopId, connection.arrivalTime, connection.tripId);
146140
}
@@ -186,10 +180,206 @@ class CSA {
186180
profiler.countMetric(METRIC_EDGES);
187181
const StopId toStop = StopId(data.transferGraph.get(ToVertex, edge));
188182
const int newArrivalTime = time + data.transferGraph.get(TravelTime, edge);
189-
if constexpr (EnablePruning == 1) {
190-
if (newArrivalTime > arrivalTime[targetStop]) {
191-
break;
183+
arrivalByTransfer(toStop, newArrivalTime, stop, edge);
184+
}
185+
}
186+
187+
inline void arrivalByTransfer(const StopId stop, const int time, const StopId parent, const Edge edge) noexcept {
188+
if (arrivalTime[stop] <= time) return;
189+
profiler.countMetric(METRIC_STOPS_BY_TRANSFER);
190+
arrivalTime[stop] = time;
191+
if constexpr (PathRetrieval) {
192+
parentLabel[stop].parent = parent;
193+
parentLabel[stop].reachedByTransfer = true;
194+
parentLabel[stop].transferId = edge;
195+
}
196+
}
197+
198+
private:
199+
const Data& data;
200+
201+
StopId sourceStop;
202+
StopId targetStop;
203+
204+
std::vector<TripFlag> tripReached;
205+
std::vector<int> arrivalTime;
206+
std::vector<ParentLabel> parentLabel;
207+
208+
Profiler profiler;
209+
};
210+
template<bool PATH_RETRIEVAL = true, typename PROFILER = NoProfiler>
211+
class CSA_prune {
212+
213+
public:
214+
constexpr static bool PathRetrieval = PATH_RETRIEVAL;
215+
using Profiler = PROFILER;
216+
using Type = CSA_prune<PathRetrieval, Profiler>;
217+
using TripFlag = std::conditional_t<PathRetrieval, ConnectionId, bool>;
218+
219+
private:
220+
struct ParentLabel {
221+
ParentLabel(const StopId parent = noStop, const bool reachedByTransfer = false, const TripId tripId = noTripId) :
222+
parent(parent),
223+
reachedByTransfer(reachedByTransfer),
224+
tripId(tripId) {
225+
}
226+
227+
StopId parent;
228+
bool reachedByTransfer;
229+
union {
230+
TripId tripId;
231+
Edge transferId;
232+
};
233+
};
234+
235+
public:
236+
CSA_prune(const Data& data, const Profiler& profilerTemplate = Profiler()) :
237+
data(data),
238+
sourceStop(noStop),
239+
targetStop(noStop),
240+
tripReached(data.numberOfTrips(), TripFlag()),
241+
arrivalTime(data.numberOfStops(), never),
242+
parentLabel(PathRetrieval ? data.numberOfStops() : 0),
243+
profiler(profilerTemplate) {
244+
Assert(Vector::isSorted(data.connections), "Connections must be sorted in ascending order!");
245+
profiler.registerPhases({PHASE_CLEAR, PHASE_INITIALIZATION, PHASE_CONNECTION_SCAN});
246+
profiler.registerMetrics({METRIC_CONNECTIONS, METRIC_EDGES, METRIC_STOPS_BY_TRIP, METRIC_STOPS_BY_TRANSFER});
247+
profiler.initialize();
248+
}
249+
250+
inline void run(const StopId source, const int departureTime, const StopId target = noStop) noexcept {
251+
profiler.start();
252+
253+
profiler.startPhase();
254+
Assert(data.isStop(source), "Source stop " << source << " is not a valid stop!");
255+
clear();
256+
profiler.donePhase(PHASE_CLEAR);
257+
258+
profiler.startPhase();
259+
sourceStop = source;
260+
targetStop = target;
261+
arrivalTime[sourceStop] = departureTime;
262+
const int targetArrivalTime = (targetStop != noStop)
263+
? arrivalTime[targetStop]
264+
: never;
265+
relaxEdges(sourceStop, departureTime, targetArrivalTime);
266+
const ConnectionId firstConnection = firstReachableConnection(departureTime);
267+
profiler.donePhase(PHASE_INITIALIZATION);
268+
269+
profiler.startPhase();
270+
scanConnections(firstConnection, ConnectionId(data.connections.size()));
271+
profiler.donePhase(PHASE_CONNECTION_SCAN);
272+
273+
profiler.done();
274+
}
275+
276+
inline bool reachable(const StopId stop) const noexcept {
277+
return arrivalTime[stop] < never;
278+
}
279+
280+
inline int getEarliestArrivalTime(const StopId stop) const noexcept {
281+
return arrivalTime[stop];
282+
}
283+
284+
inline Journey getJourney() const noexcept requires PathRetrieval {
285+
return getJourney(targetStop);
286+
}
287+
288+
inline Journey getJourney(StopId stop) const noexcept requires PathRetrieval {
289+
Journey journey;
290+
if (!reachable(stop)) return journey;
291+
while (stop != sourceStop) {
292+
const ParentLabel& label = parentLabel[stop];
293+
if (label.reachedByTransfer) {
294+
const int travelTime = data.transferGraph.get(TravelTime, label.transferId);
295+
journey.emplace_back(label.parent, stop, arrivalTime[stop] - travelTime, arrivalTime[stop], label.transferId);
296+
} else {
297+
journey.emplace_back(label.parent, stop, data.connections[tripReached[label.tripId]].departureTime, arrivalTime[stop], label.tripId);
298+
}
299+
stop = label.parent;
300+
}
301+
Vector::reverse(journey);
302+
return journey;
303+
}
304+
305+
inline const Profiler& getProfiler() const noexcept {
306+
return profiler;
307+
}
308+
309+
private:
310+
inline void clear() {
311+
sourceStop = noStop;
312+
targetStop = noStop;
313+
Vector::fill(arrivalTime, never);
314+
Vector::fill(tripReached, TripFlag());
315+
if constexpr (PathRetrieval) {
316+
Vector::fill(parentLabel, ParentLabel());
317+
}
318+
}
319+
320+
inline ConnectionId firstReachableConnection(const int departureTime) const noexcept {
321+
return ConnectionId(Vector::lowerBound(data.connections, departureTime, [](const Connection& connection, const int time) {
322+
return connection.departureTime < time;
323+
}));
324+
}
325+
326+
inline void scanConnections(const ConnectionId begin, const ConnectionId end) noexcept {
327+
for (ConnectionId i = begin; i < end; i++) {
328+
329+
const Connection& connection = data.connections[i];
330+
if (targetStop != noStop && connection.departureTime > arrivalTime[targetStop]) break;
331+
if (connectionIsReachable(connection, i)) {
332+
if (connection.arrivalTime > arrivalTime[targetStop]) {
333+
continue;
192334
}
335+
336+
profiler.countMetric(METRIC_CONNECTIONS);
337+
arrivalByTrip(connection.arrivalStopId, connection.arrivalTime, connection.tripId, arrivalTime[targetStop]);
338+
}
339+
}
340+
}
341+
342+
inline bool connectionIsReachableFromStop(const Connection& connection) const noexcept {
343+
return arrivalTime[connection.departureStopId] <= connection.departureTime - data.minTransferTime(connection.departureStopId);
344+
}
345+
346+
inline bool connectionIsReachableFromTrip(const Connection& connection) const noexcept {
347+
return tripReached[connection.tripId] != TripFlag();
348+
}
349+
350+
inline bool connectionIsReachable(const Connection& connection, const ConnectionId id) noexcept {
351+
if (connectionIsReachableFromTrip(connection)) return true;
352+
if (connectionIsReachableFromStop(connection)) {
353+
if constexpr (PathRetrieval) {
354+
tripReached[connection.tripId] = id;
355+
} else {
356+
suppressUnusedParameterWarning(id);
357+
tripReached[connection.tripId] = true;
358+
}
359+
return true;
360+
}
361+
return false;
362+
}
363+
364+
inline void arrivalByTrip(const StopId stop, const int time, const TripId trip, const int targetArrivalTime) noexcept {
365+
if (arrivalTime[stop] <= time) return;
366+
profiler.countMetric(METRIC_STOPS_BY_TRIP);
367+
arrivalTime[stop] = time;
368+
if constexpr (PathRetrieval) {
369+
parentLabel[stop].parent = data.connections[tripReached[trip]].departureStopId;
370+
parentLabel[stop].reachedByTransfer = false;
371+
parentLabel[stop].tripId = trip;
372+
}
373+
relaxEdges(stop, time, targetArrivalTime);
374+
}
375+
376+
inline void relaxEdges(const StopId stop, const int time, const int targetArrivalTime) noexcept {
377+
for (const Edge edge : data.transferGraph.edgesFrom(stop)) {
378+
profiler.countMetric(METRIC_EDGES);
379+
const StopId toStop = StopId(data.transferGraph.get(ToVertex, edge));
380+
const int newArrivalTime = time + data.transferGraph.get(TravelTime, edge);
381+
if (newArrivalTime > targetArrivalTime) {
382+
break;
193383
}
194384
arrivalByTransfer(toStop, newArrivalTime, stop, edge);
195385
}

0 commit comments

Comments
 (0)