From 27d2268b6e59263d0ec4792d0f3d504c6e246571 Mon Sep 17 00:00:00 2001 From: reus Date: Mon, 14 Jul 2025 14:41:06 +0800 Subject: [PATCH] colexec/group: add spill --- pkg/fileservice/file_service_test.go | 3 + pkg/fileservice/sub_path.go | 40 +- pkg/sql/colexec/aggexec/aggFrame_test.go | 26 +- pkg/sql/colexec/aggexec/approx_count.go | 8 +- pkg/sql/colexec/aggexec/concat.go | 4 +- pkg/sql/colexec/aggexec/count.go | 8 +- pkg/sql/colexec/aggexec/fromBytesRetBytes.go | 4 +- pkg/sql/colexec/aggexec/fromBytesRetFixed.go | 4 +- pkg/sql/colexec/aggexec/fromFixedRetBytes.go | 4 +- pkg/sql/colexec/aggexec/fromFixedRetFixed.go | 4 +- pkg/sql/colexec/aggexec/median.go | 4 +- pkg/sql/colexec/aggexec/serialize.go | 4 +- pkg/sql/colexec/aggexec/types.go | 4 +- pkg/sql/colexec/aggexec/window.go | 4 +- pkg/sql/colexec/group/container.go | 28 + pkg/sql/colexec/group/exec.go | 42 + pkg/sql/colexec/group/exec_test.go | 42 +- pkg/sql/colexec/group/execctx.go | 35 +- pkg/sql/colexec/group/group.go | 370 +++++++ pkg/sql/colexec/group/group_result_buffer.go | 28 + .../colexec/group/group_result_none_block.go | 31 + pkg/sql/colexec/group/group_test.go | 140 +++ pkg/sql/colexec/group/hashmap.go | 104 ++ pkg/sql/colexec/group/spill_manager.go | 586 +++++++++++ pkg/sql/colexec/group/spill_manager_test.go | 250 +++++ pkg/sql/colexec/group/spill_test.go | 949 ++++++++++++++++++ pkg/sql/colexec/group/testspill/spill_test.go | 101 ++ pkg/sql/colexec/group/types.go | 13 +- pkg/testutil/util_compare.go | 37 + 29 files changed, 2829 insertions(+), 48 deletions(-) create mode 100644 pkg/sql/colexec/group/container.go create mode 100644 pkg/sql/colexec/group/group.go create mode 100644 pkg/sql/colexec/group/group_result_buffer.go create mode 100644 pkg/sql/colexec/group/group_result_none_block.go create mode 100644 pkg/sql/colexec/group/group_test.go create mode 100644 pkg/sql/colexec/group/hashmap.go create mode 100644 pkg/sql/colexec/group/spill_manager.go create mode 100644 pkg/sql/colexec/group/spill_manager_test.go create mode 100644 pkg/sql/colexec/group/spill_test.go create mode 100644 pkg/sql/colexec/group/testspill/spill_test.go diff --git a/pkg/fileservice/file_service_test.go b/pkg/fileservice/file_service_test.go index 38f8f2d6306b7..30d02d005cc28 100644 --- a/pkg/fileservice/file_service_test.go +++ b/pkg/fileservice/file_service_test.go @@ -1032,6 +1032,9 @@ func testFileService( } w, err := rwFS.NewWriter(ctx, "foo") + if moerr.IsMoErrCode(err, moerr.ErrNotSupported) { + return + } assert.Nil(t, err) assert.NotNil(t, w) _, err = w.Write([]byte("foobarbaz")) diff --git a/pkg/fileservice/sub_path.go b/pkg/fileservice/sub_path.go index 1d4360d786eaa..abf090c222e69 100644 --- a/pkg/fileservice/sub_path.go +++ b/pkg/fileservice/sub_path.go @@ -17,20 +17,24 @@ package fileservice import ( "context" "fmt" + "io" "iter" "path" "strings" + + "github.com/matrixorigin/matrixone/pkg/common/moerr" ) type subPathFS struct { - upstream FileService - path string - name string + upstream FileService + rwUpstream ReaderWriterFileService + path string + name string } // SubPath returns a FileService instance that operates at specified sub path of the upstream instance func SubPath(upstream FileService, path string) FileService { - return &subPathFS{ + ret := &subPathFS{ upstream: upstream, path: path, name: strings.Join([]string{ @@ -39,9 +43,13 @@ func SubPath(upstream FileService, path string) FileService { path, }, ","), } + if rwfs, ok := upstream.(ReaderWriterFileService); ok { + ret.rwUpstream = rwfs + } + return ret } -var _ FileService = new(subPathFS) +var _ ReaderWriterFileService = new(subPathFS) func (s *subPathFS) Name() string { return s.name @@ -143,6 +151,28 @@ func (s *subPathFS) Cost() *CostAttr { return s.upstream.Cost() } +func (s *subPathFS) NewReader(ctx context.Context, filePath string) (io.ReadCloser, error) { + p, err := s.toUpstreamPath(filePath) + if err != nil { + return nil, err + } + if s.rwUpstream != nil { + return s.rwUpstream.NewReader(ctx, p) + } + return nil, moerr.NewNotSupportedNoCtx("not ReaderWriterFileService") +} + +func (s *subPathFS) NewWriter(ctx context.Context, filePath string) (io.WriteCloser, error) { + p, err := s.toUpstreamPath(filePath) + if err != nil { + return nil, err + } + if s.rwUpstream != nil { + return s.rwUpstream.NewWriter(ctx, p) + } + return nil, moerr.NewNotSupportedNoCtx("not ReaderWriterFileService") +} + var _ MutableFileService = new(subPathFS) func (s *subPathFS) NewMutator(ctx context.Context, filePath string) (Mutator, error) { diff --git a/pkg/sql/colexec/aggexec/aggFrame_test.go b/pkg/sql/colexec/aggexec/aggFrame_test.go index f5f08d5a313e3..4c808da057f0b 100644 --- a/pkg/sql/colexec/aggexec/aggFrame_test.go +++ b/pkg/sql/colexec/aggexec/aggFrame_test.go @@ -699,12 +699,12 @@ func TestGroupConcatExecMarshalUnmarshal(t *testing.T) { require.NoError(t, vector.AppendBytes(v1, []byte("test1"), false, m.Mp())) require.NoError(t, exec.Fill(0, 0, []*vector.Vector{v1})) - data, err := exec.marshal() + data, err := exec.Marshal() require.NoError(t, err) require.NotNil(t, data) newExec := newGroupConcatExec(m, info, ",") - err = newExec.unmarshal(m.Mp(), nil, nil, [][]byte{[]byte(",")}) + err = newExec.Unmarshal(m.Mp(), nil, nil, [][]byte{[]byte(",")}) require.NoError(t, err) require.Equal(t, []byte(","), newExec.(*groupConcatExec).separator) @@ -722,14 +722,14 @@ func TestGroupConcatExecMarshalUnmarshal(t *testing.T) { } exec := newGroupConcatExec(m, info, "|") - data, err := exec.marshal() + data, err := exec.Marshal() require.NoError(t, err) newExec := newGroupConcatExec(m, info, ",") encoded := &EncodedAgg{} require.NoError(t, encoded.Unmarshal(data)) - err = newExec.unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) + err = newExec.Unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) require.NoError(t, err) require.Equal(t, []byte("|"), newExec.(*groupConcatExec).separator) @@ -753,14 +753,14 @@ func TestGroupConcatExecMarshalUnmarshal(t *testing.T) { require.NoError(t, vector.AppendBytes(v1, []byte("distinct1"), false, m.Mp())) require.NoError(t, exec.Fill(0, 0, []*vector.Vector{v1})) - data, err := exec.marshal() + data, err := exec.Marshal() require.NoError(t, err) newExec := newGroupConcatExec(m, info, ",") encoded := &EncodedAgg{} require.NoError(t, encoded.Unmarshal(data)) - err = newExec.unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) + err = newExec.Unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) require.NoError(t, err) exec.Free() @@ -803,14 +803,14 @@ func TestCountColumnExecMarshalUnmarshal(t *testing.T) { require.NoError(t, vector.AppendFixedList(v1, []int64{1, 2, 3}, nil, m.Mp())) require.NoError(t, exec.BulkFill(0, []*vector.Vector{v1})) - data, err := exec.marshal() + data, err := exec.Marshal() require.NoError(t, err) newExec := newCountColumnExecExec(m, info) encoded := &EncodedAgg{} require.NoError(t, encoded.Unmarshal(data)) - err = newExec.unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) + err = newExec.Unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) require.NoError(t, err) exec.Free() @@ -833,14 +833,14 @@ func TestCountColumnExecMarshalUnmarshal(t *testing.T) { require.NoError(t, vector.AppendFixedList(v1, []int64{1, 2, 1, 3}, nil, m.Mp())) require.NoError(t, exec.BulkFill(0, []*vector.Vector{v1})) - data, err := exec.marshal() + data, err := exec.Marshal() require.NoError(t, err) newExec := newCountColumnExecExec(m, info) encoded := &EncodedAgg{} require.NoError(t, encoded.Unmarshal(data)) - err = newExec.unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) + err = newExec.Unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) require.NoError(t, err) exec.Free() @@ -858,14 +858,14 @@ func TestCountColumnExecMarshalUnmarshal(t *testing.T) { exec := newCountColumnExecExec(m, info) require.NoError(t, exec.GroupGrow(1)) - data, err := exec.marshal() + data, err := exec.Marshal() require.NoError(t, err) newExec := newCountColumnExecExec(m, info) encoded := &EncodedAgg{} require.NoError(t, encoded.Unmarshal(data)) - err = newExec.unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) + err = newExec.Unmarshal(m.Mp(), encoded.Result, encoded.Empties, encoded.Groups) require.NoError(t, err) exec.Free() @@ -882,7 +882,7 @@ func TestCountColumnExecMarshalUnmarshal(t *testing.T) { } exec := newCountColumnExecExec(m, info) - err := exec.unmarshal(m.Mp(), nil, nil, [][]byte{}) + err := exec.Unmarshal(m.Mp(), nil, nil, [][]byte{}) require.NoError(t, err) exec.Free() diff --git a/pkg/sql/colexec/aggexec/approx_count.go b/pkg/sql/colexec/aggexec/approx_count.go index fac4af9343223..c0a28e9c9b97c 100644 --- a/pkg/sql/colexec/aggexec/approx_count.go +++ b/pkg/sql/colexec/aggexec/approx_count.go @@ -35,7 +35,7 @@ func (exec *approxCountFixedExec[T]) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *approxCountFixedExec[T]) marshal() ([]byte, error) { +func (exec *approxCountFixedExec[T]) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -60,7 +60,7 @@ func (exec *approxCountFixedExec[T]) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *approxCountFixedExec[T]) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *approxCountFixedExec[T]) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { err := exec.ret.unmarshalFromBytes(result, empties) if err != nil { return err @@ -90,7 +90,7 @@ func (exec *approxCountVarExec) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *approxCountVarExec) marshal() ([]byte, error) { +func (exec *approxCountVarExec) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -115,7 +115,7 @@ func (exec *approxCountVarExec) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *approxCountVarExec) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *approxCountVarExec) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { err := exec.ret.unmarshalFromBytes(result, empties) if err != nil { return err diff --git a/pkg/sql/colexec/aggexec/concat.go b/pkg/sql/colexec/aggexec/concat.go index 2fd5c865a40ec..be7d3a461da0a 100644 --- a/pkg/sql/colexec/aggexec/concat.go +++ b/pkg/sql/colexec/aggexec/concat.go @@ -37,7 +37,7 @@ func (exec *groupConcatExec) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *groupConcatExec) marshal() ([]byte, error) { +func (exec *groupConcatExec) Marshal() ([]byte, error) { d := exec.multiAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -59,7 +59,7 @@ func (exec *groupConcatExec) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *groupConcatExec) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *groupConcatExec) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { if err := exec.SetExtraInformation(groups[0], 0); err != nil { return err } diff --git a/pkg/sql/colexec/aggexec/count.go b/pkg/sql/colexec/aggexec/count.go index 53c0badfb737a..b767457a7b91f 100644 --- a/pkg/sql/colexec/aggexec/count.go +++ b/pkg/sql/colexec/aggexec/count.go @@ -39,7 +39,7 @@ func (exec *countColumnExec) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *countColumnExec) marshal() ([]byte, error) { +func (exec *countColumnExec) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -63,7 +63,7 @@ func (exec *countColumnExec) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *countColumnExec) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *countColumnExec) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { if exec.IsDistinct() { if len(groups) > 0 { if err := exec.distinctHash.unmarshal(groups[0]); err != nil { @@ -286,7 +286,7 @@ func (exec *countStarExec) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *countStarExec) marshal() ([]byte, error) { +func (exec *countStarExec) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -301,7 +301,7 @@ func (exec *countStarExec) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *countStarExec) unmarshal(_ *mpool.MPool, result, empties, _ [][]byte) error { +func (exec *countStarExec) Unmarshal(_ *mpool.MPool, result, empties, _ [][]byte) error { return exec.ret.unmarshalFromBytes(result, empties) } diff --git a/pkg/sql/colexec/aggexec/fromBytesRetBytes.go b/pkg/sql/colexec/aggexec/fromBytesRetBytes.go index 7cf43bf871c82..dd5bbe9a7805d 100644 --- a/pkg/sql/colexec/aggexec/fromBytesRetBytes.go +++ b/pkg/sql/colexec/aggexec/fromBytesRetBytes.go @@ -100,7 +100,7 @@ func (exec *aggregatorFromBytesToBytes) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *aggregatorFromBytesToBytes) marshal() ([]byte, error) { +func (exec *aggregatorFromBytesToBytes) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -115,7 +115,7 @@ func (exec *aggregatorFromBytesToBytes) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *aggregatorFromBytesToBytes) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *aggregatorFromBytesToBytes) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { exec.execContext.decodeGroupContexts(groups, exec.singleAggInfo.retType, exec.singleAggInfo.argType) return exec.ret.unmarshalFromBytes(result, empties) } diff --git a/pkg/sql/colexec/aggexec/fromBytesRetFixed.go b/pkg/sql/colexec/aggexec/fromBytesRetFixed.go index 0b3e1c5bdd936..e0e73c499dc11 100644 --- a/pkg/sql/colexec/aggexec/fromBytesRetFixed.go +++ b/pkg/sql/colexec/aggexec/fromBytesRetFixed.go @@ -198,7 +198,7 @@ func (exec *aggregatorFromBytesToFixed[to]) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *aggregatorFromBytesToFixed[to]) marshal() ([]byte, error) { +func (exec *aggregatorFromBytesToFixed[to]) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -213,7 +213,7 @@ func (exec *aggregatorFromBytesToFixed[to]) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *aggregatorFromBytesToFixed[to]) unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *aggregatorFromBytesToFixed[to]) Unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error { exec.execContext.decodeGroupContexts(groups, exec.singleAggInfo.retType, exec.singleAggInfo.argType) return exec.ret.unmarshalFromBytes(result, empties) } diff --git a/pkg/sql/colexec/aggexec/fromFixedRetBytes.go b/pkg/sql/colexec/aggexec/fromFixedRetBytes.go index fda4318cc8b8b..d40206d099f48 100644 --- a/pkg/sql/colexec/aggexec/fromFixedRetBytes.go +++ b/pkg/sql/colexec/aggexec/fromFixedRetBytes.go @@ -217,7 +217,7 @@ func (exec *aggregatorFromFixedToBytes[from]) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *aggregatorFromFixedToBytes[from]) marshal() ([]byte, error) { +func (exec *aggregatorFromFixedToBytes[from]) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -232,7 +232,7 @@ func (exec *aggregatorFromFixedToBytes[from]) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *aggregatorFromFixedToBytes[from]) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *aggregatorFromFixedToBytes[from]) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { exec.execContext.decodeGroupContexts(groups, exec.singleAggInfo.retType, exec.singleAggInfo.argType) return exec.ret.unmarshalFromBytes(result, empties) } diff --git a/pkg/sql/colexec/aggexec/fromFixedRetFixed.go b/pkg/sql/colexec/aggexec/fromFixedRetFixed.go index 2d3520b16039c..c8c054b2ad72b 100644 --- a/pkg/sql/colexec/aggexec/fromFixedRetFixed.go +++ b/pkg/sql/colexec/aggexec/fromFixedRetFixed.go @@ -271,7 +271,7 @@ func (exec *aggregatorFromFixedToFixed[from, to]) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *aggregatorFromFixedToFixed[from, to]) marshal() ([]byte, error) { +func (exec *aggregatorFromFixedToFixed[from, to]) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -287,7 +287,7 @@ func (exec *aggregatorFromFixedToFixed[from, to]) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *aggregatorFromFixedToFixed[from, to]) unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *aggregatorFromFixedToFixed[from, to]) Unmarshal(_ *mpool.MPool, result, empties, groups [][]byte) error { exec.execContext.decodeGroupContexts(groups, exec.singleAggInfo.retType, exec.singleAggInfo.argType) return exec.ret.unmarshalFromBytes(result, empties) } diff --git a/pkg/sql/colexec/aggexec/median.go b/pkg/sql/colexec/aggexec/median.go index 96f435d8b9179..cb9e4cb5b6fbd 100644 --- a/pkg/sql/colexec/aggexec/median.go +++ b/pkg/sql/colexec/aggexec/median.go @@ -54,7 +54,7 @@ func (exec *medianColumnExecSelf[T, R]) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *medianColumnExecSelf[T, R]) marshal() ([]byte, error) { +func (exec *medianColumnExecSelf[T, R]) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -78,7 +78,7 @@ func (exec *medianColumnExecSelf[T, R]) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *medianColumnExecSelf[T, R]) unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *medianColumnExecSelf[T, R]) Unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error { if len(groups) > 0 { exec.groups = make([]*Vectors[T], len(groups)) for i := range exec.groups { diff --git a/pkg/sql/colexec/aggexec/serialize.go b/pkg/sql/colexec/aggexec/serialize.go index 7ad2451c1bad4..75773922570e2 100644 --- a/pkg/sql/colexec/aggexec/serialize.go +++ b/pkg/sql/colexec/aggexec/serialize.go @@ -19,7 +19,7 @@ import ( ) func MarshalAggFuncExec(exec AggFuncExec) ([]byte, error) { - return exec.marshal() + return exec.Marshal() } func UnmarshalAggFuncExec( @@ -48,7 +48,7 @@ func UnmarshalAggFuncExec( mp = mg.Mp() } - if err := exec.unmarshal( + if err := exec.Unmarshal( mp, encoded.Result, encoded.Empties, encoded.Groups); err != nil { exec.Free() return nil, err diff --git a/pkg/sql/colexec/aggexec/types.go b/pkg/sql/colexec/aggexec/types.go index ac1cda5512c30..f5ea66104de51 100644 --- a/pkg/sql/colexec/aggexec/types.go +++ b/pkg/sql/colexec/aggexec/types.go @@ -66,8 +66,8 @@ func (expr AggFuncExecExpression) GetExtraConfig() []byte { // AggFuncExec is an interface to do execution for aggregation. type AggFuncExec interface { - marshal() ([]byte, error) - unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error + Marshal() ([]byte, error) + Unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error GetOptResult() SplitResult AggID() int64 diff --git a/pkg/sql/colexec/aggexec/window.go b/pkg/sql/colexec/aggexec/window.go index 21c856d3bb6d8..3c0510814373e 100644 --- a/pkg/sql/colexec/aggexec/window.go +++ b/pkg/sql/colexec/aggexec/window.go @@ -59,7 +59,7 @@ func (exec *singleWindowExec) GetOptResult() SplitResult { return &exec.ret.optSplitResult } -func (exec *singleWindowExec) marshal() ([]byte, error) { +func (exec *singleWindowExec) Marshal() ([]byte, error) { d := exec.singleAggInfo.getEncoded() r, em, err := exec.ret.marshalToBytes() if err != nil { @@ -81,7 +81,7 @@ func (exec *singleWindowExec) marshal() ([]byte, error) { return encoded.Marshal() } -func (exec *singleWindowExec) unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error { +func (exec *singleWindowExec) Unmarshal(mp *mpool.MPool, result, empties, groups [][]byte) error { if len(exec.groups) > 0 { exec.groups = make([][]int64, len(groups)) for i := range exec.groups { diff --git a/pkg/sql/colexec/group/container.go b/pkg/sql/colexec/group/container.go new file mode 100644 index 0000000000000..556fa23dc6a7e --- /dev/null +++ b/pkg/sql/colexec/group/container.go @@ -0,0 +1,28 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +func (ctr *container) MemoryUsed() int64 { + var size int64 + if !ctr.hr.IsEmpty() { + size += ctr.hr.Size() + } + size += ctr.result1.Size() + size += ctr.result2.Size() + if ctr.spillManager != nil { + size += ctr.spillManager.Size() + } + return size +} diff --git a/pkg/sql/colexec/group/exec.go b/pkg/sql/colexec/group/exec.go index 700d5aa1ccadd..dfc8ab9a7ae08 100644 --- a/pkg/sql/colexec/group/exec.go +++ b/pkg/sql/colexec/group/exec.go @@ -57,6 +57,21 @@ func (group *Group) Prepare(proc *process.Process) (err error) { group.ctr.state = vm.Build group.ctr.dataSourceIsEmpty = true group.prepareAnalyzer() + + if group.SpillThreshold == 0 { + group.SpillThreshold = 1 * 1024 * 1024 //TODO + } + if group.ctr.spillManager == nil { + var groupByTypes []types.Type + for _, expr := range group.Exprs { + groupByTypes = append(groupByTypes, types.New(types.T(expr.Typ.Id), expr.Typ.Width, expr.Typ.Scale)) + } + group.ctr.spillManager, err = NewSpillManager(proc, groupByTypes, group.Aggs) + if err != nil { + return + } + } + if err = group.prepareAgg(proc); err != nil { return err } @@ -188,6 +203,13 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e for { if group.ctr.state == vm.Eval { + // If we have spilled data, we need to merge it + if group.ctr.spillManager != nil && group.ctr.spillManager.HasSpilledData() { + if err := group.mergeSpilledData(proc); err != nil { + return nil, err + } + } + if group.ctr.result1.IsEmpty() { group.ctr.state = vm.End return nil, nil @@ -218,6 +240,11 @@ func (group *Group) callToGetFinalResult(proc *process.Process) (*batch.Batch, e if err = group.consumeBatchToGetFinalResult(proc, res); err != nil { return nil, err } + + // Check for spill after processing each batch + if err := group.checkAndSpill(proc); err != nil { + return nil, err + } } } @@ -312,9 +339,19 @@ func (group *Group) consumeBatchToGetFinalResult( return err } } + + // Check for spill after processing each chunk + if err := group.checkAndSpill(proc); err != nil { + return err + } } } + // Check for spill after processing the batch + if err := group.checkAndSpill(proc); err != nil { + return err + } + return nil } @@ -492,6 +529,11 @@ func (group *Group) consumeBatchToRes( } } + // Check for spill after processing a batch + if err := group.checkAndSpill(proc); err != nil { + return false, err + } + return res.RowCount() < intermediateResultSendActionTrigger, nil } } diff --git a/pkg/sql/colexec/group/exec_test.go b/pkg/sql/colexec/group/exec_test.go index 5e53f82e8aef2..a0b351b143d3e 100644 --- a/pkg/sql/colexec/group/exec_test.go +++ b/pkg/sql/colexec/group/exec_test.go @@ -15,6 +15,8 @@ package group import ( + "testing" + "github.com/matrixorigin/matrixone/pkg/common/mpool" "github.com/matrixorigin/matrixone/pkg/container/batch" "github.com/matrixorigin/matrixone/pkg/container/types" @@ -25,14 +27,12 @@ import ( "github.com/matrixorigin/matrixone/pkg/testutil" "github.com/matrixorigin/matrixone/pkg/vm" "github.com/stretchr/testify/require" - "testing" ) // hackAggExecToTest 是一个不带任何逻辑的AggExec,主要用于单测中检查各种接口的调用次数。 type hackAggExecToTest struct { toFlush int - aggexec.AggFuncExec preAllocated int groupNumber int doFillRow int @@ -42,6 +42,44 @@ type hackAggExecToTest struct { isFree bool } +var _ aggexec.AggFuncExec = new(hackAggExecToTest) + +func (h *hackAggExecToTest) AggID() int64 { + panic("unimplemented") +} + +func (h *hackAggExecToTest) BatchMerge(next aggexec.AggFuncExec, offset int, groups []uint64) error { + panic("unimplemented") +} + +func (h *hackAggExecToTest) IsDistinct() bool { + panic("unimplemented") +} + +func (h *hackAggExecToTest) Marshal() ([]byte, error) { + return nil, nil +} + +func (h *hackAggExecToTest) Merge(next aggexec.AggFuncExec, groupIdx1 int, groupIdx2 int) error { + panic("unimplemented") +} + +func (h *hackAggExecToTest) SetExtraInformation(partialResult any, groupIndex int) (err error) { + panic("unimplemented") +} + +func (h *hackAggExecToTest) TypesInfo() ([]types.Type, types.Type) { + panic("unimplemented") +} + +func (h *hackAggExecToTest) Unmarshal(mp *mpool.MPool, result [][]byte, empties [][]byte, groups [][]byte) error { + return nil +} + +func (h *hackAggExecToTest) Size() int64 { + return 0 +} + func (h *hackAggExecToTest) GetOptResult() aggexec.SplitResult { return nil } diff --git a/pkg/sql/colexec/group/execctx.go b/pkg/sql/colexec/group/execctx.go index 22fc688d8392d..0a36c1ce70ef3 100644 --- a/pkg/sql/colexec/group/execctx.go +++ b/pkg/sql/colexec/group/execctx.go @@ -94,6 +94,14 @@ func (hr *ResHashRelated) GetBinaryInsertList(vals []uint64, before uint64) (ins hr.inserted = hr.inserted[:len(vals)] } + if hr.Hash == nil { + // When hash table is nil, no inserts should occur + for i := range hr.inserted { + hr.inserted[i] = 0 + } + return hr.inserted, 0 + } + insertCount = hr.Hash.GroupCount() - before last := before @@ -123,7 +131,32 @@ type GroupResultBuffer struct { } func (buf *GroupResultBuffer) IsEmpty() bool { - return cap(buf.ToPopped) == 0 + // Buffer is considered empty if there are no batches with actual data + // or if there are no aggregators + if len(buf.ToPopped) == 0 && len(buf.AggList) == 0 { + return true + } + + // If we have aggregators with data, we're not empty + // This handles the case where there are no group-by columns + if len(buf.AggList) > 0 { + for _, agg := range buf.AggList { + if agg != nil && agg.Size() > 0 { + return false + } + } + } + + // Check if any batch has actual data + for _, batch := range buf.ToPopped { + if batch != nil && batch.RowCount() > 0 { + return false + } + } + + // If we reach here, all batches are either nil or empty + // In the context of spilling, we only spill when we have group data + return true } func (buf *GroupResultBuffer) InitOnlyAgg(chunkSize int, aggList []aggexec.AggFuncExec) { diff --git a/pkg/sql/colexec/group/group.go b/pkg/sql/colexec/group/group.go new file mode 100644 index 0000000000000..e23bb907bdb96 --- /dev/null +++ b/pkg/sql/colexec/group/group.go @@ -0,0 +1,370 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "time" + + "github.com/matrixorigin/matrixone/pkg/common/hashmap" + "github.com/matrixorigin/matrixone/pkg/common/moerr" + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/logutil" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/vm/process" + "go.uber.org/zap" +) + +func (group *Group) MemoryUsed() int64 { + var size int64 + size += group.ctr.MemoryUsed() + return size +} + +func (group *Group) ShouldSpill() bool { + if group.SpillThreshold == 0 { + return false + } + current := group.MemoryUsed() + return current >= group.SpillThreshold +} + +func (group *Group) checkAndSpill(proc *process.Process) error { + logutil.Info("Group: check spill", + zap.Any("used", group.MemoryUsed()), + zap.Any("threshold", group.SpillThreshold), + ) + if group.ShouldSpill() { + // Trigger spilling + logutil.Infof("Group: Memory threshold exceeded, triggering spill operation. Current memory: %d bytes, Threshold: %d bytes", + group.MemoryUsed(), group.SpillThreshold) + return group.spillToDisk(proc) + } + return nil +} + +func (group *Group) spillToDisk(proc *process.Process) error { + if group.ctr.result1.IsEmpty() { + return nil + } + + logutil.Infof("Group: Starting spill to disk operation with hash table preservation") + startTime := time.Now() + beforeMemory := group.MemoryUsed() + + var groups []*batch.Batch + for _, b := range group.ctr.result1.ToPopped { + if b != nil && b.RowCount() > 0 { + groups = append(groups, b) + } + } + + logutil.Infof("Group: Preparing %d batches and %d aggregators for spilling", len(groups), len(group.ctr.result1.AggList)) + + if len(groups) == 0 && len(group.ctr.result1.AggList) > 0 { + emptyBatch := batch.NewOffHeapEmpty() + groups = append(groups, emptyBatch) + } + + if len(groups) == 0 && len(group.ctr.result1.AggList) == 0 { + return nil + } + + // Serialize hash table state before spilling + var hashTableData []byte + var err error + if !group.ctr.hr.IsEmpty() { + hashTableData, err = group.ctr.hr.MarshalHashTable() + if err != nil { + return moerr.NewInternalErrorNoCtxf("failed to marshal hash table: %v", err) + } + } + + // Spill with hash table data + err = group.ctr.spillManager.SpillToDiskWithHashTable( + groups, + group.ctr.result1.AggList, + hashTableData, + group.ctr.mtyp == HStr, + group.ctr.keyNullable, + group.ctr.keyWidth, + ) + if err != nil { + return err + } + + // Clean up current state and create fresh state for continued processing + var newHr ResHashRelated + var newResult GroupResultBuffer + + aggs, err := group.generateAggExec(proc) + if err != nil { + for _, agg := range aggs { + if agg != nil { + agg.Free() + } + } + return err + } + + newResult.InitOnlyAgg(aggexec.GetMinAggregatorsChunkSize(nil, aggs), aggs) + + err = newHr.BuildHashTable(true, group.ctr.mtyp == HStr, group.ctr.keyNullable, 0) + if err != nil { + for _, agg := range newResult.AggList { + if agg != nil { + agg.Free() + } + } + newHr.Free0() + return err + } + + oldHr := group.ctr.hr + oldResult := group.ctr.result1 + + group.ctr.hr = newHr + group.ctr.result1 = newResult + + oldHr.Free0() + oldResult.Free0(proc.Mp()) + + duration := time.Since(startTime) + afterMemory := group.MemoryUsed() + logutil.Infof("Group: Successfully completed spill to disk operation with hash table preservation. Duration: %v, Memory before: %d bytes, Memory after: %d bytes", + duration, beforeMemory, afterMemory) + + return nil +} + +func (group *Group) mergeSpilledData(proc *process.Process) error { + if group.ctr.spillManager == nil || !group.ctr.spillManager.HasSpilledData() { + return nil + } + + logutil.Infof("Group: Starting merge of spilled data from %d spill files with hash table preservation", len(group.ctr.spillManager.spillFiles)) + startTime := time.Now() + beforeMemory := group.MemoryUsed() + + for _, filePath := range group.ctr.spillManager.spillFiles { + logutil.Infof("Group: Merging spilled data from file %s", filePath) + groups, aggs, hashTableData, isStrHash, keyNullable, keyWidth, err := group.ctr.spillManager.ReadSpilledDataWithHashTable(filePath) + if err != nil { + return err + } + + if err := group.mergeSpilledGroupsAndAggsWithHashTable(proc, groups, aggs, hashTableData, isStrHash, keyNullable, keyWidth); err != nil { + return err + } + } + + duration := time.Since(startTime) + afterMemory := group.MemoryUsed() + logutil.Infof("Group: Successfully completed merge of spilled data with hash table preservation. Duration: %v, Memory before: %d bytes, Memory after: %d bytes", + duration, beforeMemory, afterMemory) + + return group.ctr.spillManager.Cleanup(proc.Ctx) +} + +func (group *Group) mergeSpilledGroupsAndAggs(proc *process.Process, groups []*batch.Batch, aggs []aggexec.AggFuncExec) error { + return group.mergeSpilledGroupsAndAggsWithHashTable(proc, groups, aggs, nil, false, false, 0) +} + +func (group *Group) mergeSpilledGroupsAndAggsWithHashTable(proc *process.Process, groups []*batch.Batch, aggs []aggexec.AggFuncExec, hashTableData []byte, isStrHash bool, keyNullable bool, keyWidth int) error { + if len(groups) == 0 && len(aggs) == 0 && len(hashTableData) == 0 { + return nil + } + + if group.ctr.result1.IsEmpty() { + return group.restoreSpilledDataAsCurrentStateWithHashTable(proc, groups, aggs, hashTableData, isStrHash, keyNullable, keyWidth) + } + + if group.ctr.hr.Hash == nil || group.ctr.hr.Itr == nil { + return moerr.NewInternalError(proc.Ctx, "hash table or iterator is nil during merge") + } + + // If we have hash table data, try to merge it efficiently + if len(hashTableData) > 0 { + if err := group.mergeHashTableData(proc, hashTableData, isStrHash, keyNullable, keyWidth); err != nil { + logutil.Infof("Group: Failed to merge hash table data, falling back to group re-insertion: %v", err) + } + } + + if err := group.mergeSpilledGroups(proc, groups); err != nil { + return err + } + + return group.mergeSpilledAggregations(aggs) +} + +func (group *Group) restoreSpilledDataAsCurrentState(proc *process.Process, groups []*batch.Batch, aggs []aggexec.AggFuncExec) error { + return group.restoreSpilledDataAsCurrentStateWithHashTable(proc, groups, aggs, nil, false, false, 0) +} + +func (group *Group) restoreSpilledDataAsCurrentStateWithHashTable(proc *process.Process, groups []*batch.Batch, aggs []aggexec.AggFuncExec, hashTableData []byte, isStrHash bool, keyNullable bool, keyWidth int) error { + if len(groups) > 0 { + duplicatedBatches := make([]*batch.Batch, 0, len(groups)) + var batchesOwnershipTransferred bool + defer func() { + if !batchesOwnershipTransferred { + for _, batch := range duplicatedBatches { + if batch != nil { + batch.Clean(proc.Mp()) + } + } + } + }() + + for _, bat := range groups { + if bat != nil && bat.RowCount() > 0 { + newBatch, err := bat.Dup(proc.Mp()) + if err != nil { + return err + } + duplicatedBatches = append(duplicatedBatches, newBatch) + } + } + + if len(duplicatedBatches) > 0 { + // Try to restore hash table from spilled data + if len(hashTableData) > 0 { + if err := group.ctr.hr.UnmarshalHashTable(hashTableData, isStrHash, keyNullable, keyWidth); err != nil { + logutil.Infof("Group: Failed to restore hash table from spilled data, rebuilding: %v", err) + if err := group.ctr.hr.BuildHashTable(true, isStrHash, keyNullable, 0); err != nil { + return err + } + } + } else { + // Fallback to building new hash table + if err := group.ctr.hr.BuildHashTable(true, group.ctr.mtyp == HStr, group.ctr.keyNullable, 0); err != nil { + return err + } + } + + if group.ctr.hr.Itr == nil { + return moerr.NewInternalError(proc.Ctx, "hash table iterator is nil after rebuild") + } + + group.ctr.result1.ToPopped = duplicatedBatches + batchesOwnershipTransferred = true + + // Re-insert groups into hash table to restore mappings + for _, batch := range group.ctr.result1.ToPopped { + if batch == nil || batch.RowCount() == 0 { + continue + } + + count := batch.RowCount() + for i := 0; i < count; i += hashmap.UnitLimit { + n := count - i + if n > hashmap.UnitLimit { + n = hashmap.UnitLimit + } + + _, _, err := group.ctr.hr.Itr.Insert(i, n, batch.Vecs) + if err != nil { + return err + } + } + } + } + } + + if len(aggs) > 0 { + for _, agg := range group.ctr.result1.AggList { + if agg != nil { + agg.Free() + } + } + + group.ctr.result1.AggList = make([]aggexec.AggFuncExec, len(aggs)) + copy(group.ctr.result1.AggList, aggs) + } + + return nil +} + +func (group *Group) mergeSpilledGroups(proc *process.Process, groups []*batch.Batch) error { + for _, spilledBatch := range groups { + if spilledBatch == nil || spilledBatch.RowCount() == 0 { + continue + } + + count := spilledBatch.RowCount() + for i := 0; i < count; i += hashmap.UnitLimit { + n := count - i + if n > hashmap.UnitLimit { + n = hashmap.UnitLimit + } + + originGroupCount := group.ctr.hr.Hash.GroupCount() + vals, _, err := group.ctr.hr.Itr.Insert(i, n, spilledBatch.Vecs) + if err != nil { + return err + } + + insertList, newGroupCount := group.ctr.hr.GetBinaryInsertList(vals[:n], originGroupCount) + if newGroupCount > 0 { + if group.ctr.result1.ToPopped == nil { + group.ctr.result1.ToPopped = make([]*batch.Batch, 0, 1) + } + + _, err := group.ctr.result1.AppendBatch(proc.Mp(), spilledBatch.Vecs, i, insertList) + if err != nil { + return err + } + + for _, agg := range group.ctr.result1.AggList { + if agg != nil { + if err := agg.GroupGrow(int(newGroupCount)); err != nil { + return err + } + } + } + } + } + } + + return nil +} + +func (group *Group) mergeSpilledAggregations(aggs []aggexec.AggFuncExec) error { + if len(aggs) == 0 || len(group.ctr.result1.AggList) == 0 { + return nil + } + + minLen := len(aggs) + if len(group.ctr.result1.AggList) < minLen { + minLen = len(group.ctr.result1.AggList) + } + + for i := 0; i < minLen; i++ { + if aggs[i] != nil && group.ctr.result1.AggList[i] != nil { + if err := group.ctr.result1.AggList[i].Merge(aggs[i], 0, 0); err != nil { + return err + } + } + } + + return nil +} + +func (group *Group) mergeHashTableData(proc *process.Process, hashTableData []byte, isStrHash bool, keyNullable bool, keyWidth int) error { + if len(hashTableData) == 0 { + return nil + } + + //TODO re-insert groups + return moerr.NewInternalErrorNoCtx("not implemented") +} diff --git a/pkg/sql/colexec/group/group_result_buffer.go b/pkg/sql/colexec/group/group_result_buffer.go new file mode 100644 index 0000000000000..f02b638ce5d9b --- /dev/null +++ b/pkg/sql/colexec/group/group_result_buffer.go @@ -0,0 +1,28 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +func (buf *GroupResultBuffer) Size() int64 { + var size int64 + for _, b := range buf.ToPopped { + size += int64(b.Allocated()) + } + for _, agg := range buf.AggList { + if agg != nil { + size += agg.Size() + } + } + return size +} diff --git a/pkg/sql/colexec/group/group_result_none_block.go b/pkg/sql/colexec/group/group_result_none_block.go new file mode 100644 index 0000000000000..da5579e33c949 --- /dev/null +++ b/pkg/sql/colexec/group/group_result_none_block.go @@ -0,0 +1,31 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +func (r *GroupResultNoneBlock) Size() int64 { + if r.res == nil { + return 0 + } + var size int64 + size += int64(r.res.Allocated()) + + for _, agg := range r.res.Aggs { + if agg != nil { + size += agg.Size() + } + } + + return size +} diff --git a/pkg/sql/colexec/group/group_test.go b/pkg/sql/colexec/group/group_test.go new file mode 100644 index 0000000000000..5f58897218cfb --- /dev/null +++ b/pkg/sql/colexec/group/group_test.go @@ -0,0 +1,140 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "math/rand/v2" + "testing" + + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/pb/plan" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/value_scan" + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/matrixorigin/matrixone/pkg/vm" + "github.com/stretchr/testify/require" +) + +func TestGroup_CountStar(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + mp := proc.Mp() + before := mp.CurrNB() + + numRows := 1024 + numBatches := 1024 + + var allBatches []*batch.Batch + for i := 0; i < numBatches; i++ { + inputValues := make([]int64, numRows) + for i := 0; i < numRows; i++ { + inputValues[i] = int64(i + 1) + } + rand.Shuffle(len(inputValues), func(i, j int) { + inputValues[i], inputValues[j] = inputValues[j], inputValues[i] + }) + + inputVec := testutil.NewInt64Vector(numRows, types.T_int64.ToType(), mp, false, inputValues) + inputBatch := batch.NewWithSize(1) + inputBatch.Vecs[0] = inputVec + inputBatch.SetRowCount(numRows) + + allBatches = append(allBatches, inputBatch) + } + + vscan := value_scan.NewArgument() + vscan.Batchs = allBatches + + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{ + newColumnExpression(0), + }, + nil, + ), + }, + } + g.AppendChild(vscan) + + require.NoError(t, vscan.Prepare(proc)) + require.NoError(t, g.Prepare(proc)) + + var outputBatch *batch.Batch + outputCount := 0 + for { + r, err := g.Call(proc) + require.NoError(t, err) + if r.Batch == nil { + break + } + outputCount++ + outputBatch, err = r.Batch.Dup(proc.Mp()) + require.NoError(t, err) + require.Equal(t, 1, outputCount) + } + + require.NotNil(t, outputBatch) + require.Equal(t, 2, len(outputBatch.Vecs)) + require.Equal(t, numRows, outputBatch.RowCount()) + require.Equal(t, 0, len(outputBatch.Aggs)) + + outputValues := vector.MustFixedColNoTypeCheck[int64](outputBatch.Vecs[0]) + countValues := vector.MustFixedColNoTypeCheck[int64](outputBatch.Vecs[1]) + require.Equal(t, numRows, len(outputValues)) + require.Equal(t, numRows, len(countValues)) + + for _, v := range countValues { + require.Equal(t, int64(numBatches), v, "Count for each unique group should be 1") + } + + outputMap := make(map[int64]int64) + for i, v := range outputValues { + outputMap[v] = countValues[i] + } + + if outputBatch != nil { + outputBatch.Clean(proc.Mp()) + } + g.Free(proc, false, nil) + vscan.Free(proc, false, nil) + require.Equal(t, before, mp.CurrNB()) +} + +func TestBug_NilHashPointerDereference(t *testing.T) { + // Create a ResHashRelated with nil Hash + hr := &ResHashRelated{ + Hash: nil, // This will cause the original bug + Itr: nil, + } + + vals := []uint64{1, 2, 3} + before := uint64(0) + + // This should not panic after the fix + insertList, insertCount := hr.GetBinaryInsertList(vals, before) + + require.NotNil(t, insertList) + require.Equal(t, uint64(0), insertCount) // Should return 0 when hash is nil + require.Equal(t, len(vals), len(insertList)) +} diff --git a/pkg/sql/colexec/group/hashmap.go b/pkg/sql/colexec/group/hashmap.go new file mode 100644 index 0000000000000..dd9d21642abbd --- /dev/null +++ b/pkg/sql/colexec/group/hashmap.go @@ -0,0 +1,104 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import "github.com/matrixorigin/matrixone/pkg/common/moerr" + +func (hr *ResHashRelated) Size() int64 { + if hr.Hash == nil { + return 0 + } + return hr.Hash.Size() +} + +type HashTableMetadata struct { + IsStrHash bool + KeyNullable bool + GroupCount uint64 + KeyWidth int +} + +func (hr *ResHashRelated) MarshalHashTable() ([]byte, error) { + if hr.Hash == nil { + return nil, nil + } + + metadata := HashTableMetadata{ + IsStrHash: false, // will be set based on actual type + KeyNullable: false, // will be determined from context + GroupCount: hr.Hash.GroupCount(), + KeyWidth: 0, // will be set from context + } + + // We'll marshal just the metadata since we can rebuild the hash table + // from the spilled group batches + buf := make([]byte, 8+8+8+8) // bool + bool + uint64 + int (padded) + offset := 0 + + if metadata.IsStrHash { + buf[offset] = 1 + } + offset++ + + if metadata.KeyNullable { + buf[offset] = 1 + } + offset++ + + // Pad to 8-byte boundary + offset = 8 + + // GroupCount (uint64) + for i := 0; i < 8; i++ { + buf[offset+i] = byte(metadata.GroupCount >> (i * 8)) + } + offset += 8 + + // KeyWidth (int64) + keyWidth := int64(metadata.KeyWidth) + for i := 0; i < 8; i++ { + buf[offset+i] = byte(keyWidth >> (i * 8)) + } + + return buf, nil +} + +func (hr *ResHashRelated) UnmarshalHashTable(data []byte, isStrHash bool, keyNullable bool, keyWidth int) error { + if len(data) == 0 { + return nil + } + + if len(data) < 24 { + return moerr.NewInternalErrorNoCtx("invalid hash table metadata size") + } + + // Extract GroupCount + groupCount := uint64(0) + for i := 0; i < 8; i++ { + groupCount |= uint64(data[8+i]) << (i * 8) + } + + // Rebuild hash table with the same configuration + return hr.BuildHashTable(true, isStrHash, keyNullable, groupCount) +} + +func (hr *ResHashRelated) GetHashTableConfig() (isStrHash bool, keyNullable bool, groupCount uint64) { + if hr.Hash == nil { + return false, false, 0 + } + // We'll need to store this information when the hash table is created + // For now, return what we can determine + return false, false, hr.Hash.GroupCount() +} diff --git a/pkg/sql/colexec/group/spill_manager.go b/pkg/sql/colexec/group/spill_manager.go new file mode 100644 index 0000000000000..8c216a3712f75 --- /dev/null +++ b/pkg/sql/colexec/group/spill_manager.go @@ -0,0 +1,586 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "sync/atomic" + "time" + + "github.com/matrixorigin/matrixone/pkg/common/moerr" + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/fileservice" + "github.com/matrixorigin/matrixone/pkg/logutil" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/vm/process" +) + +type SpillManager struct { + proc *process.Process + fileService fileservice.ReaderWriterFileService + spillFiles []string + groupByTypes []types.Type + aggInfos []aggexec.AggFuncExecExpression +} + +func NewSpillManager(proc *process.Process, groupByTypes []types.Type, aggInfos []aggexec.AggFuncExecExpression) (*SpillManager, error) { + manager := &SpillManager{ + proc: proc, + spillFiles: make([]string, 0), + groupByTypes: groupByTypes, + aggInfos: aggInfos, + } + + fs, err := proc.GetSpillFileService() + if err != nil { + return nil, err + } + if rwfs, ok := fs.(fileservice.ReaderWriterFileService); !ok { + return nil, moerr.NewInternalErrorNoCtxf("%T is not ReaderWriterFileService", fs) + } else { + manager.fileService = rwfs + } + + return manager, nil +} + +func (sm *SpillManager) SpillToDisk(groups []*batch.Batch, aggs []aggexec.AggFuncExec) (err error) { + return sm.SpillToDiskWithHashTable(groups, aggs, nil, false, false, 0) +} + +func (sm *SpillManager) SpillToDiskWithHashTable(groups []*batch.Batch, aggs []aggexec.AggFuncExec, hashTableData []byte, isStrHash bool, keyNullable bool, keyWidth int) (err error) { + if err := sm.validateSpillInputs(groups, aggs); err != nil { + return moerr.NewInternalErrorNoCtxf("spill input validation failed: %v", err) + } + + if len(groups) == 0 && len(aggs) == 0 && len(hashTableData) == 0 { + return nil + } + + logutil.Infof("SpillManager: Starting spill operation with %d groups, %d aggregations, and hash table data", len(groups), len(aggs)) + + if sm.fileService == nil { + return moerr.NewInternalErrorNoCtx("file service is not available for spilling") + } + + filePath := sm.generateSpillFilePath() + + writer, err := sm.fileService.NewWriter(context.Background(), filePath) + if err != nil { + return moerr.NewInternalErrorNoCtxf("failed to create spill file writer for %s: %v", filePath, err) + } + defer func() { + err = errors.Join(err, writer.Close()) + if err != nil { + _ = sm.fileService.Delete(context.Background(), filePath) + } + }() + + startTime := time.Now() + defer func() { + if err == nil { + duration := time.Since(startTime) + logutil.Infof("SpillManager: Successfully completed spill operation to %s. Duration: %v, Total spill files: %d", + filePath, duration, len(sm.spillFiles)) + } + }() + + logutil.Infof("SpillManager: Writing spill data to file %s", filePath) + + if err := sm.writeHeaderWithHashTable(writer, groups, aggs, hashTableData, isStrHash, keyNullable, keyWidth); err != nil { + return moerr.NewInternalErrorNoCtxf("failed to write spill file header to %s: %v", filePath, err) + } + + if err := sm.writeGroupBatches(writer, groups); err != nil { + return moerr.NewInternalErrorNoCtxf("failed to write group batches to spill file %s: %v", filePath, err) + } + + if err := sm.writeAggStates(writer, aggs); err != nil { + return moerr.NewInternalErrorNoCtxf("failed to write aggregation states to spill file %s: %v", filePath, err) + } + + if err := sm.writeHashTableData(writer, hashTableData); err != nil { + return moerr.NewInternalErrorNoCtxf("failed to write hash table data to spill file %s: %v", filePath, err) + } + + sm.spillFiles = append(sm.spillFiles, filePath) + + return nil +} + +func (sm *SpillManager) validateSpillInputs(groups []*batch.Batch, aggs []aggexec.AggFuncExec) error { + for i, group := range groups { + if group == nil { + continue + } + + if group.RowCount() < 0 { + return moerr.NewInternalErrorNoCtxf("group batch %d has negative row count: %d", i, group.RowCount()) + } + + const maxReasonableRowCount = 1000000 + if group.RowCount() > maxReasonableRowCount { + return moerr.NewInternalErrorNoCtxf("group batch %d has excessive row count: %d", i, group.RowCount()) + } + + if group.RowCount() > 0 && len(group.Vecs) == 0 { + return moerr.NewInternalErrorNoCtxf("group batch %d has rows but no vectors", i) + } + } + + for i, agg := range aggs { + if agg == nil { + continue + } + + if size := agg.Size(); size < 0 { + return moerr.NewInternalErrorNoCtxf("aggregation %d has negative size: %d", i, size) + } + } + + return nil +} + +func (sm *SpillManager) writeHeader(writer io.Writer, groups []*batch.Batch, aggs []aggexec.AggFuncExec) error { + return sm.writeHeaderWithHashTable(writer, groups, aggs, nil, false, false, 0) +} + +func (sm *SpillManager) writeHeaderWithHashTable(writer io.Writer, groups []*batch.Batch, aggs []aggexec.AggFuncExec, hashTableData []byte, isStrHash bool, keyNullable bool, keyWidth int) error { + nonEmptyGroupCount := int32(0) + for _, groupBatch := range groups { + if groupBatch != nil && groupBatch.RowCount() > 0 { + nonEmptyGroupCount++ + } + } + + hasHashTable := int32(0) + if len(hashTableData) > 0 { + hasHashTable = 1 + } + + hashTableFlags := int32(0) + if isStrHash { + hashTableFlags |= 1 + } + if keyNullable { + hashTableFlags |= 2 + } + + header := struct { + GroupCount int32 + AggCount int32 + HasHashTable int32 + HashTableFlags int32 + KeyWidth int32 + Reserved int32 + }{ + GroupCount: nonEmptyGroupCount, + AggCount: int32(len(aggs)), + HasHashTable: hasHashTable, + HashTableFlags: hashTableFlags, + KeyWidth: int32(keyWidth), + Reserved: 0, + } + + return binary.Write(writer, binary.BigEndian, header) +} + +func (sm *SpillManager) writeGroupBatches(writer io.Writer, groups []*batch.Batch) error { + for _, groupBatch := range groups { + if groupBatch == nil || groupBatch.RowCount() == 0 { + continue + } + + if err := sm.writeSizedData(writer, groupBatch.MarshalBinary); err != nil { + return err + } + } + return nil +} + +func (sm *SpillManager) writeAggStates(writer io.Writer, aggs []aggexec.AggFuncExec) error { + for _, agg := range aggs { + if agg == nil { + continue + } + + if err := sm.writeSizedData(writer, func() ([]byte, error) { + return aggexec.MarshalAggFuncExec(agg) + }); err != nil { + return err + } + } + return nil +} + +func (sm *SpillManager) writeSizedData(writer io.Writer, marshalFunc func() ([]byte, error)) error { + data, err := marshalFunc() + if err != nil { + return err + } + + size := int32(len(data)) + if err := binary.Write(writer, binary.BigEndian, size); err != nil { + return err + } + + _, err = writer.Write(data) + return err +} + +var spillCounter atomic.Int64 + +func (sm *SpillManager) generateSpillFilePath() string { + return fmt.Sprintf("group_spill_%d", spillCounter.Add(1)) +} + +func (sm *SpillManager) ReadSpilledData(filePath string) ([]*batch.Batch, []aggexec.AggFuncExec, error) { + groups, aggs, _, _, _, _, err := sm.ReadSpilledDataWithHashTable(filePath) + return groups, aggs, err +} + +func (sm *SpillManager) ReadSpilledDataWithHashTable(filePath string) ([]*batch.Batch, []aggexec.AggFuncExec, []byte, bool, bool, int, error) { + if filePath == "" { + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtx("spill file path cannot be empty") + } + + if sm.fileService == nil { + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtx("file service is not available for reading spilled data") + } + + reader, err := sm.fileService.NewReader(context.Background(), filePath) + if err != nil { + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("failed to open spill file %s: %v", filePath, err) + } + defer reader.Close() + + logutil.Infof("SpillManager: Reading spilled data from file %s", filePath) + startTime := time.Now() + + header, err := sm.readHeaderWithHashTable(reader) + if err != nil { + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("failed to read spill file header from %s: %v", filePath, err) + } + + if err := sm.validateSpillHeaderWithHashTable(header, filePath); err != nil { + return nil, nil, nil, false, false, 0, err + } + + var groups []*batch.Batch + var aggs []aggexec.AggFuncExec + var hashTableData []byte + var readErr error + + cleanup := func() { + cleanupResources(sm.proc.Mp(), groups, aggs) + groups = nil + aggs = nil + } + + if groups, readErr = sm.readGroupBatches(reader, header.GroupCount); readErr != nil { + cleanup() + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("failed to read group batches from %s: %v", filePath, readErr) + } + + if aggs, readErr = sm.readAggStates(reader, header.AggCount); readErr != nil { + cleanup() + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("failed to read aggregation states from %s: %v", filePath, readErr) + } + + if header.HasHashTable > 0 { + if hashTableData, readErr = sm.readHashTableData(reader); readErr != nil { + cleanup() + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("failed to read hash table data from %s: %v", filePath, readErr) + } + } + + if int32(len(groups)) != header.GroupCount { + cleanup() + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("spill file %s group count mismatch: expected %d, got %d", + filePath, header.GroupCount, len(groups)) + } + + if int32(len(aggs)) != header.AggCount { + cleanup() + return nil, nil, nil, false, false, 0, moerr.NewInternalErrorNoCtxf("spill file %s aggregation count mismatch: expected %d, got %d", + filePath, header.AggCount, len(aggs)) + } + + isStrHash := (header.HashTableFlags & 1) != 0 + keyNullable := (header.HashTableFlags & 2) != 0 + keyWidth := int(header.KeyWidth) + + duration := time.Since(startTime) + logutil.Infof("SpillManager: Successfully read spilled data from %s. Duration: %v, Groups: %d, Aggregations: %d, HashTable: %t", + filePath, duration, len(groups), len(aggs), len(hashTableData) > 0) + + return groups, aggs, hashTableData, isStrHash, keyNullable, keyWidth, nil +} + +func (sm *SpillManager) validateSpillHeader(header struct{ GroupCount, AggCount int32 }, filePath string) error { + headerExt := struct { + GroupCount int32 + AggCount int32 + HasHashTable int32 + HashTableFlags int32 + KeyWidth int32 + Reserved int32 + }{ + GroupCount: header.GroupCount, + AggCount: header.AggCount, + HasHashTable: 0, + HashTableFlags: 0, + KeyWidth: 0, + Reserved: 0, + } + + return sm.validateSpillHeaderWithHashTable(headerExt, filePath) +} + +func (sm *SpillManager) validateSpillHeaderWithHashTable(header struct { + GroupCount int32 + AggCount int32 + HasHashTable int32 + HashTableFlags int32 + KeyWidth int32 + Reserved int32 +}, filePath string) error { + if header.GroupCount < 0 { + return moerr.NewInternalErrorNoCtxf("spill file %s has negative group count: %d", filePath, header.GroupCount) + } + + if header.AggCount < 0 { + return moerr.NewInternalErrorNoCtxf("spill file %s has negative aggregation count: %d", filePath, header.AggCount) + } + + if header.HasHashTable < 0 || header.HasHashTable > 1 { + return moerr.NewInternalErrorNoCtxf("spill file %s has invalid hash table flag: %d", filePath, header.HasHashTable) + } + + const maxReasonableCount = 100000 + if header.GroupCount > maxReasonableCount { + return moerr.NewInternalErrorNoCtxf("spill file %s has excessive group count: %d", filePath, header.GroupCount) + } + + if header.AggCount > maxReasonableCount { + return moerr.NewInternalErrorNoCtxf("spill file %s has excessive aggregation count: %d", filePath, header.AggCount) + } + + totalCount := int64(header.GroupCount) + int64(header.AggCount) + if totalCount > maxReasonableCount { + return moerr.NewInternalErrorNoCtxf("spill file %s has excessive total count: %d", filePath, totalCount) + } + + return nil +} + +func (sm *SpillManager) Cleanup(ctx context.Context) error { + _ = sm.fileService.Delete(ctx, sm.spillFiles...) + sm.spillFiles = sm.spillFiles[:0] + return nil +} + +func (sm *SpillManager) HasSpilledData() bool { + return len(sm.spillFiles) > 0 +} + +func (sm *SpillManager) GetSpillFileCount() int { + return len(sm.spillFiles) +} + +func (sm *SpillManager) Size() int64 { + var size int64 + // Account for spill file paths + for _, filePath := range sm.spillFiles { + size += int64(len(filePath)) + } + // Account for groupByTypes + for _, typ := range sm.groupByTypes { + size += int64(typ.ProtoSize()) + } + // Account for aggInfos (approximate) + size += int64(len(sm.aggInfos) * 64) // Approximate size per agg info + // Account for fileCounter (negligible but included for completeness) + size += 8 + return size +} + +func (sm *SpillManager) readHeader(reader io.Reader) (struct{ GroupCount, AggCount int32 }, error) { + headerExt, err := sm.readHeaderWithHashTable(reader) + if err != nil { + return struct{ GroupCount, AggCount int32 }{}, err + } + + return struct{ GroupCount, AggCount int32 }{ + GroupCount: headerExt.GroupCount, + AggCount: headerExt.AggCount, + }, nil +} + +func (sm *SpillManager) readHeaderWithHashTable(reader io.Reader) (struct { + GroupCount int32 + AggCount int32 + HasHashTable int32 + HashTableFlags int32 + KeyWidth int32 + Reserved int32 +}, error) { + var header struct { + GroupCount int32 + AggCount int32 + HasHashTable int32 + HashTableFlags int32 + KeyWidth int32 + Reserved int32 + } + + // Try to read the extended header first + if err := binary.Read(reader, binary.BigEndian, &header); err != nil { + // If it fails, try to read the old format + var oldHeader struct { + GroupCount int32 + AggCount int32 + } + + // Reset reader position would be ideal, but since we can't, + // we'll assume the error means old format and return appropriate default + if err := binary.Read(reader, binary.BigEndian, &oldHeader); err != nil { + return header, err + } + + header.GroupCount = oldHeader.GroupCount + header.AggCount = oldHeader.AggCount + header.HasHashTable = 0 + header.HashTableFlags = 0 + header.KeyWidth = 0 + header.Reserved = 0 + } + + return header, nil +} + +func (sm *SpillManager) readGroupBatches(reader io.Reader, count int32) ([]*batch.Batch, error) { + groups := make([]*batch.Batch, 0, count) + + for i := int32(0); i < count; i++ { + batch, err := sm.readSingleBatch(reader) + if err != nil { + cleanupResources(sm.proc.Mp(), groups, nil) + return nil, err + } + groups = append(groups, batch) + } + + return groups, nil +} + +func (sm *SpillManager) readAggStates(reader io.Reader, count int32) ([]aggexec.AggFuncExec, error) { + aggs := make([]aggexec.AggFuncExec, 0, count) + + for i := int32(0); i < count; i++ { + agg, err := sm.readSingleAgg(reader) + if err != nil { + cleanupResources(sm.proc.Mp(), nil, aggs) + return nil, err + } + aggs = append(aggs, agg) + } + + return aggs, nil +} + +func (sm *SpillManager) readSingleBatch(reader io.Reader) (*batch.Batch, error) { + data, err := sm.readSizedData(reader) + if err != nil { + return nil, err + } + + bat := batch.NewWithSize(0) + if err := bat.UnmarshalBinary(data); err != nil { + bat.Clean(sm.proc.Mp()) + return nil, err + } + + return bat, nil +} + +func (sm *SpillManager) readSingleAgg(reader io.Reader) (aggexec.AggFuncExec, error) { + data, err := sm.readSizedData(reader) + if err != nil { + return nil, err + } + + return aggexec.UnmarshalAggFuncExec(sm.proc, data) +} + +func (sm *SpillManager) readSizedData(reader io.Reader) ([]byte, error) { + var size int32 + if err := binary.Read(reader, binary.BigEndian, &size); err != nil { + return nil, moerr.NewInternalErrorNoCtxf("failed to read data size: %v", err) + } + + if size < 0 { + return nil, moerr.NewInternalErrorNoCtxf("invalid negative data size: %d", size) + } + + const maxReasonableSize = 100 * 1024 * 1024 + if size > maxReasonableSize { + return nil, moerr.NewInternalErrorNoCtxf("data size too large: %d bytes (max %d)", size, maxReasonableSize) + } + + if size == 0 { + return []byte{}, nil + } + + data := make([]byte, size) + if _, err := io.ReadFull(reader, data); err != nil { + return nil, moerr.NewInternalErrorNoCtxf("failed to read %d bytes of data: %v", size, err) + } + + return data, nil +} + +func cleanupResources(mp *mpool.MPool, groups []*batch.Batch, aggs []aggexec.AggFuncExec) { + for _, batch := range groups { + if batch != nil { + batch.Clean(mp) + } + } + for _, agg := range aggs { + if agg != nil { + agg.Free() + } + } +} + +func (sm *SpillManager) writeHashTableData(writer io.Writer, hashTableData []byte) error { + if len(hashTableData) == 0 { + return nil + } + + return sm.writeSizedData(writer, func() ([]byte, error) { + return hashTableData, nil + }) +} + +func (sm *SpillManager) readHashTableData(reader io.Reader) ([]byte, error) { + return sm.readSizedData(reader) +} diff --git a/pkg/sql/colexec/group/spill_manager_test.go b/pkg/sql/colexec/group/spill_manager_test.go new file mode 100644 index 0000000000000..e1a0c261647f6 --- /dev/null +++ b/pkg/sql/colexec/group/spill_manager_test.go @@ -0,0 +1,250 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "testing" + + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/pb/plan" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/stretchr/testify/require" +) + +func TestSpillManager_CreateSpillFile(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + + // Test creating spill file + filePath := sm.generateSpillFilePath() + require.NotEmpty(t, filePath) + + // Test that we can create multiple spill files + filePath2 := sm.generateSpillFilePath() + require.NotEmpty(t, filePath2) + require.NotEqual(t, filePath, filePath2) +} + +func TestSpillManager_SpillAndRead(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + + // Create some test data + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(3, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 2, 3}) + testBatch.SetRowCount(3) + + groups := []*batch.Batch{testBatch} + + // Create a simple agg executor for testing + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(3)) + aggs := []aggexec.AggFuncExec{aggExec} + + // Test spilling data + err = sm.SpillToDisk(groups, aggs) + require.NoError(t, err) + require.Equal(t, 1, sm.GetSpillFileCount()) + + // Test reading spilled data + spilledGroups, spilledAggs, err := sm.ReadSpilledData(sm.spillFiles[0]) + require.NoError(t, err) + require.Len(t, spilledGroups, 1) + require.Len(t, spilledAggs, 1) + require.Equal(t, 3, spilledGroups[0].RowCount()) + + // Cleanup + require.NoError(t, sm.Cleanup(t.Context())) + require.Equal(t, 0, sm.GetSpillFileCount()) +} + +func TestSpillManager_BasicOperations(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + require.NotNil(t, sm) + + // Test that we can get group by types + require.Equal(t, groupByTypes, sm.groupByTypes) + + // Test that we can get agg infos + require.Equal(t, aggInfos, sm.aggInfos) + + // Test spill file count methods + require.Equal(t, 0, sm.GetSpillFileCount()) + require.Equal(t, 0, len(sm.spillFiles)) + require.False(t, sm.HasSpilledData()) + + // Clean up + require.NoError(t, sm.Cleanup(t.Context())) +} + +func TestSpillManager_SpillToDisk(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + + // Create test data + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(3, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 2, 3}) + testBatch.SetRowCount(3) + + groups := []*batch.Batch{testBatch} + + // Create aggregator with data + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(3)) + + // Fill with test data + fillBatch := batch.NewWithSize(1) + fillBatch.Vecs[0] = testutil.NewInt64Vector(5, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 1, 1, 1, 1}) + fillBatch.SetRowCount(5) + require.NoError(t, aggExec.BulkFill(0, []*vector.Vector{fillBatch.Vecs[0]})) + require.NoError(t, aggExec.BulkFill(1, []*vector.Vector{fillBatch.Vecs[0]})) + require.NoError(t, aggExec.BulkFill(2, []*vector.Vector{fillBatch.Vecs[0]})) + + aggs := []aggexec.AggFuncExec{aggExec} + + // Test spilling to disk + err = sm.SpillToDisk(groups, aggs) + require.NoError(t, err) + + // Verify spill file was created + require.Equal(t, 1, sm.GetSpillFileCount()) + require.True(t, sm.HasSpilledData()) + + // Test reading spilled data + spilledGroups, spilledAggs, err := sm.ReadSpilledData(sm.spillFiles[0]) + require.NoError(t, err) + require.Len(t, spilledGroups, 1) + require.Len(t, spilledAggs, 1) + + require.Equal(t, 3, spilledGroups[0].RowCount()) + + // Clean up + require.NoError(t, sm.Cleanup(t.Context())) + require.Equal(t, 0, sm.GetSpillFileCount()) + require.False(t, sm.HasSpilledData()) +} + +func TestSpillManager_ErrorHandling(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + + // Test with nil groups and aggregators (should handle gracefully) + err = sm.SpillToDisk(nil, nil) + require.NoError(t, err) + + // Test with empty groups and aggregators + err = sm.SpillToDisk([]*batch.Batch{}, []aggexec.AggFuncExec{}) + require.NoError(t, err) + + // Clean up + require.NoError(t, sm.Cleanup(t.Context())) +} + +func TestSpillManager_CleanupNonExistentFiles(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + + // Add a non-existent file path to the spill files list + sm.spillFiles = append(sm.spillFiles, "/tmp/non_existent_spill_file.dat") + + // Cleanup should not return an error even if file doesn't exist + err = sm.Cleanup(t.Context()) + require.NoError(t, err) + require.Equal(t, 0, len(sm.spillFiles)) +} diff --git a/pkg/sql/colexec/group/spill_test.go b/pkg/sql/colexec/group/spill_test.go new file mode 100644 index 0000000000000..8ee86783d20b8 --- /dev/null +++ b/pkg/sql/colexec/group/spill_test.go @@ -0,0 +1,949 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package group + +import ( + "testing" + + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/container/batch" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/pb/plan" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/matrixorigin/matrixone/pkg/testutil" + "github.com/matrixorigin/matrixone/pkg/vm" + "github.com/matrixorigin/matrixone/pkg/vm/process" + "github.com/stretchr/testify/require" +) + +func TestGroup_DiskSpill(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Set a very low spill threshold to force spilling + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, // Very low threshold to trigger spilling + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{ + newColumnExpression(0), + }, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + + // Test that memory tracking works + initialMemory := g.MemoryUsed() + if initialMemory < 0 { + t.Errorf("MemoryUsed should return non-negative value, got %d", initialMemory) + } + + // Test that spill manager is properly initialized + if g.ctr.spillManager == nil { + t.Error("Spill manager should be initialized") + } + + // Test spill threshold checking + shouldSpill := g.ShouldSpill() + // Initially should not spill since memory usage is low + if shouldSpill { + t.Error("Should not spill initially with low memory usage") + } + + // Test size calculation methods + spillManagerSize := g.ctr.spillManager.Size() + if spillManagerSize < 0 { + t.Errorf("SpillManager.Size() should return non-negative value, got %d", spillManagerSize) + } +} + +func TestGroup_SpillAndMerge(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Create a group operator with a very low spill threshold + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, // Very low threshold to force spilling + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{ + newColumnExpression(0), + }, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + + // Verify spill manager is initialized + require.NotNil(t, g.ctr.spillManager) + + // Test memory usage tracking + initialMemory := g.MemoryUsed() + require.True(t, initialMemory >= 0, "MemoryUsed should return non-negative value") + + // Test spill threshold checking + shouldSpill := g.ShouldSpill() + // Initially should not spill since memory usage is low + require.False(t, shouldSpill, "Should not spill initially with low memory usage") + + // Test spill manager size calculation + spillManagerSize := g.ctr.spillManager.Size() + require.True(t, spillManagerSize >= 0, "SpillManager.Size() should return non-negative value") + + // Clean up + g.Free(proc, false, nil) +} + +func TestGroup_SpillToDisk(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Create group operator with aggregations + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, // Low threshold to force spilling + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + + // Initialize container with actual group data to spill + // Create simple aggregator + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(1)) + + // Initialize result buffer with aggregator only + g.ctr.result1.InitOnlyAgg(10, []aggexec.AggFuncExec{aggExec}) + + // Add actual group data to spill - create a batch with group data + groupBatch := batch.NewWithSize(1) + groupBatch.Vecs[0] = testutil.NewInt64Vector(1, types.T_int64.ToType(), proc.Mp(), false, []int64{1}) + groupBatch.SetRowCount(1) + g.ctr.result1.ToPopped[0] = groupBatch + + // Fill aggregator with data + aggExec.Fill(0, 0, []*vector.Vector{testutil.NewInt64Vector(1, types.T_int64.ToType(), proc.Mp(), false, []int64{1})}) + + // Initialize hash table with correct type + g.ctr.mtyp = HStr // Use string hash map to avoid type issues + g.ctr.keyNullable = false + require.NoError(t, g.ctr.hr.BuildHashTable(true, true, false, 0)) + + // Test spilling to disk + err = g.spillToDisk(proc) + require.NoError(t, err) + + // Verify spill file was created + require.True(t, g.ctr.spillManager.HasSpilledData()) + require.Equal(t, 1, g.ctr.spillManager.GetSpillFileCount()) + + // Verify in-memory data was cleaned up + // After spilling, the data should be cleaned up + hasData := false + for _, batch := range g.ctr.result1.ToPopped { + if batch != nil && batch.RowCount() > 0 { + hasData = true + break + } + } + require.False(t, hasData, "Data should be cleaned up after spilling") + + // Clean up + g.Free(proc, false, nil) +} + +func TestGroup_MergeSpilledData(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Create group operator + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + + // Create simple test data to spill + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(3, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 2, 3}) + testBatch.SetRowCount(3) + + // Create aggregator with data + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(3)) + + // Fill with test data + fillBatch := batch.NewWithSize(1) + fillBatch.Vecs[0] = testutil.NewInt64Vector(5, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 1, 1, 1, 1}) + fillBatch.SetRowCount(5) + require.NoError(t, aggExec.BulkFill(0, []*vector.Vector{fillBatch.Vecs[0]})) + require.NoError(t, aggExec.BulkFill(1, []*vector.Vector{fillBatch.Vecs[0]})) + require.NoError(t, aggExec.BulkFill(2, []*vector.Vector{fillBatch.Vecs[0]})) + + // Spill the data + groups := []*batch.Batch{testBatch} + aggs := []aggexec.AggFuncExec{aggExec} + + err = g.ctr.spillManager.SpillToDisk(groups, aggs) + require.NoError(t, err) + + // Verify data was spilled + require.True(t, g.ctr.spillManager.HasSpilledData()) + + // Test merging spilled data + err = g.mergeSpilledData(proc) + require.NoError(t, err) + + // Verify spill files were cleaned up + require.False(t, g.ctr.spillManager.HasSpilledData()) + + // Clean up + g.Free(proc, false, nil) +} + +func TestGroup_SpillAndMergeCycle(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Create group operator + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + + // Simplified test - just verify the methods can be called without errors + // when there's no data to spill + err := g.spillToDisk(proc) + require.NoError(t, err) + + // Test merging when there's no spilled data + err = g.mergeSpilledData(proc) + require.NoError(t, err) + + // Clean up + g.Free(proc, false, nil) +} + +func TestGroup_MergeSpilledDataWithEmptyResults(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Create group operator + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + + // Test merging when result buffer has nil ToPopped slice + // This should not panic + g.ctr.result1.ToPopped = nil + + // Create simple test data to spill + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(3, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 2, 3}) + testBatch.SetRowCount(3) + + // Create aggregator with data + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(3)) + + // Fill with test data + fillBatch := batch.NewWithSize(1) + fillBatch.Vecs[0] = testutil.NewInt64Vector(5, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 1, 1, 1, 1}) + fillBatch.SetRowCount(5) + require.NoError(t, aggExec.BulkFill(0, []*vector.Vector{fillBatch.Vecs[0]})) + require.NoError(t, aggExec.BulkFill(1, []*vector.Vector{fillBatch.Vecs[0]})) + require.NoError(t, aggExec.BulkFill(2, []*vector.Vector{fillBatch.Vecs[0]})) + + // Spill the data + groups := []*batch.Batch{testBatch} + aggs := []aggexec.AggFuncExec{aggExec} + + err = g.ctr.spillManager.SpillToDisk(groups, aggs) + require.NoError(t, err) + + // Verify data was spilled + require.True(t, g.ctr.spillManager.HasSpilledData()) + + // Test merging spilled data - this should not panic even with nil ToPopped + err = g.mergeSpilledData(proc) + require.NoError(t, err) + + // Clean up + g.Free(proc, false, nil) +} + +func TestGroup_MergeSpilledGroupsAndAggs(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Test 1: Empty result buffer - should call restoreSpilledDataAsCurrentState + t.Run("EmptyResultBuffer", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + // Ensure result buffer is empty + require.True(t, g.ctr.result1.IsEmpty()) + + // Create test spilled data + groups, aggs := createTestSpilledData(t, proc) + + err := g.mergeSpilledGroupsAndAggs(proc, groups, aggs) + require.NoError(t, err) + + // Verify data was restored + require.False(t, g.ctr.result1.IsEmpty()) + require.Equal(t, 1, len(g.ctr.result1.ToPopped)) + require.Equal(t, 1, len(g.ctr.result1.AggList)) + + // Cleanup + cleanupTestData(proc.Mp(), groups, aggs) + }) + + // Test 2: Non-empty result buffer - should merge with existing data + t.Run("NonEmptyResultBuffer", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + // Ensure result buffer has data + require.False(t, g.ctr.result1.IsEmpty()) + initialGroupCount := len(g.ctr.result1.ToPopped) + initialAggCount := len(g.ctr.result1.AggList) + + // Create test spilled data + groups, aggs := createTestSpilledData(t, proc) + + err := g.mergeSpilledGroupsAndAggs(proc, groups, aggs) + require.NoError(t, err) + + // Verify data was merged (counts should remain the same for aggregators) + require.Equal(t, initialAggCount, len(g.ctr.result1.AggList)) + require.GreaterOrEqual(t, len(g.ctr.result1.ToPopped), initialGroupCount) + + // Cleanup + cleanupTestData(proc.Mp(), groups, aggs) + }) + + // Test 3: Empty input data + t.Run("EmptyInputData", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + err := g.mergeSpilledGroupsAndAggs(proc, nil, nil) + require.NoError(t, err) + + err = g.mergeSpilledGroupsAndAggs(proc, []*batch.Batch{}, []aggexec.AggFuncExec{}) + require.NoError(t, err) + }) + + // Test 4: Error condition - nil hash table + t.Run("NilHashTable", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + g.ctr.hr.Hash = nil + + groups, aggs := createTestSpilledData(t, proc) + + err := g.mergeSpilledGroupsAndAggs(proc, groups, aggs) + require.Error(t, err) + require.Contains(t, err.Error(), "hash table or iterator is nil") + + // Cleanup + cleanupTestData(proc.Mp(), groups, aggs) + }) +} + +func TestGroup_RestoreSpilledDataAsCurrentState(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Test 1: Restore groups only + t.Run("RestoreGroupsOnly", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + groups := createTestBatches(t, proc, 3) + + err := g.restoreSpilledDataAsCurrentState(proc, groups, nil) + require.NoError(t, err) + + // Verify groups were restored + require.False(t, g.ctr.result1.IsEmpty()) + require.Equal(t, len(groups), len(g.ctr.result1.ToPopped)) + require.NotNil(t, g.ctr.hr.Hash) + require.NotNil(t, g.ctr.hr.Itr) + + // Cleanup + cleanupTestData(proc.Mp(), groups, nil) + }) + + // Test 2: Restore aggregators only + t.Run("RestoreAggregatorsOnly", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + aggs := createTestAggregators(t, proc, 2) + + err := g.restoreSpilledDataAsCurrentState(proc, nil, aggs) + require.NoError(t, err) + + // Verify aggregators were restored + require.Equal(t, len(aggs), len(g.ctr.result1.AggList)) + + // Note: Don't cleanup aggs as ownership was transferred + }) + + // Test 3: Restore both groups and aggregators + t.Run("RestoreBothGroupsAndAggregators", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + groups := createTestBatches(t, proc, 2) + aggs := createTestAggregators(t, proc, 1) + + err := g.restoreSpilledDataAsCurrentState(proc, groups, aggs) + require.NoError(t, err) + + // Verify both were restored + require.False(t, g.ctr.result1.IsEmpty()) + require.Equal(t, len(groups), len(g.ctr.result1.ToPopped)) + require.Equal(t, len(aggs), len(g.ctr.result1.AggList)) + + // Cleanup groups only (aggs ownership transferred) + cleanupTestData(proc.Mp(), groups, nil) + }) + + // Test 4: Empty/nil batches filtering + t.Run("EmptyBatchFiltering", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + // Create mix of valid, empty, and nil batches + groups := []*batch.Batch{ + nil, + batch.NewOffHeapEmpty(), + createTestBatch(t, proc, 2), + nil, + createTestBatch(t, proc, 1), + } + + err := g.restoreSpilledDataAsCurrentState(proc, groups, nil) + require.NoError(t, err) + + // Should only restore non-empty batches + require.Equal(t, 2, len(g.ctr.result1.ToPopped)) + + // Cleanup valid batches + groups[2].Clean(proc.Mp()) + groups[4].Clean(proc.Mp()) + }) +} + +func TestGroup_MergeSpilledGroups(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Test 1: Merge single batch with existing data + t.Run("MergeSingleBatch", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + initialGroupCount := g.ctr.hr.Hash.GroupCount() + + // Create spilled batch with new groups + spilledBatch := createTestBatch(t, proc, 2) + groups := []*batch.Batch{spilledBatch} + + err := g.mergeSpilledGroups(proc, groups) + require.NoError(t, err) + + // Verify groups were merged + newGroupCount := g.ctr.hr.Hash.GroupCount() + require.GreaterOrEqual(t, newGroupCount, initialGroupCount) + + // Cleanup + spilledBatch.Clean(proc.Mp()) + }) + + // Test 2: Merge multiple batches + t.Run("MergeMultipleBatches", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + initialGroupCount := g.ctr.hr.Hash.GroupCount() + + // Create multiple spilled batches + groups := createTestBatches(t, proc, 3) + + err := g.mergeSpilledGroups(proc, groups) + require.NoError(t, err) + + // Verify groups were merged + newGroupCount := g.ctr.hr.Hash.GroupCount() + require.GreaterOrEqual(t, newGroupCount, initialGroupCount) + + // Cleanup + cleanupTestData(proc.Mp(), groups, nil) + }) + + // Test 3: Handle empty and nil batches + t.Run("HandleEmptyBatches", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + initialGroupCount := g.ctr.hr.Hash.GroupCount() + + // Mix of empty, nil, and valid batches + groups := []*batch.Batch{ + nil, + batch.NewOffHeapEmpty(), + createTestBatch(t, proc, 1), + } + + err := g.mergeSpilledGroups(proc, groups) + require.NoError(t, err) + + // Should handle gracefully + newGroupCount := g.ctr.hr.Hash.GroupCount() + require.GreaterOrEqual(t, newGroupCount, initialGroupCount) + + // Cleanup + groups[2].Clean(proc.Mp()) + }) + + // Test 4: Large batch processing (multiple chunks) + t.Run("LargeBatchProcessing", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + // Create a large batch that will be processed in chunks + largeBatch := createLargeTestBatch(t, proc, 3000) // Larger than UnitLimit + groups := []*batch.Batch{largeBatch} + + err := g.mergeSpilledGroups(proc, groups) + require.NoError(t, err) + + // Cleanup + largeBatch.Clean(proc.Mp()) + }) +} + +func TestGroup_MergeSpilledAggregations(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Test 1: Merge single aggregator + t.Run("MergeSingleAggregator", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + // Create spilled aggregator + spilledAggs := createTestAggregators(t, proc, 1) + + err := g.mergeSpilledAggregations(spilledAggs) + require.NoError(t, err) + + // Cleanup spilled aggs + for _, agg := range spilledAggs { + if agg != nil { + agg.Free() + } + } + }) + + // Test 2: Merge multiple aggregators + t.Run("MergeMultipleAggregators", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + // Ensure we have multiple aggregators in current state + for len(g.ctr.result1.AggList) < 2 { + agg := createTestAggregator(t, proc) + g.ctr.result1.AggList = append(g.ctr.result1.AggList, agg) + } + + // Create spilled aggregators + spilledAggs := createTestAggregators(t, proc, 2) + + err := g.mergeSpilledAggregations(spilledAggs) + require.NoError(t, err) + + // Cleanup + for _, agg := range spilledAggs { + if agg != nil { + agg.Free() + } + } + }) + + // Test 3: Handle size mismatches + t.Run("HandleSizeMismatches", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + // Test with more spilled aggregators than current + spilledAggs := createTestAggregators(t, proc, 5) + + err := g.mergeSpilledAggregations(spilledAggs) + require.NoError(t, err) + + // Test with fewer spilled aggregators than current + spilledAggs2 := createTestAggregators(t, proc, 0) + + err = g.mergeSpilledAggregations(spilledAggs2) + require.NoError(t, err) + + // Cleanup + for _, agg := range spilledAggs { + if agg != nil { + agg.Free() + } + } + }) + + // Test 4: Handle nil aggregators + t.Run("HandleNilAggregators", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + // Create mix of nil and valid aggregators + spilledAggs := []aggexec.AggFuncExec{ + nil, + createTestAggregator(t, proc), + nil, + } + + err := g.mergeSpilledAggregations(spilledAggs) + require.NoError(t, err) + + // Cleanup non-nil aggregators + for _, agg := range spilledAggs { + if agg != nil { + agg.Free() + } + } + }) + + // Test 5: Empty aggregator lists + t.Run("EmptyAggregatorLists", func(t *testing.T) { + g := createTestGroupOperatorWithData(t, proc) + + err := g.mergeSpilledAggregations(nil) + require.NoError(t, err) + + err = g.mergeSpilledAggregations([]aggexec.AggFuncExec{}) + require.NoError(t, err) + }) +} + +// Helper functions for test setup and data creation + +func createTestGroupOperator(t *testing.T, proc *process.Process) *Group { + g := &Group{ + OperatorBase: vm.OperatorBase{}, + NeedEval: true, + Exprs: []*plan.Expr{newColumnExpression(0)}, + GroupingFlag: []bool{true}, + SpillThreshold: 1024, + Aggs: []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + }, + } + + require.NoError(t, g.Prepare(proc)) + return g +} + +func createTestGroupOperatorWithData(t *testing.T, proc *process.Process) *Group { + g := createTestGroupOperator(t, proc) + + // Initialize with some data + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(1)) + + g.ctr.result1.InitOnlyAgg(10, []aggexec.AggFuncExec{aggExec}) + + // Add a test batch + testBatch := createTestBatch(t, proc, 1) + g.ctr.result1.ToPopped[0] = testBatch + + // Initialize hash table + g.ctr.mtyp = H8 + g.ctr.keyNullable = false + require.NoError(t, g.ctr.hr.BuildHashTable(true, false, false, 0)) + + return g +} + +func createTestBatch(t *testing.T, proc *process.Process, rowCount int) *batch.Batch { + values := make([]int64, rowCount) + for i := 0; i < rowCount; i++ { + values[i] = int64(i + 1) + } + + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(rowCount, types.T_int64.ToType(), proc.Mp(), false, values) + testBatch.SetRowCount(rowCount) + + return testBatch +} + +func createLargeTestBatch(t *testing.T, proc *process.Process, rowCount int) *batch.Batch { + values := make([]int64, rowCount) + for i := 0; i < rowCount; i++ { + values[i] = int64(i%100 + 1) // Create groups with some repetition + } + + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(rowCount, types.T_int64.ToType(), proc.Mp(), false, values) + testBatch.SetRowCount(rowCount) + + return testBatch +} + +func createTestBatches(t *testing.T, proc *process.Process, count int) []*batch.Batch { + batches := make([]*batch.Batch, count) + for i := 0; i < count; i++ { + batches[i] = createTestBatch(t, proc, i+1) + } + return batches +} + +func createTestAggregator(t *testing.T, proc *process.Process) aggexec.AggFuncExec { + agg, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, agg.GroupGrow(1)) + return agg +} + +func createTestAggregators(t *testing.T, proc *process.Process, count int) []aggexec.AggFuncExec { + aggs := make([]aggexec.AggFuncExec, count) + for i := 0; i < count; i++ { + aggs[i] = createTestAggregator(t, proc) + } + return aggs +} + +func createTestSpilledData(t *testing.T, proc *process.Process) ([]*batch.Batch, []aggexec.AggFuncExec) { + groups := createTestBatches(t, proc, 1) + aggs := createTestAggregators(t, proc, 1) + return groups, aggs +} + +func cleanupTestData(mp *mpool.MPool, groups []*batch.Batch, aggs []aggexec.AggFuncExec) { + for _, batch := range groups { + if batch != nil { + batch.Clean(mp) + } + } + for _, agg := range aggs { + if agg != nil { + agg.Free() + } + } +} + +func TestGroup_HashTableSpilling(t *testing.T) { + proc := testutil.NewProcess(t) + defer proc.Free() + + // Test hash table serialization + t.Run("HashTableSerialization", func(t *testing.T) { + hr := &ResHashRelated{} + + // Test with nil hash table + data, err := hr.MarshalHashTable() + require.NoError(t, err) + require.Nil(t, data) + + // Test with actual hash table + err = hr.BuildHashTable(true, false, false, 10) + require.NoError(t, err) + + data, err = hr.MarshalHashTable() + require.NoError(t, err) + require.NotNil(t, data) + require.True(t, len(data) > 0) + + // Test unmarshaling + hr2 := &ResHashRelated{} + err = hr2.UnmarshalHashTable(data, false, false, 8) + require.NoError(t, err) + require.NotNil(t, hr2.Hash) + + hr.Free0() + hr2.Free0() + }) + + // Test spill manager with hash table + t.Run("SpillManagerWithHashTable", func(t *testing.T) { + groupByTypes := []types.Type{types.T_int64.ToType()} + aggInfos := []aggexec.AggFuncExecExpression{ + aggexec.MakeAggFunctionExpression( + aggexec.AggIdOfCountStar, + false, + []*plan.Expr{newColumnExpression(0)}, + nil, + ), + } + + sm, err := NewSpillManager(proc, groupByTypes, aggInfos) + require.NoError(t, err) + + // Create test data + testBatch := batch.NewWithSize(1) + testBatch.Vecs[0] = testutil.NewInt64Vector(3, types.T_int64.ToType(), proc.Mp(), false, []int64{1, 2, 3}) + testBatch.SetRowCount(3) + + groups := []*batch.Batch{testBatch} + + // Create aggregator + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(3)) + aggs := []aggexec.AggFuncExec{aggExec} + + // Create hash table data + hr := &ResHashRelated{} + err = hr.BuildHashTable(true, false, false, 3) + require.NoError(t, err) + + hashTableData, err := hr.MarshalHashTable() + require.NoError(t, err) + + // Test spilling with hash table + err = sm.SpillToDiskWithHashTable(groups, aggs, hashTableData, false, false, 8) + require.NoError(t, err) + require.Equal(t, 1, sm.GetSpillFileCount()) + + // Test reading with hash table + readGroups, readAggs, readHashTableData, isStrHash, keyNullable, keyWidth, err := sm.ReadSpilledDataWithHashTable(sm.spillFiles[0]) + require.NoError(t, err) + require.Len(t, readGroups, 1) + require.Len(t, readAggs, 1) + require.NotNil(t, readHashTableData) + require.False(t, isStrHash) + require.False(t, keyNullable) + require.Equal(t, 8, keyWidth) + + // Cleanup + hr.Free0() + require.NoError(t, sm.Cleanup(t.Context())) + }) + + // Test group operator with hash table spilling + t.Run("GroupOperatorHashTableSpill", func(t *testing.T) { + g := createTestGroupOperator(t, proc) + + // Initialize with data that includes hash table + aggExec, err := aggexec.MakeAgg(proc, aggexec.AggIdOfCountStar, false, types.T_int64.ToType()) + require.NoError(t, err) + require.NoError(t, aggExec.GroupGrow(3)) + + g.ctr.result1.InitOnlyAgg(10, []aggexec.AggFuncExec{aggExec}) + + // Add test batches + testBatch := createTestBatch(t, proc, 3) + g.ctr.result1.ToPopped[0] = testBatch + + // Initialize hash table with test data + g.ctr.mtyp = H8 + g.ctr.keyNullable = false + g.ctr.keyWidth = 8 + require.NoError(t, g.ctr.hr.BuildHashTable(true, false, false, 3)) + + // Insert test groups into hash table + _, _, err = g.ctr.hr.Itr.Insert(0, 3, testBatch.Vecs) + require.NoError(t, err) + + // Test spilling preserves hash table + err = g.spillToDisk(proc) + require.NoError(t, err) + + // Verify spill file was created + require.True(t, g.ctr.spillManager.HasSpilledData()) + require.Equal(t, 1, g.ctr.spillManager.GetSpillFileCount()) + + // Test merging spilled data restores hash table + err = g.mergeSpilledData(proc) + require.NoError(t, err) + + // Verify hash table was restored + require.NotNil(t, g.ctr.hr.Hash) + require.False(t, g.ctr.spillManager.HasSpilledData()) + + // Clean up + g.Free(proc, false, nil) + }) +} diff --git a/pkg/sql/colexec/group/testspill/spill_test.go b/pkg/sql/colexec/group/testspill/spill_test.go new file mode 100644 index 0000000000000..9ae0e3f0dc306 --- /dev/null +++ b/pkg/sql/colexec/group/testspill/spill_test.go @@ -0,0 +1,101 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testspill + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/matrixorigin/matrixone/pkg/embed" + "github.com/stretchr/testify/require" +) + +func TestSpill(t *testing.T) { + + // start cluster + cluster, err := embed.NewCluster( + embed.WithCNCount(3), + embed.WithTesting(), + embed.WithPreStart(func(service embed.ServiceOperator) { + }), + ) + require.NoError(t, err) + err = cluster.Start() + require.NoError(t, err) + defer cluster.Close() + + cn0, err := cluster.GetCNService(0) + require.NoError(t, err) + dsn := fmt.Sprintf("dump:111@tcp(127.0.0.1:%d)/", + cn0.GetServiceConfig().CN.Frontend.Port, + ) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + defer db.Close() + + // database + _, err = db.Exec(` create database test `) + require.NoError(t, err) + + // table + _, err = db.Exec(` + use test; + create table sales ( + id int, + product_id int, + customer_id int, + sale_date date, + amount decimal(10, 2) + ); + `) + require.NoError(t, err) + + // data + _, err = db.Exec(` + use test; + insert into sales (id, product_id, customer_id, sale_date, amount) + select g.result as id, + floor(1 + (rand() * 10000)) as product_id, + floor(1 + (rand() * 10000)) as customer_id, + current_date - interval floor(rand() * 365) day as sale_date, + floor(rand() * 1000) as amount + from generate_series(2000 * 10000) g + `, + ) + require.NoError(t, err) + + // query, large number of groups + rows, err := db.Query(` + select + product_id, + customer_id, + sale_date, + sum(amount) as total_sales + from sales + group by product_id, customer_id, sale_date + order by total_sales desc; + `, + ) + require.NoError(t, err) + n := 0 + for rows.Next() { + n++ + } + require.NoError(t, rows.Close()) + t.Logf("results: %d", n) + +} diff --git a/pkg/sql/colexec/group/types.go b/pkg/sql/colexec/group/types.go index 0243e4cda69d2..cba9ceca2e93c 100644 --- a/pkg/sql/colexec/group/types.go +++ b/pkg/sql/colexec/group/types.go @@ -93,6 +93,9 @@ type Group struct { GroupingFlag []bool // agg info and agg column. Aggs []aggexec.AggFuncExecExpression + + // SpillThreshold is the memory threshold for spilling in bytes + SpillThreshold int64 } func (group *Group) evaluateGroupByAndAgg(proc *process.Process, bat *batch.Batch) (err error) { @@ -161,6 +164,8 @@ type container struct { result1 GroupResultBuffer // result if NeedEval is false. result2 GroupResultNoneBlock + + spillManager *SpillManager } func (ctr *container) isDataSourceEmpty() bool { @@ -168,8 +173,14 @@ func (ctr *container) isDataSourceEmpty() bool { } func (group *Group) Free(proc *process.Process, _ bool, _ error) { - group.freeCannotReuse(proc.Mp()) + // Use defer to ensure spill manager cleanup happens even if earlier operations panic + defer func() { + if group.ctr.spillManager != nil { + group.ctr.spillManager.Cleanup(proc.Ctx) + } + }() + group.freeCannotReuse(proc.Mp()) group.ctr.freeGroupEvaluate() group.ctr.freeAggEvaluate() group.FreeProjection(proc) diff --git a/pkg/testutil/util_compare.go b/pkg/testutil/util_compare.go index 52db2ae175b03..f589ee2768f1b 100644 --- a/pkg/testutil/util_compare.go +++ b/pkg/testutil/util_compare.go @@ -17,9 +17,13 @@ package testutil import ( "bytes" "reflect" + "testing" + "github.com/matrixorigin/matrixone/pkg/container/batch" "github.com/matrixorigin/matrixone/pkg/container/nulls" "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/matrixorigin/matrixone/pkg/sql/colexec/aggexec" + "github.com/stretchr/testify/require" ) func CompareVectors(expected *vector.Vector, got *vector.Vector) bool { @@ -85,3 +89,36 @@ func CompareVectors(expected *vector.Vector, got *vector.Vector) bool { } } } + +// CompareBatches compares two batches for deep equality. +func CompareBatches(t *testing.T, expected, actual *batch.Batch) { + if expected == nil && actual == nil { + return + } + if expected == nil || actual == nil { + t.Fatalf("one batch is nil, the other is not. Expected: %v, Actual: %v", expected, actual) + } + + require.Equal(t, expected.RowCount(), actual.RowCount(), "row count mismatch") + + require.Equal(t, len(expected.Vecs), len(actual.Vecs), "vector count mismatch") + + require.Equal(t, len(expected.Attrs), len(actual.Attrs), "attribute count mismatch") + + for i := range expected.Attrs { + require.Equal(t, expected.Attrs[i], actual.Attrs[i], "attribute name mismatch at index %d", i) + } + + for i := range expected.Vecs { + require.True(t, CompareVectors(expected.Vecs[i], actual.Vecs[i]), "vector content mismatch at index %d", i) + } + + require.Equal(t, len(expected.Aggs), len(actual.Aggs), "aggregator count mismatch") + for i := range expected.Aggs { + expectedBytes, err := aggexec.MarshalAggFuncExec(expected.Aggs[i]) + require.NoError(t, err) + actualBytes, err := aggexec.MarshalAggFuncExec(actual.Aggs[i]) + require.NoError(t, err) + require.Equal(t, expectedBytes, actualBytes, "aggregator state mismatch at index %d", i) + } +}