[442] | 1 | package soju
|
---|
| 2 |
|
---|
| 3 | import (
|
---|
[667] | 4 | "context"
|
---|
[442] | 5 | "fmt"
|
---|
| 6 | "time"
|
---|
| 7 |
|
---|
[488] | 8 | "git.sr.ht/~sircmpwn/go-bare"
|
---|
[442] | 9 | "gopkg.in/irc.v3"
|
---|
| 10 | )
|
---|
| 11 |
|
---|
| 12 | const messageRingBufferCap = 4096
|
---|
| 13 |
|
---|
[488] | 14 | type memoryMsgID struct {
|
---|
| 15 | Seq bare.Uint
|
---|
| 16 | }
|
---|
| 17 |
|
---|
| 18 | func (memoryMsgID) msgIDType() msgIDType {
|
---|
| 19 | return msgIDMemory
|
---|
| 20 | }
|
---|
| 21 |
|
---|
[442] | 22 | func parseMemoryMsgID(s string) (netID int64, entity string, seq uint64, err error) {
|
---|
[488] | 23 | var id memoryMsgID
|
---|
| 24 | netID, entity, err = parseMsgID(s, &id)
|
---|
[442] | 25 | if err != nil {
|
---|
| 26 | return 0, "", 0, err
|
---|
| 27 | }
|
---|
[488] | 28 | return netID, entity, uint64(id.Seq), nil
|
---|
[442] | 29 | }
|
---|
| 30 |
|
---|
| 31 | func formatMemoryMsgID(netID int64, entity string, seq uint64) string {
|
---|
[488] | 32 | id := memoryMsgID{bare.Uint(seq)}
|
---|
| 33 | return formatMsgID(netID, entity, &id)
|
---|
[442] | 34 | }
|
---|
| 35 |
|
---|
| 36 | type ringBufferKey struct {
|
---|
| 37 | networkID int64
|
---|
| 38 | entity string
|
---|
| 39 | }
|
---|
| 40 |
|
---|
| 41 | type memoryMessageStore struct {
|
---|
| 42 | buffers map[ringBufferKey]*messageRingBuffer
|
---|
| 43 | }
|
---|
| 44 |
|
---|
[517] | 45 | var _ messageStore = (*memoryMessageStore)(nil)
|
---|
| 46 |
|
---|
[442] | 47 | func newMemoryMessageStore() *memoryMessageStore {
|
---|
| 48 | return &memoryMessageStore{
|
---|
| 49 | buffers: make(map[ringBufferKey]*messageRingBuffer),
|
---|
| 50 | }
|
---|
| 51 | }
|
---|
| 52 |
|
---|
| 53 | func (ms *memoryMessageStore) Close() error {
|
---|
| 54 | ms.buffers = nil
|
---|
| 55 | return nil
|
---|
| 56 | }
|
---|
| 57 |
|
---|
[666] | 58 | func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
|
---|
[442] | 59 | k := ringBufferKey{networkID: network.ID, entity: entity}
|
---|
| 60 | if rb, ok := ms.buffers[k]; ok {
|
---|
| 61 | return rb
|
---|
| 62 | }
|
---|
| 63 | rb := newMessageRingBuffer(messageRingBufferCap)
|
---|
| 64 | ms.buffers[k] = rb
|
---|
| 65 | return rb
|
---|
| 66 | }
|
---|
| 67 |
|
---|
[666] | 68 | func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
|
---|
[442] | 69 | var seq uint64
|
---|
| 70 | k := ringBufferKey{networkID: network.ID, entity: entity}
|
---|
| 71 | if rb, ok := ms.buffers[k]; ok {
|
---|
| 72 | seq = rb.cur
|
---|
| 73 | }
|
---|
| 74 | return formatMemoryMsgID(network.ID, entity, seq), nil
|
---|
| 75 | }
|
---|
| 76 |
|
---|
[666] | 77 | func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
|
---|
[665] | 78 | switch msg.Command {
|
---|
| 79 | case "PRIVMSG", "NOTICE":
|
---|
| 80 | default:
|
---|
| 81 | return "", nil
|
---|
| 82 | }
|
---|
| 83 |
|
---|
[442] | 84 | k := ringBufferKey{networkID: network.ID, entity: entity}
|
---|
| 85 | rb, ok := ms.buffers[k]
|
---|
| 86 | if !ok {
|
---|
| 87 | rb = newMessageRingBuffer(messageRingBufferCap)
|
---|
| 88 | ms.buffers[k] = rb
|
---|
| 89 | }
|
---|
| 90 |
|
---|
| 91 | seq := rb.Append(msg)
|
---|
| 92 | return formatMemoryMsgID(network.ID, entity, seq), nil
|
---|
| 93 | }
|
---|
| 94 |
|
---|
[667] | 95 | func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
|
---|
[442] | 96 | _, _, seq, err := parseMemoryMsgID(id)
|
---|
| 97 | if err != nil {
|
---|
| 98 | return nil, err
|
---|
| 99 | }
|
---|
| 100 |
|
---|
| 101 | k := ringBufferKey{networkID: network.ID, entity: entity}
|
---|
| 102 | rb, ok := ms.buffers[k]
|
---|
| 103 | if !ok {
|
---|
| 104 | return nil, nil
|
---|
| 105 | }
|
---|
| 106 |
|
---|
| 107 | return rb.LoadLatestSeq(seq, limit)
|
---|
| 108 | }
|
---|
| 109 |
|
---|
| 110 | type messageRingBuffer struct {
|
---|
| 111 | buf []*irc.Message
|
---|
[444] | 112 | cur uint64
|
---|
[442] | 113 | }
|
---|
| 114 |
|
---|
| 115 | func newMessageRingBuffer(capacity int) *messageRingBuffer {
|
---|
| 116 | return &messageRingBuffer{
|
---|
| 117 | buf: make([]*irc.Message, capacity),
|
---|
| 118 | cur: 1,
|
---|
| 119 | }
|
---|
| 120 | }
|
---|
| 121 |
|
---|
| 122 | func (rb *messageRingBuffer) cap() uint64 {
|
---|
| 123 | return uint64(len(rb.buf))
|
---|
| 124 | }
|
---|
| 125 |
|
---|
| 126 | func (rb *messageRingBuffer) Append(msg *irc.Message) uint64 {
|
---|
| 127 | seq := rb.cur
|
---|
| 128 | i := int(seq % rb.cap())
|
---|
| 129 | rb.buf[i] = msg
|
---|
| 130 | rb.cur++
|
---|
| 131 | return seq
|
---|
| 132 | }
|
---|
| 133 |
|
---|
| 134 | func (rb *messageRingBuffer) LoadLatestSeq(seq uint64, limit int) ([]*irc.Message, error) {
|
---|
| 135 | if seq > rb.cur {
|
---|
| 136 | return nil, fmt.Errorf("loading messages from sequence number (%v) greater than current (%v)", seq, rb.cur)
|
---|
| 137 | } else if seq == rb.cur {
|
---|
| 138 | return nil, nil
|
---|
| 139 | }
|
---|
| 140 |
|
---|
| 141 | // The query excludes the message with the sequence number seq
|
---|
| 142 | diff := rb.cur - seq - 1
|
---|
| 143 | if diff > rb.cap() {
|
---|
| 144 | // We dropped diff - cap entries
|
---|
| 145 | diff = rb.cap()
|
---|
| 146 | }
|
---|
| 147 | if int(diff) > limit {
|
---|
| 148 | diff = uint64(limit)
|
---|
| 149 | }
|
---|
| 150 |
|
---|
| 151 | l := make([]*irc.Message, int(diff))
|
---|
| 152 | for i := 0; i < int(diff); i++ {
|
---|
| 153 | j := int((rb.cur - diff + uint64(i)) % rb.cap())
|
---|
| 154 | l[i] = rb.buf[j]
|
---|
| 155 | }
|
---|
| 156 |
|
---|
| 157 | return l, nil
|
---|
| 158 | }
|
---|