diff --git a/bot/bot.go b/bot/bot.go index 2386478..039a7f8 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -14,18 +14,16 @@ import ( "github.com/spf13/viper" ) -var ( - C Config +var C Config - handlers []MessageCreateHandler = []MessageCreateHandler{ - NewReactionHandler(), +type ( + Bot struct { + Session *discordgo.Session + Config Config } -) -type Bot struct { - Session *discordgo.Session - Config Config -} + MessageHandler func(s *discordgo.Session, m *discordgo.MessageCreate) +) func NewBot(s *discordgo.Session, config Config) *Bot { return &Bot{Session: s, Config: config} @@ -66,6 +64,11 @@ func (b *Bot) RegisterCommands() { }) } +func (b *Bot) RegisterHandlers() { + b.Session.AddHandler(b.CommandHandler()) + b.Session.AddHandler(b.ReactionHandler()) +} + func Run() error { setupConfig() @@ -80,16 +83,10 @@ func Run() error { log.Fatalf("error creating Discord session: %v\n", err) } - for _, h := range handlers { - h.SetConfig(C) - dg.AddHandler(h.Handle) - } - b := NewBot(dg, C) + b.RegisterHandlers() b.RegisterCommands() - dg.AddHandler(NewCommandHandler(b)) - dg.Identify.Intents = discordgo.IntentsGuildMessages | discordgo.IntentsDirectMessages err = dg.Open() diff --git a/bot/command.go b/bot/command.go index 7485279..6159636 100644 --- a/bot/command.go +++ b/bot/command.go @@ -55,7 +55,7 @@ func GetCommand(name string) (*Command, bool) { return cmd, ok } -func NewCommandHandler(bot *Bot) func(s *discordgo.Session, m *discordgo.MessageCreate) { +func (b *Bot) CommandHandler() func(*discordgo.Session, *discordgo.MessageCreate) { return func(s *discordgo.Session, m *discordgo.MessageCreate) { var cmd *Command @@ -63,18 +63,18 @@ func NewCommandHandler(bot *Bot) func(s *discordgo.Session, m *discordgo.Message return } - if !lib.HasCommand(m.Content, bot.Config.Prefix) { + if !lib.HasCommand(m.Content, b.Config.Prefix) { return } - cmdName, arg := lib.SplitCommandAndArg(m.Content, bot.Config.Prefix) + cmdName, arg := lib.SplitCommandAndArg(m.Content, b.Config.Prefix) cmd, ok := GetCommand(cmdName) args := lib.SplitArgs(arg, cmd.NArgs) if ok { - cmd.Config = bot.Config + cmd.Config = b.Config log.Debugf("command: %v, args: %v, nargs: %d", cmd.Name, args, len(args)) cmd.Func(args, m) diff --git a/bot/handlers.go b/bot/handlers.go deleted file mode 100644 index 721c6af..0000000 --- a/bot/handlers.go +++ /dev/null @@ -1,10 +0,0 @@ -package bot - -import ( - "github.com/bwmarrin/discordgo" -) - -type MessageCreateHandler interface { - Handle(*discordgo.Session, *discordgo.MessageCreate) - SetConfig(Config) -} diff --git a/bot/reaction.go b/bot/reaction.go index a94e0d8..efefe55 100644 --- a/bot/reaction.go +++ b/bot/reaction.go @@ -10,55 +10,43 @@ import ( log "github.com/sirupsen/logrus" ) -type ( - ReactionHandler struct { - Config Config - } -) +func (b *Bot) ReactionHandler() func(*discordgo.Session, *discordgo.MessageCreate) { + return func(s *discordgo.Session, m *discordgo.MessageCreate) { + if m.Author.ID == s.State.User.ID { + return + } -func NewReactionHandler() *ReactionHandler { - return new(ReactionHandler) -} + emojis := b.Config.Handler.Reaction.Emojis + channels := b.Config.Handler.Reaction.Channels -func (h *ReactionHandler) SetConfig(config Config) { - h.Config = config -} + if len(emojis) == 0 { + log.Warning("emoji list is empty") + return + } -func (h *ReactionHandler) Handle(s *discordgo.Session, m *discordgo.MessageCreate) { - if m.Author.ID == s.State.User.ID { - return - } + channel, err := s.Channel(m.ChannelID) + if err != nil { + log.Fatalf("unable to get channel name: %v", err) + } - emojis := h.Config.Handler.Reaction.Emojis - channels := h.Config.Handler.Reaction.Channels + if len(channels) > 0 && !lib.Contains(channels, channel.Name) { + return + } - if len(emojis) == 0 { - log.Warning("emoji list is empty") - return - } + for _, a := range m.Attachments { + if strings.HasPrefix(a.ContentType, "image/") { + for i := 1; i <= lib.RandInt(1, len(emojis)); i++ { + r := emojis[rand.Intn(len(emojis))] + s.MessageReactionAdd(m.ChannelID, m.ID, r) + } + } + } - channel, err := s.Channel(m.ChannelID) - if err != nil { - log.Fatalf("unable to get channel name: %v", err) - } - - if len(channels) > 0 && !lib.Contains(channels, channel.Name) { - return - } - - for _, a := range m.Attachments { - if strings.HasPrefix(a.ContentType, "image/") { + for range m.Embeds { for i := 1; i <= lib.RandInt(1, len(emojis)); i++ { r := emojis[rand.Intn(len(emojis))] s.MessageReactionAdd(m.ChannelID, m.ID, r) } } } - - for range m.Embeds { - for i := 1; i <= lib.RandInt(1, len(emojis)); i++ { - r := emojis[rand.Intn(len(emojis))] - s.MessageReactionAdd(m.ChannelID, m.ID, r) - } - } }