Files
grok/main.go
2026-02-01 18:06:21 +03:00

270 lines
5.6 KiB
Go

package main
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/signal"
"strconv"
"strings"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"github.com/joho/godotenv"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/openai/openai-go/v3/shared"
)
type OpenAIPrompter struct {
cli openai.Client
cfg *Config
}
func (p *OpenAIPrompter) Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error) {
instructions := p.cfg.SystemPrompt
if req.OriginalPostContent != nil {
instructions += fmt.Sprintf("\nThis is the channel post that user is mentioning: %s", *req.OriginalPostContent)
}
resp, err := p.cli.Responses.New(ctx, responses.ResponseNewParams{
Instructions: openai.String(p.cfg.SystemPrompt),
Input: responses.ResponseNewParamsInputUnion{
OfString: openai.String(req.Question),
},
Reasoning: shared.ReasoningParam{
Effort: shared.ReasoningEffortXhigh,
},
})
if err != nil {
return nil, err
}
return &HighlyTrustedResponse{
Text: resp.OutputText(),
}, nil
}
func NewOpenAIProoooompter(cfg *Config) *OpenAIPrompter {
return &OpenAIPrompter{
cli: openai.NewClient(
option.WithBaseURL(cfg.OpenAIBaseURL),
),
cfg: cfg,
}
}
type HighlyTrustedResponse struct {
Text string
}
type PromptRequest struct {
Question string
OriginalPostContent *string
}
type Proompter interface {
Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error)
}
type App struct {
log *slog.Logger
bot *tgbotapi.BotAPI
proompter Proompter
sema chan struct{}
config *Config
}
func (a *App) handleMessage(ctx context.Context, msg *tgbotapi.Message) error {
if msg.Chat == nil {
return nil
}
chatID := msg.Chat.ID
if chatID != a.config.ChatID || !strings.HasPrefix(msg.Text, "@grok ") {
return nil
}
question := msg.Text[len("@grok "):]
select {
case a.sema <- struct{}{}:
default:
a.log.Info("concurrency limit hit", "msg", msg)
return nil
}
defer func() { <-a.sema }()
a.log.Info(
"message passed all guards",
"og_text", msg.Text,
"transformed_text", question,
)
var ogPostContent *string
if msg.ReplyToMessage.From.ID == a.config.ChannelID {
a.log.Info("message was a reply to channel post")
ogPostContent = &msg.ReplyToMessage.Text
}
response, err := a.proompter.Prompt(ctx, PromptRequest{
Question: question,
OriginalPostContent: ogPostContent,
})
if err != nil {
return fmt.Errorf("prompting: %w", err)
}
_, err = a.bot.Send(tgbotapi.MessageConfig{
BaseChat: tgbotapi.BaseChat{
ChatID: chatID,
ReplyToMessageID: msg.MessageID,
},
Text: response.Text,
})
if err != nil {
return fmt.Errorf("responding: %w", err)
}
return nil
}
func (a *App) HandleUpdates(ctx context.Context) error {
upds := a.bot.GetUpdatesChan(tgbotapi.UpdateConfig{})
for {
select {
case upd, ok := <-upds:
if !ok {
return errors.New("channel closed")
}
if upd.Message != nil {
a.log.Info("new message", "update", upd)
go func() {
if err := a.handleMessage(ctx, upd.Message); err != nil {
a.log.Error("handling message", "msg", upd.Message, "err", err)
}
}()
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func NewApp(cfg *Config, prompter Proompter) (*App, error) {
var app App
var err error
app.proompter = prompter
app.log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
app.bot, err = tgbotapi.NewBotAPI(cfg.BotToken)
if err != nil {
return nil, err
}
if cfg.MaxConcurrentRequests == 0 {
return nil, errors.New("concurrency limit not set")
}
app.sema = make(chan struct{}, cfg.MaxConcurrentRequests)
app.config = cfg
return &app, nil
}
type Config struct {
SystemPrompt string
OpenAIBaseURL string
BotToken string
MaxConcurrentRequests uint
ChatID int64
ChannelID int64
}
func LoadConfig(cfg *Config) error {
if err := godotenv.Load(".env"); err != nil {
slog.Warn("no env file loaded", "err", err)
}
cfg.OpenAIBaseURL = os.Getenv("OPENAI_BASE_URL")
if cfg.OpenAIBaseURL == "" {
return errors.New("openai base url not set")
}
cfg.BotToken = os.Getenv("BOT_TOKEN")
mcg, err := strconv.ParseUint(
os.Getenv("MAX_CONCURRENT_REQUESTS"),
10, 64,
)
if err != nil {
return err
}
cfg.MaxConcurrentRequests = uint(mcg)
chatID, err := strconv.ParseInt(os.Getenv("CHAT_ID"), 10, 64)
if err != nil {
return err
}
if chatID == 0 {
return errors.New("non-zero chat id is unsafe")
}
cfg.ChatID = chatID
channelID, err := strconv.ParseInt(os.Getenv("CHANNEL_ID"), 10, 64)
if err != nil {
return err
}
if chatID == 0 {
slog.Warn("channel id is not set")
}
cfg.ChannelID = channelID
sysPromptPath := os.Getenv("SYSTEM_PROMPT_PATH")
promptBytes, err := os.ReadFile(sysPromptPath)
if err != nil {
slog.Warn("could not load system prompt", "path", sysPromptPath)
} else {
cfg.SystemPrompt = string(promptBytes)
}
return nil
}
func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
log := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
log.Info("Starting GROK")
var cfg Config
if err := LoadConfig(&cfg); err != nil {
log.Error("loading config", "err", err)
os.Exit(1)
}
prompter := NewOpenAIProoooompter(&cfg)
app, err := NewApp(&cfg, prompter)
if err != nil {
log.Error("initializing app", "err", err)
os.Exit(1)
}
go func() {
if err := app.HandleUpdates(ctx); err != nil {
log.Error("handleUpdates exited", "err", err)
}
cancel()
}()
<-ctx.Done()
}