Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/channelize_poly_bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// printf("Benchmarking complex<double> -> complex<double>\n");
// ChannelizePolyBench<cuda::std::complex<double>,cuda::std::complex<double>>(channel_start, channel_stop);

matx::ClearCachesAndAllocations();

MATX_EXIT_HANDLER();
}
35 changes: 31 additions & 4 deletions include/matx/core/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ struct MemTracker {
iter->second.stream = stream;
}

// deallocate_internal assumes that the caller has already acquired the memory_mtx mutex.
template <typename StreamType>
auto deallocate_internal(void *ptr, [[maybe_unused]] StreamType st) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

[[maybe_unused]] std::unique_lock lck(memory_mtx);
auto iter = allocationMap.find(ptr);

if (iter == allocationMap.end()) {
Expand Down Expand Up @@ -159,10 +159,12 @@ struct MemTracker {
struct valid_stream_t { cudaStream_t stream; };

auto deallocate(void *ptr) {
[[maybe_unused]] std::unique_lock lck(memory_mtx);
deallocate_internal(ptr, no_stream_t{});
}

auto deallocate(void *ptr, cudaStream_t stream) {
[[maybe_unused]] std::unique_lock lck(memory_mtx);
deallocate_internal(ptr, valid_stream_t{stream});
}

Expand Down Expand Up @@ -256,11 +258,23 @@ struct MemTracker {
return MATX_INVALID_MEMORY;
}

~MemTracker() {
while (allocationMap.size()) {
deallocate(allocationMap.begin()->first);
void free_all() {
[[maybe_unused]] std::unique_lock lck(memory_mtx);
while (! allocationMap.empty()) {
auto it = allocationMap.begin();
const auto ptr = it->first;
deallocate_internal(ptr, no_stream_t{});
if (allocationMap.find(ptr) != allocationMap.end()) {
// deallocate_internal may have erased the pointer from the map
// If not, erase it here to avoid an infinite loop.
allocationMap.erase(ptr);
}
}
}

~MemTracker() {
free_all();
}
};


Expand All @@ -271,6 +285,19 @@ __MATX_INLINE__ MemTracker &GetAllocMap() {
return tracker;
}

// Helper function to free all MatX allocations. This function frees all allocations
// made with matxAlloc. These allocations may have been made directly by the user or they
// may have been made by MatX internally for workspaces. This function does not free the
// caches (i.e., allocations made for FFT plans, cuBLAS handles, and other state required
// for MatX transforms). To free those caches, use matx::ClearCaches(). It is not safe to
// call matxFree() on user-managed pointers after calling this function. This function should
// be called after the user application has called matxFree() on any pointers for which it
// will call matxFree().
__attribute__ ((visibility ("default")))
__MATX_INLINE__ void FreeAllocations() {
GetAllocMap().free_all();
}

/**
* @brief Determine if a pointer is printable by the host
*
Expand Down
84 changes: 78 additions & 6 deletions include/matx/core/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ struct LTOIRData {

static constexpr size_t MAX_CUDA_DEVICES_PER_SYSTEM = 16;
using CacheId = uint64_t;
struct CacheFreeHelper {
void (*free)(std::any&);
};

// Common cache parameters that every cache entry needs
struct CacheCommonParamsKey {
Expand All @@ -118,13 +121,34 @@ __attribute__ ((visibility ("default")))
inline cuda::std::atomic<CacheId> CacheIdCounter{0};
inline std::recursive_mutex cache_mtx; ///< Mutex protecting updates from map
inline std::recursive_mutex ltoir_mutex; ///< Mutex protecting LTOIR cache operations
inline std::recursive_mutex stream_alloc_mutex; ///< Mutex protecting stream allocation cache operations

inline auto& CacheRegistry() {
// Protected by cache_mtx
static std::unordered_map<CacheId, CacheFreeHelper> registry;
return registry;
}

template<typename CacheType>
__attribute__ ((visibility ("default")))
CacheId GetCacheIdFromType()
{
static CacheId id = CacheIdCounter.fetch_add(1);

[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(cache_mtx);
auto &registry = CacheRegistry();
if (registry.find(id) != registry.end()) {
// Registry already contains this ID, so no need to insert it again
// with its CacheFreHelper.
return id;
}
registry.emplace(id, CacheFreeHelper{
.free = [](std::any& any) -> void {
using CacheMap = std::unordered_map<CacheCommonParamsKey, CacheType, CacheCommonParamsKeyHash>;
// This clear is the unordered_map's clear, which will ultimately call the
// destructors of the cache entries.
std::any_cast<CacheMap&>(any).clear();
},
});
return id;
}

Expand All @@ -144,10 +168,7 @@ class matxCache_t {
public:
matxCache_t() {}
~matxCache_t() {
// Destroy all outstanding objects in the cache to free memory
for (auto &[k, v]: cache) {
v.reset();
}
ClearAll();
}

/**
Expand All @@ -165,6 +186,38 @@ class matxCache_t {
std::any_cast<CacheMap&>(el->second).clear();
}

void ClearAll() {
// Clear all cache entries for all cache types
{
[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(cache_mtx);
for (auto &[id, v]: cache) {
auto entry = CacheRegistry().find(id);
if (entry == CacheRegistry().end()) {
continue;
}
auto &info = entry->second;
info.free(v);
}
cache.clear();
}
{
[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(stream_alloc_mutex);
for (auto &[outer_key, inner_map]: stream_alloc_cache) {
for (auto &[inner_key, value]: inner_map) {
if (value.ptr) {
matxFree(value.ptr);
}
}
inner_map.clear();
}
stream_alloc_cache.clear();
}
{
[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(ltoir_mutex);
ltoir_cache.clear();
}
}

template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun, typename Executor>
void LookupAndExec(const CacheId &id, const InParams &params, const MakeFun &mfun, const ExecFun &efun, [[maybe_unused]] const Executor &exec) {
// This mutex should eventually be finer-grained so each transform doesn't get blocked by others
Expand Down Expand Up @@ -211,6 +264,8 @@ class matxCache_t {
key.thread_id = std::this_thread::get_id();
cudaGetDevice(&key.device_id);

[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(stream_alloc_mutex);

auto &common_params_cache = stream_alloc_cache[key];
auto el = common_params_cache.find(stream);
if (el == common_params_cache.end()) {
Expand Down Expand Up @@ -689,8 +744,25 @@ __MATX_INLINE__ matxCache_t &GetCache() {
return InitCache();
}

} // namespace detail

// Helper function to free all MatX caches. This function frees caches created for
// FFT plans, cuBLAS handles, and other state required for MatX transforms. This
// function does not clear the allocator cache (i.e., allocations made with matxAlloc
// other than those created to support transforms).
// To free the allocator cache, use matx::FreeAllocations().
__attribute__ ((visibility ("default")))
__MATX_INLINE__ void ClearCaches() {
detail::GetCache().ClearAll();
}

// Helper function to clear both MatX caches and allocations. This provides a single
// function that can be called prior to program exit to support clean shutdown
// (i.e., to avoid issues with the order of destruction of static objects and CUDA contexts).
__attribute__ ((visibility ("default")))
__MATX_INLINE__ void ClearCachesAndAllocations() {
ClearCaches();
FreeAllocations();
}

} // namespace detail
}; // namespace matx
86 changes: 86 additions & 0 deletions test/00_misc/ClearCacheTests.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2021, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#include "assert.h"
#include "matx.h"
#include "test_types.h"
#include "utilities.h"
#include "gtest/gtest.h"
#include <iostream>
#include <vector>
#include <unordered_map>

using namespace matx;

TEST(ClearCacheTests, TestCase) {
MATX_ENTER_HANDLER();

size_t initial_free_mem = 0;
size_t total_mem = 0;
cudaError_t err = cudaMemGetInfo(&initial_free_mem, &total_mem);
ASSERT_EQ(err, cudaSuccess);

// The cuBLAS handle will allocate an associated workspace of 4 MiB on pre-Hopper and
// 32 MiB on Hopper+.
{
auto c = matx::make_tensor<float, 2>({1024, 1024});
auto a = matx::make_tensor<float, 2>({1024, 1024});
auto b = matx::make_tensor<float, 2>({1024, 1024});
(c = matx::matmul(a, b)).run();
cudaDeviceSynchronize();
}

// Manually allocate 4 MiB
const size_t four_MiB = 4 * 1024 * 1024;
void *ptr;
matxAlloc(&ptr, four_MiB, MATX_DEVICE_MEMORY);

size_t post_alloc_free_mem = 0;
err = cudaMemGetInfo(&post_alloc_free_mem, &total_mem);
ASSERT_EQ(err, cudaSuccess);

matx::ClearCachesAndAllocations();

size_t post_clear_free_mem = 0;
err = cudaMemGetInfo(&post_clear_free_mem, &total_mem);
ASSERT_EQ(err, cudaSuccess);

const ssize_t allocated = static_cast<ssize_t>(initial_free_mem) - static_cast<ssize_t>(post_alloc_free_mem);
const ssize_t freed = static_cast<ssize_t>(post_clear_free_mem) - static_cast<ssize_t>(post_alloc_free_mem);

// The cuBLAS cache and allocator data structure should have allocated at least 8 MiB
// in total and thus at least 8 MiB should be freed when clearing the caches/allocations.
ASSERT_GE(allocated, 2 * four_MiB);
ASSERT_GE(freed, 2 * four_MiB);

MATX_EXIT_HANDLER();
}
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ list(TRANSFORM OPERATOR_TEST_FILES PREPEND "00_operators/")

set (test_sources
00_misc/AllocatorTests.cu
00_misc/ClearCacheTests.cu
00_misc/ProfilingTests.cu
00_tensor/BasicTensorTests.cu
00_tensor/CUBTests.cu
Expand Down Expand Up @@ -141,6 +142,8 @@ endforeach()
# Number of test jobs to run in parallel
set(CTEST_PARALLEL_JOBS 4)

set_tests_properties(test_00_misc_ClearCacheTests PROPERTIES RUN_SERIAL TRUE)

# Create a legacy matx_test script for CI compatibility
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/matx_test.sh
Expand Down
8 changes: 7 additions & 1 deletion test/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,16 @@
#include "gtest/gtest.h"
#include <pybind11/embed.h>

#include "matx.h"

int main(int argc, char **argv)
{
printf("Running MatX unit tests. Press Ctrl+\\ (SIGQUIT) to kill tests\n");

::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
const int result = RUN_ALL_TESTS();

matx::ClearCachesAndAllocations();

return result;
}