Skip to content

Commit

Permalink
Merge pull request #1055 from gotd/feat/updates-rework
Browse files Browse the repository at this point in the history
feat(updates): rework
  • Loading branch information
ernado authored Apr 15, 2023
2 parents c6c2a09 + 665ad92 commit b755722
Show file tree
Hide file tree
Showing 19 changed files with 557 additions and 440 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ linters:
enable:
- depguard
- dogsled
- dupl
# - dupl
- errcheck
- gochecknoinits
- goconst
Expand Down
14 changes: 5 additions & 9 deletions examples/updates/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,10 @@ func run(ctx context.Context) error {
if err != nil {
return err
}

// Notify update manager about authentication.
if err := gaps.Auth(ctx, client.API(), user.ID, user.Bot, true); err != nil {
return err
}
defer func() { _ = gaps.Logout() }()

<-ctx.Done()
return ctx.Err()
return gaps.Run(ctx, client.API(), user.ID, updates.AuthOptions{
OnStart: func(ctx context.Context) {
log.Info("Gaps started")
},
})
})
}
4 changes: 2 additions & 2 deletions telegram/peers/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// SetChannelAccessHash implements updates.ChannelAccessHasher.
func (m *Manager) SetChannelAccessHash(userID, channelID, accessHash int64) error {
func (m *Manager) SetChannelAccessHash(ctx context.Context, userID, channelID, accessHash int64) error {
myID, ok := m.myID()
if !ok || myID != userID {
return nil
Expand All @@ -25,7 +25,7 @@ func (m *Manager) SetChannelAccessHash(userID, channelID, accessHash int64) erro
}

// GetChannelAccessHash implements updates.ChannelAccessHasher.
func (m *Manager) GetChannelAccessHash(userID, channelID int64) (accessHash int64, found bool, err error) {
func (m *Manager) GetChannelAccessHash(ctx context.Context, userID, channelID int64) (accessHash int64, found bool, err error) {
myID, ok := m.myID()
if !ok || myID != userID {
return 0, false, nil
Expand Down
20 changes: 7 additions & 13 deletions telegram/peers/manager_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package peers_test

import (
"context"
"fmt"

"github.com/go-faster/errors"
"go.uber.org/zap"

"github.com/gotd/td/telegram"
Expand Down Expand Up @@ -38,7 +38,10 @@ func ExampleManager() {
})
h = peerManager.UpdateHook(gaps)

if err := client.Run(context.TODO(), func(ctx context.Context) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

if err := client.Run(ctx, func(ctx context.Context) error {
if err := peerManager.Init(ctx); err != nil {
return err
}
Expand All @@ -48,18 +51,9 @@ func ExampleManager() {
}

_, isBot := u.ToBot()
if err := gaps.Auth(ctx, client.API(), u.ID(), isBot, false); err != nil {
return err
if err := gaps.Run(ctx, client.API(), u.ID(), updates.AuthOptions{IsBot: isBot}); err != nil {
return errors.Wrap(err, "gaps")
}
defer gaps.Logout()

p, err := peerManager.Resolve(ctx, "durov")
if err != nil {
return err
}

username, _ := p.Username()
fmt.Println(username)
return nil
}); err != nil {
panic(err)
Expand Down
19 changes: 13 additions & 6 deletions telegram/updates/access_hash_feeder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package updates

import (
"go.uber.org/zap"
"golang.org/x/net/context"

"github.com/gotd/td/tg"
)

func (s *state) saveChannelHashes(chats []tg.ChatClass) {
func (s *internalState) saveChannelHashes(ctx context.Context, chats []tg.ChatClass) {
ctx, span := s.tracer.Start(ctx, "updates.saveChannelHashes")
defer span.End()

for _, c := range chats {
switch c := c.(type) {
case *tg.Channel:
Expand All @@ -22,7 +26,7 @@ func (s *state) saveChannelHashes(chats []tg.ChatClass) {
zap.Int64("channel_id", c.ID),
zap.String("title", c.Title),
)
if err := s.hasher.SetChannelAccessHash(s.selfID, c.ID, hash); err != nil {
if err := s.hasher.SetChannelAccessHash(ctx, s.selfID, c.ID, hash); err != nil {
s.log.Error("SetChannelState error", zap.Error(err))
}
}
Expand All @@ -34,15 +38,18 @@ func (s *state) saveChannelHashes(chats []tg.ChatClass) {
zap.Int64("channel_id", c.ID),
zap.String("title", c.Title),
)
if err := s.hasher.SetChannelAccessHash(s.selfID, c.ID, c.AccessHash); err != nil {
if err := s.hasher.SetChannelAccessHash(ctx, s.selfID, c.ID, c.AccessHash); err != nil {
s.log.Error("SetChannelState error", zap.Error(err))
}
}
}
}

func (s *state) restoreAccessHash(channelID int64, date int) (accessHash int64, ok bool) {
diff, err := s.client.UpdatesGetDifference(s.ctx, &tg.UpdatesGetDifferenceRequest{
func (s *internalState) restoreAccessHash(ctx context.Context, channelID int64, date int) (accessHash int64, ok bool) {
ctx, span := s.tracer.Start(ctx, "updates.restoreAccessHash")
defer span.End()

diff, err := s.client.UpdatesGetDifference(ctx, &tg.UpdatesGetDifferenceRequest{
Pts: s.pts.State(),
Qts: s.qts.State(),
Date: date,
Expand All @@ -60,7 +67,7 @@ func (s *state) restoreAccessHash(channelID int64, date int) (accessHash int64,
chats = diff.Chats
}

s.saveChannelHashes(chats)
s.saveChannelHashes(ctx, chats)
for _, c := range chats {
switch c := c.(type) {
case *tg.Channel:
Expand Down
18 changes: 12 additions & 6 deletions telegram/updates/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package updates
import (
"context"

"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"

"github.com/gotd/td/telegram"
"github.com/gotd/td/tg"
)

// RawClient is the interface which contains
// Telegram RPC methods used by manager for state synchronization.
type RawClient interface {
// API is the interface which contains
// Telegram RPC methods used by manager for internalState synchronization.
type API interface {
UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error)
UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error)
UpdatesGetChannelDifference(ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest) (tg.UpdatesChannelDifferenceClass, error)
Expand All @@ -32,21 +33,26 @@ type Config struct {
AccessHasher ChannelAccessHasher
// Logger (optional).
Logger *zap.Logger
// TracerProvider (optional).
TracerProvider trace.TracerProvider
}

func (cfg *Config) setDefaults() {
if cfg.Handler == nil {
panic("Handler is nil")
}
if cfg.Storage == nil {
cfg.Storage = newMemStorage()
}
if cfg.AccessHasher == nil {
cfg.AccessHasher = newMemAccessHasher()
}
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
if cfg.TracerProvider == nil {
cfg.TracerProvider = trace.NewNoopTracerProvider()
}
if cfg.Storage == nil {
cfg.Storage = newMemStorage()
}
if cfg.OnChannelTooLong == nil {
cfg.OnChannelTooLong = func(channelID int64) {
cfg.Logger.Error("Difference too long", zap.Int64("channel_id", channelID))
Expand Down
6 changes: 3 additions & 3 deletions telegram/updates/conv_shorts.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func convertOptional(msg *tg.Message, i tg.UpdatesClass) {
}
}

func (s *state) convertShortMessage(u *tg.UpdateShortMessage) *tg.UpdateShort {
func (s *internalState) convertShortMessage(u *tg.UpdateShortMessage) *tg.UpdateShort {
msg := &tg.Message{
ID: u.ID,
PeerID: &tg.PeerUser{UserID: u.UserID},
Expand Down Expand Up @@ -75,7 +75,7 @@ func (s *state) convertShortMessage(u *tg.UpdateShortMessage) *tg.UpdateShort {
}
}

func (s *state) convertShortChatMessage(u *tg.UpdateShortChatMessage) *tg.UpdateShort {
func (s *internalState) convertShortChatMessage(u *tg.UpdateShortChatMessage) *tg.UpdateShort {
msg := &tg.Message{
ID: u.ID,
PeerID: &tg.PeerChat{ChatID: u.ChatID},
Expand All @@ -101,7 +101,7 @@ func (s *state) convertShortChatMessage(u *tg.UpdateShortChatMessage) *tg.Update
}
}

func (s *state) convertShortSentMessage(u *tg.UpdateShortSentMessage) *tg.UpdateShort {
func (s *internalState) convertShortSentMessage(u *tg.UpdateShortSentMessage) *tg.UpdateShort {
// This update should be converted by the one who called the method
// that returned this update, because we do not have any context about
// it (message text, sender/recipient, etc.)
Expand Down
6 changes: 3 additions & 3 deletions telegram/updates/doc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Package updates provides a Telegram's state synchronization manager.
// Package updates provides a Telegram's internalState synchronization manager.
//
// It guarantees that all state-sensitive updates will be performed
// It guarantees that all internalState-sensitive updates will be performed
// in correct order.
//
// Limitations:
Expand All @@ -13,7 +13,7 @@
// of these operations. We rely on the server here.
//
// 3. Manager cannot recover the channel gap if there is a ChannelDifferenceTooLong error.
// Restoring the state in such situation is not the prerogative of this manager.
// Restoring the internalState in such situation is not the prerogative of this manager.
// See: https://core.telegram.org/constructor/updates.channelDifferenceTooLong
//
// TODO: Write implementation details.
Expand Down
62 changes: 46 additions & 16 deletions telegram/updates/internal/e2e/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ import (

func TestE2E(t *testing.T) {
testManager(t, func(s *server, storage updates.StateStorage) chan *tg.Updates {
t.Helper()

c := make(chan *tg.Updates, 10)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
biba = s.peers.createUser("biba")
Expand All @@ -26,7 +30,7 @@ func TestE2E(t *testing.T) {
)

var channels []*tg.PeerChannel
require.NoError(t, storage.ForEachChannels(123, func(channelID int64, pts int) error {
require.NoError(t, storage.ForEachChannels(ctx, 123, func(ctx context.Context, channelID int64, pts int) error {
channels = append(channels, &tg.PeerChannel{
ChannelID: channelID,
})
Expand All @@ -39,7 +43,7 @@ func TestE2E(t *testing.T) {
// Biba.
go func() {
defer wg.Done()
for i := 0; i < 30; i++ {
for i := 0; i < 3; i++ {
c <- s.CreateEvent(func(ev *EventBuilder) {
ev.SendMessage(biba, chat, fmt.Sprintf("biba-%d", i))

Expand All @@ -53,7 +57,7 @@ func TestE2E(t *testing.T) {
// Boba.
go func() {
defer wg.Done()
for i := 0; i < 30; i++ {
for i := 0; i < 3; i++ {
c <- s.CreateEvent(func(ev *EventBuilder) {
ev.SendMessage(boba, chat, fmt.Sprintf("boba-%d", i))

Expand All @@ -75,6 +79,8 @@ func TestE2E(t *testing.T) {
func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) chan *tg.Updates) {
t.Helper()

ctx, cancel := context.WithCancel(context.Background())

var (
log = zaptest.NewLogger(t)
s = newServer()
Expand All @@ -83,17 +89,19 @@ func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) c
hasher = newMemAccessHasher()
)

require.NoError(t, storage.SetState(123, updates.State{
const uid = 123

require.NoError(t, storage.SetState(ctx, uid, updates.State{
Pts: 0,
Qts: 0,
Date: 0,
Seq: 0,
}))

for i := 0; i < 30; i++ {
for i := 0; i < 2; i++ {
c := s.peers.createChannel(fmt.Sprintf("channel-%d", i))
require.NoError(t, storage.SetChannelPts(123, c.ChannelID, 0))
require.NoError(t, hasher.SetChannelAccessHash(123, c.ChannelID, c.ChannelID*2))
require.NoError(t, storage.SetChannelPts(ctx, uid, c.ChannelID, 0))
require.NoError(t, hasher.SetChannelAccessHash(ctx, uid, c.ChannelID, c.ChannelID*2))
}

e := updates.New(updates.Config{
Expand All @@ -103,17 +111,33 @@ func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) c
AccessHasher: hasher,
})

require.NoError(t, e.Auth(context.Background(), s, 123, false, false))

uchan := loss(f(s, storage))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

g, ctx := errgroup.WithContext(ctx)
ready := make(chan struct{})
opts := updates.AuthOptions{
OnStart: func(ctx context.Context) {
t.Log("OnStart")
close(ready)
},
}
g.Go(func() error {
t.Log("Starting manager")
defer t.Log("Manager stopped")
return e.Run(ctx, s, uid, opts)
})
g.Go(func() error {
t.Log("Starting updates generator")
defer t.Log("Updates generator stopped")

defer cancel()

select {
case <-ready:
t.Log("Ready")
case <-ctx.Done():
return ctx.Err()
}

var g errgroup.Group
for i := 0; i < 2; i++ {
g.Go(func() error {
Expand All @@ -134,26 +158,32 @@ func testManager(t *testing.T, f func(s *server, storage updates.StateStorage) c
})
}

t.Log("Waiting")
if err := g.Wait(); err != nil {
return err
}

t.Log("Sending pts changed")

ups := []tg.UpdateClass{&tg.UpdatePtsChanged{}}
if err := storage.ForEachChannels(123, func(channelID int64, pts int) error {
if err := storage.ForEachChannels(ctx, uid, func(ctx context.Context, channelID int64, pts int) error {
ups = append(ups, &tg.UpdateChannelTooLong{ChannelID: channelID})
return nil
}); err != nil {
return err
}

t.Log("Handle")

return e.Handle(ctx, &tg.Updates{
Updates: ups,
})
})

require.NoError(t, g.Wait())
require.NoError(t, e.Logout())
t.Log("Waiting for shutdown")
require.ErrorIs(t, g.Wait(), context.Canceled)

t.Log("Checking")
require.Equal(t, s.messages, h.messages)
require.Equal(t, s.peers.channels, h.ents.Channels)
require.Equal(t, s.peers.chats, h.ents.Chats)
Expand Down
Loading

0 comments on commit b755722

Please sign in to comment.