From b482e0338cd6d6586cda5ef3efd09c56c2c9b56e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=A8=81?= Date: Thu, 16 Jan 2025 14:07:20 +0800 Subject: [PATCH] feat: add unbuffered stream in go runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 罗威 --- runtime/Go/antlr/v4/lexer.go | 2 +- runtime/Go/antlr/v4/parser.go | 6 +- runtime/Go/antlr/v4/token_source.go | 2 +- runtime/Go/antlr/v4/unbuffered_char_stream.go | 193 ++++++++++++++ .../Go/antlr/v4/unbuffered_token_stream.go | 241 ++++++++++++++++++ 5 files changed, 439 insertions(+), 5 deletions(-) create mode 100644 runtime/Go/antlr/v4/unbuffered_char_stream.go create mode 100644 runtime/Go/antlr/v4/unbuffered_token_stream.go diff --git a/runtime/Go/antlr/v4/lexer.go b/runtime/Go/antlr/v4/lexer.go index e5594b2168..c1faf81e4e 100644 --- a/runtime/Go/antlr/v4/lexer.go +++ b/runtime/Go/antlr/v4/lexer.go @@ -159,7 +159,7 @@ func (b *BaseLexer) GetTokenFactory() TokenFactory { return b.factory } -func (b *BaseLexer) setTokenFactory(f TokenFactory) { +func (b *BaseLexer) SetTokenFactory(f TokenFactory) { b.factory = f } diff --git a/runtime/Go/antlr/v4/parser.go b/runtime/Go/antlr/v4/parser.go index fb57ac15db..56155dca53 100644 --- a/runtime/Go/antlr/v4/parser.go +++ b/runtime/Go/antlr/v4/parser.go @@ -316,9 +316,9 @@ func (p *BaseParser) GetTokenFactory() TokenFactory { return p.input.GetTokenSource().GetTokenFactory() } -// setTokenFactory is used to tell our token source and error strategy about a new way to create tokens. -func (p *BaseParser) setTokenFactory(factory TokenFactory) { - p.input.GetTokenSource().setTokenFactory(factory) +// SetTokenFactory is used to tell our token source and error strategy about a new way to create tokens. +func (p *BaseParser) SetTokenFactory(factory TokenFactory) { + p.input.GetTokenSource().SetTokenFactory(factory) } // GetATNWithBypassAlts - the ATN with bypass alternatives is expensive to create, so we create it diff --git a/runtime/Go/antlr/v4/token_source.go b/runtime/Go/antlr/v4/token_source.go index a3f36eaa67..1040aa1bfc 100644 --- a/runtime/Go/antlr/v4/token_source.go +++ b/runtime/Go/antlr/v4/token_source.go @@ -12,6 +12,6 @@ type TokenSource interface { GetCharPositionInLine() int GetInputStream() CharStream GetSourceName() string - setTokenFactory(factory TokenFactory) + SetTokenFactory(factory TokenFactory) GetTokenFactory() TokenFactory } diff --git a/runtime/Go/antlr/v4/unbuffered_char_stream.go b/runtime/Go/antlr/v4/unbuffered_char_stream.go new file mode 100644 index 0000000000..322d6faff7 --- /dev/null +++ b/runtime/Go/antlr/v4/unbuffered_char_stream.go @@ -0,0 +1,193 @@ +package antlr + +import ( + "bufio" + "io" +) + +type UnbufferedCharStream struct { + data []rune + n int + p int + numMarkers int + lastChar rune + lastCharBufferStart rune + currentCharIndex int + input *bufio.Reader + name string +} + +var _ CharStream = (*UnbufferedCharStream)(nil) + +func NewUnbufferedCharStream(input io.Reader, bufferSize int) *UnbufferedCharStream { + stream := &UnbufferedCharStream{ + data: make([]rune, bufferSize), + n: 0, + p: 0, + input: bufio.NewReader(input), + } + + stream.Fill(1) + return stream +} + +func (ucs *UnbufferedCharStream) Consume() { + if ucs.LA(1) == TokenEOF { + panic("cannot consume EOF") + } + ucs.lastChar = ucs.data[ucs.p] + if ucs.p == ucs.n-1 && ucs.numMarkers == 0 { + ucs.n = 0 + ucs.p = -1 + ucs.lastCharBufferStart = ucs.lastChar + } + ucs.p++ + ucs.currentCharIndex++ + ucs.Sync(1) +} + +func (ucs *UnbufferedCharStream) Sync(want int) { + need := (ucs.p + want - 1) - ucs.n + 1 + if need > 0 { + ucs.Fill(need) + } +} + +func (ucs *UnbufferedCharStream) Fill(n int) int { + for i := 0; i < n; i++ { + if ucs.n > 0 && ucs.data[ucs.n-1] == TokenEOF { + return i + } + + c, _, err := ucs.input.ReadRune() + if err != nil { + if err != io.EOF { + panic(err) + } + + c = TokenEOF + } + ucs.Add(c) + } + return n +} + +func (ucs *UnbufferedCharStream) Add(c rune) { + if ucs.n >= len(ucs.data) { + newData := make([]rune, len(ucs.data)*2) + copy(newData, ucs.data) + ucs.data = newData + } + ucs.data[ucs.n] = c + ucs.n++ +} + +func (ucs *UnbufferedCharStream) LA(i int) int { + if i == -1 { + return int(ucs.lastChar) + } + ucs.Sync(i) + index := ucs.p + i - 1 + if index < 0 { + panic("index out of range") + } + if index >= ucs.n { + return TokenEOF + } + return int(ucs.data[index]) +} + +func (ucs *UnbufferedCharStream) Mark() int { + if ucs.numMarkers == 0 { + ucs.lastCharBufferStart = ucs.lastChar + } + mark := -ucs.numMarkers - 1 + ucs.numMarkers++ + return mark +} + +func (ucs *UnbufferedCharStream) Release(marker int) { + expectedMark := -ucs.numMarkers + if marker != expectedMark { + panic("release() called with an invalid marker.") + } + ucs.numMarkers-- + if ucs.numMarkers == 0 && ucs.p > 0 { + copy(ucs.data, ucs.data[ucs.p:ucs.n]) + ucs.n -= ucs.p + ucs.p = 0 + ucs.lastCharBufferStart = ucs.lastChar + } +} + +func (ucs *UnbufferedCharStream) Index() int { + return ucs.currentCharIndex +} + +func (ucs *UnbufferedCharStream) Seek(index int) { + if index == ucs.currentCharIndex { + return + } + if index > ucs.currentCharIndex { + ucs.Sync(index - ucs.currentCharIndex) + index = min(index, ucs.BufferStartIndex()+ucs.n-1) + } + i := index - ucs.BufferStartIndex() + if i < 0 { + panic("cannot seek to negative index") + } + if i >= ucs.n { + panic("seek to index outside buffer") + } + ucs.p = i + ucs.currentCharIndex = index + if ucs.p == 0 { + ucs.lastChar = ucs.lastCharBufferStart + } else { + ucs.lastChar = ucs.data[ucs.p-1] + } +} + +func (ucs *UnbufferedCharStream) Size() int { + panic("Unbuffered stream cannot know its size") +} + +func (ucs *UnbufferedCharStream) GetSourceName() string { + if ucs.name == "" { + return "Unknown" + } + return ucs.name +} + +func (ucs *UnbufferedCharStream) GetText(start, stop int) string { + return ucs.GetTextFromInterval(NewInterval(start, stop)) +} + +func (ucs *UnbufferedCharStream) GetTextFromTokens(start, end Token) string { + if start == nil || end == nil { + return "" + } + + return ucs.GetTextFromInterval(NewInterval(start.GetTokenIndex(), end.GetTokenIndex())) +} + +func (ucs *UnbufferedCharStream) GetTextFromInterval(interval Interval) string { + if interval.Start < 0 || interval.Stop < interval.Start-1 { + panic("invalid interval") + } + bufferStartIndex := ucs.BufferStartIndex() + if ucs.n > 0 && ucs.data[ucs.n-1] == TokenEOF { + if interval.Start+interval.Length() > bufferStartIndex+ucs.n { + panic("the interval extends past the end of the stream") + } + } + if interval.Start < bufferStartIndex || interval.Stop >= bufferStartIndex+ucs.n { + panic("interval outside buffer") + } + i := interval.Start - bufferStartIndex + return string(ucs.data[i : i+interval.Length()+1]) +} + +func (ucs *UnbufferedCharStream) BufferStartIndex() int { + return ucs.currentCharIndex - ucs.p +} diff --git a/runtime/Go/antlr/v4/unbuffered_token_stream.go b/runtime/Go/antlr/v4/unbuffered_token_stream.go new file mode 100644 index 0000000000..639de549b6 --- /dev/null +++ b/runtime/Go/antlr/v4/unbuffered_token_stream.go @@ -0,0 +1,241 @@ +package antlr + +import ( + "fmt" + "strings" +) + +// UnbufferedTokenStream 实现了 ITokenStream 接口 +type UnbufferedTokenStream struct { + tokenSource TokenSource + tokens []Token + n int + p int + numMarkers int + lastToken Token + lastTokenBufferStart Token + currentTokenIndex int +} + +var _ TokenStream = (*UnbufferedTokenStream)(nil) + +// NewUnbufferedTokenStream 创建一个新的UnbufferedTokenStream实例 +func NewUnbufferedTokenStream(tokenSource TokenSource, bufferSize int) *UnbufferedTokenStream { + stream := &UnbufferedTokenStream{ + tokenSource: tokenSource, + tokens: make([]Token, bufferSize), + n: 0, + } + stream.Fill(1) // prime the pump + return stream +} + +// Get 获取指定索引的token +func (u *UnbufferedTokenStream) Get(i int) Token { + bufferStartIndex := u.GetBufferStartIndex() + if i < bufferStartIndex || i >= bufferStartIndex+u.n { + panic(fmt.Sprintf("get(%d) outside buffer: %d..%d", i, bufferStartIndex, bufferStartIndex+u.n)) + } + return u.tokens[i-bufferStartIndex] +} + +// LT 查看前面的token +func (u *UnbufferedTokenStream) LT(i int) Token { + if i == -1 { + return u.lastToken + } + u.Sync(i) + index := u.p + i - 1 + if index < 0 { + panic(fmt.Sprintf("LT(%d) gives negative index", i)) + } + if index >= u.n { + return u.tokens[u.n-1] // return EOF token + } + return u.tokens[index] +} + +// LA 获取token类型 +func (u *UnbufferedTokenStream) LA(i int) int { + t := u.LT(i) + if t == nil { + return TokenEOF + } + return t.GetTokenType() +} + +// GetTokenSource 获取token源 +func (u *UnbufferedTokenStream) GetTokenSource() TokenSource { + return u.tokenSource +} + +func (u *UnbufferedTokenStream) SetTokenSource(tokenSource TokenSource) { + u.tokenSource = tokenSource +} + +// Consume 消费一个token +func (u *UnbufferedTokenStream) Consume() { + if u.LA(1) == TokenEOF { + panic("cannot consume EOF") + } + + u.lastToken = u.tokens[u.p] + + if u.p == u.n-1 && u.numMarkers == 0 { + u.n = 0 + u.p = -1 + u.lastTokenBufferStart = u.lastToken + } + u.p++ + u.currentTokenIndex++ + u.Sync(1) +} + +// Sync 确保缓冲区有足够的token +func (u *UnbufferedTokenStream) Sync(want int) { + need := (u.p + want - 1) - u.n + 1 + if need > 0 { + u.Fill(need) + } +} + +// Fill 填充token缓冲区 +func (u *UnbufferedTokenStream) Fill(n int) int { + for i := 0; i < n; i++ { + if u.n > 0 && u.tokens[u.n-1].GetTokenType() == TokenEOF { + return i + } + t := u.tokenSource.NextToken() + u.Add(t) + } + return n +} + +// Add 添加token到缓冲区 +func (u *UnbufferedTokenStream) Add(t Token) { + if u.n >= len(u.tokens) { + newTokens := make([]Token, len(u.tokens)*2) + copy(newTokens, u.tokens) + u.tokens = newTokens + } + + t.SetTokenIndex(u.GetBufferStartIndex() + u.n) + + u.tokens[u.n] = t + u.n++ +} + +// Mark 标记当前位置 +func (u *UnbufferedTokenStream) Mark() int { + if u.numMarkers == 0 { + u.lastTokenBufferStart = u.lastToken + } + u.numMarkers++ + return -u.numMarkers +} + +// Release 释放标记 +func (u *UnbufferedTokenStream) Release(marker int) { + expectedMark := -u.numMarkers + if marker != expectedMark { + panic("release() called with an invalid marker") + } + + u.numMarkers-- + if u.numMarkers == 0 && u.p > 0 { + copy(u.tokens, u.tokens[u.p:u.n]) + u.n = u.n - u.p + u.p = 0 + u.lastTokenBufferStart = u.lastToken + } +} + +// GetIndex 获取当前token索引 +func (u *UnbufferedTokenStream) Index() int { + return u.currentTokenIndex +} + +// Seek 跳转到指定位置 +func (u *UnbufferedTokenStream) Seek(index int) { + if index == u.currentTokenIndex { + return + } + + if index > u.currentTokenIndex { + u.Sync(index - u.currentTokenIndex) + index = min(index, u.GetBufferStartIndex()+u.n-1) + } + + i := index - u.GetBufferStartIndex() + if i < 0 { + panic(fmt.Sprintf("cannot seek to negative index %d", index)) + } + if i >= u.n { + panic(fmt.Sprintf("seek to index outside buffer: %d not in %d..%d", index, u.GetBufferStartIndex(), u.GetBufferStartIndex()+u.n)) + } + + u.p = i + u.currentTokenIndex = index + + if u.p == 0 { + u.lastToken = u.lastTokenBufferStart + } else { + u.lastToken = u.tokens[u.p-1] + } +} + +// GetText 获取文本 +func (u *UnbufferedTokenStream) GetTextFromInterval(interval Interval) string { + bufferStartIndex := u.GetBufferStartIndex() + bufferStopIndex := bufferStartIndex + len(u.tokens) - 1 + + if interval.Start < bufferStartIndex || interval.Stop > bufferStopIndex { + panic(fmt.Sprintf("interval %v not in token buffer window: %d..%d", interval, bufferStartIndex, bufferStopIndex)) + } + + a := interval.Start - bufferStartIndex + b := interval.Stop - bufferStartIndex + + var buf strings.Builder + for i := a; i <= b; i++ { + t := u.tokens[i] + buf.WriteString(t.GetText()) + } + + return buf.String() +} + +// GetBufferStartIndex 获取缓冲区起始索引 +func (u *UnbufferedTokenStream) GetBufferStartIndex() int { + return u.currentTokenIndex - u.p +} + +// GetSourceName 获取源名称 +func (u *UnbufferedTokenStream) GetSourceName() string { + return u.tokenSource.GetSourceName() +} + +// Size 获取流大小(不支持) +func (u *UnbufferedTokenStream) Size() int { + panic("Unbuffered stream cannot know its size") +} + +func (u *UnbufferedTokenStream) Reset() { + panic("cannot reset unbuffered stream") +} + +func (u *UnbufferedTokenStream) GetAllText() string { + return u.GetTextFromInterval(NewInterval(0, len(u.tokens)-1)) +} + +func (u *UnbufferedTokenStream) GetTextFromRuleContext(interval RuleContext) string { + return u.GetTextFromInterval(interval.GetSourceInterval()) +} + +func (u *UnbufferedTokenStream) GetTextFromTokens(start, end Token) string { + if start == nil || end == nil { + return "" + } + + return u.GetTextFromInterval(NewInterval(start.GetTokenIndex(), end.GetTokenIndex())) +}