diff --git a/internal/db/migrations/2023122700.sql b/internal/db/migrations/2023122700.sql new file mode 100644 index 0000000..0ee5bb9 --- /dev/null +++ b/internal/db/migrations/2023122700.sql @@ -0,0 +1 @@ +ALTER TABLE reactions ADD COLUMN excluded_channels TEXT NOT NULL DEFAULT ''; \ No newline at end of file diff --git a/internal/db/reactions.go b/internal/db/reactions.go index 11400d2..b662569 100644 --- a/internal/db/reactions.go +++ b/internal/db/reactions.go @@ -33,17 +33,18 @@ const ( ) type Reaction struct { - GuildID string `db:"guild_id"` - MatchType MatchType `db:"match_type"` - Match string `db:"match"` - ReactionType ReactionType `db:"reaction_type"` - Reaction string `db:"reaction"` - Chance int `db:"chance"` + GuildID string `db:"guild_id"` + MatchType MatchType `db:"match_type"` + Match string `db:"match"` + ReactionType ReactionType `db:"reaction_type"` + Reaction StringSlice `db:"reaction"` + Chance int `db:"chance"` + ExcludedChannels StringSlice `db:"excluded_channels"` } func AddReaction(guildID string, r Reaction) error { r.GuildID = guildID - _, err := db.NamedExec("INSERT INTO reactions VALUES (:guild_id, :match_type, :match, :reaction_type, :reaction, :chance)", r) + _, err := db.NamedExec("INSERT INTO reactions VALUES (:guild_id, :match_type, :match, :reaction_type, :reaction, :chance, :excluded_channels)", r) return err } @@ -56,3 +57,21 @@ func Reactions(guildID string) (rs []Reaction, err error) { err = db.Select(&rs, "SELECT * FROM reactions WHERE guild_id = ?", guildID) return rs, err } + +func ReactionsExclude(guildID, match, channelID string) (err error) { + if match == "" { + _, err = db.Exec("UPDATE reactions SET excluded_channels = trim(excluded_channels || X'1F' || ?, X'1F') WHERE guild_id = ?", channelID, guildID) + } else { + _, err = db.Exec("UPDATE reactions SET excluded_channels = trim(excluded_channels || X'1F' || ?, X'1F') WHERE guild_id = ? AND match = ?", channelID, guildID, match) + } + return err +} + +func ReactionsUnexclude(guildID, match, channelID string) (err error) { + if match == "" { + _, err = db.Exec("UPDATE reactions SET excluded_channels = trim(replace(replace(excluded_channels, ?, ''), X'1F1F', X'1F'), X'1F') WHERE guild_id = ?", channelID, guildID) + } else { + _, err = db.Exec("UPDATE reactions SET excluded_channels = trim(replace(replace(excluded_channels, ?, ''), X'1F1F', X'1F'), X'1F') WHERE guild_id = ? AND match = ?", channelID, guildID, match) + } + return err +} diff --git a/internal/db/roles.go b/internal/db/roles.go index fc24f36..6acf6d5 100644 --- a/internal/db/roles.go +++ b/internal/db/roles.go @@ -27,12 +27,12 @@ import ( ) type ReactionRoleCategory struct { - MsgID string `db:"msg_id"` - ChannelID string `db:"channel_id"` - Name string `db:"name"` - Description string `db:"description"` - Emoji []string `db:"emoji"` - Roles []string `db:"roles"` + MsgID string `db:"msg_id"` + ChannelID string `db:"channel_id"` + Name string `db:"name"` + Description string `db:"description"` + Emoji StringSlice `db:"emoji"` + Roles StringSlice `db:"roles"` } func AddReactionRoleCategory(channelID string, rrc ReactionRoleCategory) error { @@ -42,31 +42,16 @@ func AddReactionRoleCategory(channelID string, rrc ReactionRoleCategory) error { channelID, rrc.Name, rrc.Description, - strings.Join(rrc.Emoji, "\x1F"), - strings.Join(rrc.Roles, "\x1F"), + rrc.Emoji, + rrc.Roles, ) return err } func GetReactionRoleCategory(channelID, name string) (*ReactionRoleCategory, error) { - var msgID, description, emoji, roles string - err := db.QueryRow( - "SELECT msg_id, description, emoji, roles FROM reaction_role_categories WHERE channel_id = ? AND name = ?", - channelID, - name, - ).Scan(&msgID, &description, &emoji, &roles) - if err != nil { - return nil, err - } - - return &ReactionRoleCategory{ - MsgID: msgID, - ChannelID: channelID, - Name: name, - Description: description, - Emoji: splitOptions(emoji), - Roles: splitOptions(roles), - }, nil + out := &ReactionRoleCategory{} + err := db.QueryRowx("SELECT * FROM reaction_role_categories WHERE channel_id = ? AND name = ?", channelID, name).StructScan(out) + return out, err } func DeleteReactionRoleCategory(channelID, name string) error { @@ -74,25 +59,24 @@ func DeleteReactionRoleCategory(channelID, name string) error { return err } -func AddReactionRole(channelID, category, emoji string, role *discordgo.Role) error { - if strings.Contains(category, "\x1F") || strings.Contains(emoji, "\x1F") { +func AddReactionRole(channelID, category, emojiStr string, role *discordgo.Role) error { + if strings.Contains(category, "\x1F") || strings.Contains(emojiStr, "\x1F") { return errors.New("reaction roles cannot contain unit separator") } - var oldEmoji, oldRoles string - err := db.QueryRow("SELECT emoji, roles FROM reaction_role_categories WHERE name = ? AND channel_id = ?", category, channelID).Scan(&oldEmoji, &oldRoles) + var emoji, roles StringSlice + err := db.QueryRow("SELECT emoji, roles FROM reaction_role_categories WHERE name = ? AND channel_id = ?", category, channelID).Scan(&emoji, &roles) if err != nil { return err } - splitEmoji, splitRoles := splitOptions(oldEmoji), splitOptions(oldRoles) - splitEmoji = append(splitEmoji, strings.TrimSpace(emoji)) - splitRoles = append(splitRoles, role.ID) + emoji = append(emoji, strings.TrimSpace(emojiStr)) + roles = append(roles, role.ID) _, err = db.Exec( "UPDATE reaction_role_categories SET emoji = ?, roles = ? WHERE name = ? AND channel_id = ?", - strings.Join(splitEmoji, "\x1F"), - strings.Join(splitRoles, "\x1F"), + emoji, + roles, category, channelID, ) @@ -100,24 +84,23 @@ func AddReactionRole(channelID, category, emoji string, role *discordgo.Role) er } func DeleteReactionRole(channelID, category string, role *discordgo.Role) error { - var oldEmoji, oldRoles string - err := db.QueryRow("SELECT emoji, roles FROM reaction_role_categories WHERE name = ? AND channel_id = ?", category, channelID).Scan(&oldEmoji, &oldRoles) + var emoji, roles StringSlice + err := db.QueryRow("SELECT emoji, roles FROM reaction_role_categories WHERE name = ? AND channel_id = ?", category, channelID).Scan(&emoji, &roles) if err != nil { return err } - splitEmoji, splitRoles := splitOptions(oldEmoji), splitOptions(oldRoles) - if i := slices.Index(splitRoles, role.ID); i == -1 { + if i := slices.Index(roles, role.ID); i == -1 { return nil } else { - splitEmoji = append(splitEmoji[:i], splitEmoji[i+1:]...) - splitRoles = append(splitRoles[:i], splitRoles[i+1:]...) + emoji = append(emoji[:i], emoji[i+1:]...) + roles = append(roles[:i], roles[i+1:]...) } _, err = db.Exec( "UPDATE reaction_role_categories SET emoji = ?, roles = ? WHERE name = ? AND channel_id = ?", - strings.Join(splitEmoji, "\x1F"), - strings.Join(splitRoles, "\x1F"), + emoji, + roles, category, channelID, ) diff --git a/internal/db/slice.go b/internal/db/slice.go index c363ff5..bf1bb31 100644 --- a/internal/db/slice.go +++ b/internal/db/slice.go @@ -9,7 +9,7 @@ import ( type StringSlice []string func (s StringSlice) String() string { - return strings.Join(s, ",") + return strings.Join(s, ", ") } func (s StringSlice) Value() (driver.Value, error) { @@ -30,4 +30,4 @@ func splitOptions(s string) []string { return nil } return strings.Split(s, "\x1F") -} \ No newline at end of file +} diff --git a/internal/systems/reactions/commands.go b/internal/systems/reactions/commands.go index 3a2b873..05cdc95 100644 --- a/internal/systems/reactions/commands.go +++ b/internal/systems/reactions/commands.go @@ -41,6 +41,10 @@ func reactionsCmd(s *discordgo.Session, i *discordgo.InteractionCreate) error { return reactionsListCmd(s, i) case "delete": return reactionsDeleteCmd(s, i) + case "exclude": + return reactionsExcludeCmd(s, i) + case "unexclude": + return reactionsUnexcludeCmd(s, i) default: return fmt.Errorf("unknown reactions subcommand: %s", name) } @@ -54,7 +58,6 @@ func reactionsAddCmd(s *discordgo.Session, i *discordgo.InteractionCreate) error MatchType: db.MatchType(args[0].StringValue()), Match: strings.TrimSpace(args[1].StringValue()), ReactionType: db.ReactionType(args[2].StringValue()), - Reaction: strings.TrimSpace(args[3].StringValue()), Chance: 100, } @@ -73,12 +76,16 @@ func reactionsAddCmd(s *discordgo.Session, i *discordgo.InteractionCreate) error reaction.Match = strings.ToLower(reaction.Match) } - if reaction.ReactionType == db.ReactionTypeEmoji { + switch reaction.ReactionType { + case db.ReactionTypeEmoji: + // Convert comma-separated emoji into a StringSlice value + reaction.Reaction = db.StringSlice(strings.Split(strings.TrimSpace(args[3].StringValue()), ",")) if err := validateEmoji(reaction.Reaction); err != nil { return err } - // Use the correct delimeter for the DB - reaction.Reaction = strings.ReplaceAll(reaction.Reaction, ",", "\x1F") + case db.ReactionTypeText: + // Create a StringSlice with the desired text inside + reaction.Reaction = db.StringSlice{args[3].StringValue()} } err := db.AddReaction(i.GuildID, reaction) @@ -95,9 +102,11 @@ func reactionsListCmd(s *discordgo.Session, i *discordgo.InteractionCreate) erro return err } + fmt.Println(reactions) var sb strings.Builder sb.WriteString("**Reactions:**\n") for _, reaction := range reactions { + fmt.Println(reaction.Reaction) sb.WriteString("- _[") if reaction.Chance < 100 { sb.WriteString(strconv.Itoa(reaction.Chance)) @@ -107,7 +116,7 @@ func reactionsListCmd(s *discordgo.Session, i *discordgo.InteractionCreate) erro sb.WriteString("]_ `") sb.WriteString(reaction.Match) sb.WriteString("`: \"") - sb.WriteString(strings.ReplaceAll(reaction.Reaction, "\x1F", ",")) + sb.WriteString(reaction.Reaction.String()) sb.WriteString("\" _(") sb.WriteString(string(reaction.ReactionType)) sb.WriteString(")_\n") @@ -134,19 +143,61 @@ func reactionsDeleteCmd(s *discordgo.Session, i *discordgo.InteractionCreate) er return util.RespondEphemeral(s, i.Interaction, "Successfully removed reaction") } -func validateEmoji(s string) error { - if strings.Contains(s, ",") { - split := strings.Split(s, ",") - for _, emojiStr := range split { - if _, ok := emoji.Parse(emojiStr); !ok { - return fmt.Errorf("invalid reaction emoji: %s", emojiStr) - } - } - } else if strings.Contains(s, "\x1F") { - return fmt.Errorf("emoji string cannot contain unit separator") - } else { - if _, ok := emoji.Parse(s); !ok { - return fmt.Errorf("invalid reaction emoji: %s", s) +func reactionsExcludeCmd(s *discordgo.Session, i *discordgo.InteractionCreate) error { + // Make sure the user has the manage expressions permission + // in case a role/member override allows someone else to use it + if i.Member.Permissions&discordgo.PermissionManageEmojis == 0 { + return errors.New("you do not have permission to exclude channels") + } + + data := i.ApplicationCommandData() + args := data.Options[0].Options + + channel := args[0].ChannelValue(s) + + var match string + if len(args) == 2 { + match = args[1].StringValue() + } + + err := db.ReactionsExclude(i.GuildID, match, channel.ID) + if err != nil { + return err + } + + return util.RespondEphemeral(s, i.Interaction, fmt.Sprintf("Successfully excluded %s from receiving reactions", channel.Mention())) +} + +func reactionsUnexcludeCmd(s *discordgo.Session, i *discordgo.InteractionCreate) error { + // Make sure the user has the manage expressions permission + // in case a role/member override allows someone else to use it + if i.Member.Permissions&discordgo.PermissionManageEmojis == 0 { + return errors.New("you do not have permission to unexclude channels") + } + + data := i.ApplicationCommandData() + args := data.Options[0].Options + + channel := args[0].ChannelValue(s) + + var match string + if len(args) == 2 { + match = args[1].StringValue() + } + + err := db.ReactionsUnexclude(i.GuildID, match, channel.ID) + if err != nil { + return err + } + + return util.RespondEphemeral(s, i.Interaction, fmt.Sprintf("Successfully unexcluded %s from receiving reactions", channel.Mention())) +} + +func validateEmoji(s db.StringSlice) error { + for i := range s { + s[i] = strings.TrimSpace(s[i]) + if _, ok := emoji.Parse(s[i]); !ok { + return fmt.Errorf("invalid reaction emoji: %s", s[i]) } } return nil diff --git a/internal/systems/reactions/reactions.go b/internal/systems/reactions/reactions.go index a2430c9..bda1f66 100644 --- a/internal/systems/reactions/reactions.go +++ b/internal/systems/reactions/reactions.go @@ -21,6 +21,7 @@ package reactions import ( "fmt" "math/rand" + "slices" "strconv" "strings" "sync" @@ -120,6 +121,50 @@ func Init(s *discordgo.Session) error { }, }, }, + { + Type: discordgo.ApplicationCommandOptionSubCommand, + Name: "exclude", + Description: "Exclude a channel from having reactions", + Options: []*discordgo.ApplicationCommandOption{ + { + Name: "channel", + Description: "The channel which shouldn't receive reactions", + Type: discordgo.ApplicationCommandOptionChannel, + ChannelTypes: []discordgo.ChannelType{ + discordgo.ChannelTypeGuildText, + discordgo.ChannelTypeGuildForum, + }, + Required: true, + }, + { + Name: "match", + Description: "The match value to exclude", + Type: discordgo.ApplicationCommandOptionString, + }, + }, + }, + { + Type: discordgo.ApplicationCommandOptionSubCommand, + Name: "unexclude", + Description: "Unexclude a channel from having reactions", + Options: []*discordgo.ApplicationCommandOption{ + { + Name: "channel", + Description: "The channel which should receive reactions", + Type: discordgo.ApplicationCommandOptionChannel, + ChannelTypes: []discordgo.ChannelType{ + discordgo.ChannelTypeGuildText, + discordgo.ChannelTypeGuildForum, + }, + Required: true, + }, + { + Name: "match", + Description: "The match value to unexclude", + Type: discordgo.ApplicationCommandOptionString, + }, + }, + }, }, }) @@ -138,6 +183,10 @@ func onMessage(s *discordgo.Session, mc *discordgo.MessageCreate) { } for _, reaction := range reactions { + if slices.Contains(reaction.ExcludedChannels, mc.ChannelID) { + continue + } + switch reaction.MatchType { case db.MatchTypeContains: if strings.Contains(strings.ToLower(mc.Content), reaction.Match) { @@ -154,7 +203,7 @@ func onMessage(s *discordgo.Session, mc *discordgo.MessageCreate) { continue } - var content string + content := reaction.Reaction switch reaction.ReactionType { case db.ReactionTypeText: submatch := re.FindSubmatch([]byte(mc.Content)) @@ -163,7 +212,9 @@ func onMessage(s *discordgo.Session, mc *discordgo.MessageCreate) { for i, match := range submatch { replacements[strconv.Itoa(i)] = match } - content = fasttemplate.ExecuteStringStd(reaction.Reaction, "{", "}", replacements) + content = db.StringSlice{ + fasttemplate.ExecuteStringStd(reaction.Reaction[0], "{", "}", replacements), + } } else if len(submatch) == 1 { content = reaction.Reaction } @@ -173,7 +224,7 @@ func onMessage(s *discordgo.Session, mc *discordgo.MessageCreate) { } } - if content != "" { + if content[0] != "" { err = performReaction(s, reaction, content, mc) if err != nil { log.Error("Error performing reaction").Err(err).Send() @@ -189,7 +240,7 @@ var ( rng = rand.New(rand.NewSource(time.Now().UnixNano())) ) -func performReaction(s *discordgo.Session, reaction db.Reaction, content string, mc *discordgo.MessageCreate) error { +func performReaction(s *discordgo.Session, reaction db.Reaction, content db.StringSlice, mc *discordgo.MessageCreate) error { if reaction.Chance < 100 { rngMtx.Lock() randNum := rng.Intn(100) + 1 @@ -201,19 +252,12 @@ func performReaction(s *discordgo.Session, reaction db.Reaction, content string, switch reaction.ReactionType { case db.ReactionTypeText: - _, err := s.ChannelMessageSendReply(mc.ChannelID, content, mc.Reference()) + _, err := s.ChannelMessageSendReply(mc.ChannelID, content[0], mc.Reference()) if err != nil { return err } case db.ReactionTypeEmoji: - var emojis []string - if strings.Contains(content, "\x1F") { - emojis = strings.Split(content, "\x1F") - } else { - emojis = []string{content} - } - - for _, emojiStr := range emojis { + for _, emojiStr := range content { e, ok := emoji.Parse(emojiStr) if !ok { return fmt.Errorf("invalid emoji: %s", emojiStr)