Files
grok/main.go

312 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/signal"
"strconv"
"strings"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"github.com/joho/godotenv"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/openai/openai-go/v3/shared"
)
type OpenAIPrompter struct {
cli openai.Client
cfg *Config
log *slog.Logger
}
func composeSysPromptWithContext(systemPrompt string, req PromptRequest) string {
var b strings.Builder
b.WriteString(systemPrompt)
b.WriteString("\nИмя пользователя: ")
b.WriteString(req.Username)
if req.ReplyToContent != nil {
b.WriteString("\nПользователь отсылается на текст сообщения: ")
b.WriteString(*req.ReplyToContent)
}
return b.String()
}
func (p *OpenAIPrompter) Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error) {
p.log.Info("new prompt request",
"req", req)
sysPrompt := composeSysPromptWithContext(p.cfg.SystemPrompt, req)
input := []responses.ResponseInputItemUnionParam{
{
OfMessage: &responses.EasyInputMessageParam{
Content: responses.EasyInputMessageContentUnionParam{
OfString: openai.String(sysPrompt),
},
},
},
{
OfMessage: &responses.EasyInputMessageParam{
Content: responses.EasyInputMessageContentUnionParam{
OfString: openai.String(req.Question),
},
Role: responses.EasyInputMessageRoleUser,
},
},
}
resp, err := p.cli.Responses.New(ctx, responses.ResponseNewParams{
Instructions: openai.String(p.cfg.SystemPrompt),
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: input,
},
Reasoning: shared.ReasoningParam{
Effort: shared.ReasoningEffortXhigh,
},
})
if err != nil {
return nil, err
}
return &HighlyTrustedResponse{
Text: resp.OutputText(),
}, nil
}
func NewOpenAIProoooompter(cfg *Config, log *slog.Logger) *OpenAIPrompter {
return &OpenAIPrompter{
cli: openai.NewClient(
option.WithBaseURL(cfg.OpenAIBaseURL),
),
cfg: cfg,
log: log,
}
}
type HighlyTrustedResponse struct {
Text string
}
type PromptRequest struct {
Username string
Question string
ReplyToContent *string
}
type Proompter interface {
Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error)
}
type App struct {
log *slog.Logger
bot *tgbotapi.BotAPI
proompter Proompter
sema chan struct{}
config *Config
}
func (a *App) handleMessage(ctx context.Context, msg *tgbotapi.Message) error {
if msg.Chat == nil {
return nil
}
chatID := msg.Chat.ID
if chatID != a.config.ChatID || !strings.HasPrefix(msg.Text, "@grok ") {
return nil
}
question := msg.Text[len("@grok "):]
select {
case a.sema <- struct{}{}:
default:
a.log.Info("concurrency limit hit", "msg", msg)
return nil
}
defer func() { <-a.sema }()
a.log.Info(
"message passed all guards",
"og_text", msg.Text,
"transformed_text", question,
)
var repliedToContent *string
if msg.ReplyToMessage != nil {
a.log.Info("message was a reply")
repliedToContent = &msg.ReplyToMessage.Text
}
var username string
if msg.From != nil {
username = msg.From.UserName
}
response, err := a.proompter.Prompt(ctx, PromptRequest{
Question: question,
ReplyToContent: repliedToContent,
Username: username,
})
if err != nil {
return fmt.Errorf("prompting: %w", err)
}
_, err = a.bot.Send(tgbotapi.MessageConfig{
BaseChat: tgbotapi.BaseChat{
ChatID: chatID,
ReplyToMessageID: msg.MessageID,
},
Text: response.Text,
})
if err != nil {
return fmt.Errorf("responding: %w", err)
}
return nil
}
func (a *App) HandleUpdates(ctx context.Context) error {
upds := a.bot.GetUpdatesChan(tgbotapi.UpdateConfig{})
for {
select {
case upd, ok := <-upds:
if !ok {
return errors.New("channel closed")
}
if upd.Message != nil {
a.log.Info("new message", "update", upd)
go func() {
if err := a.handleMessage(ctx, upd.Message); err != nil {
a.log.Error("handling message", "msg", upd.Message, "err", err)
}
}()
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func NewApp(cfg *Config, prompter Proompter) (*App, error) {
var app App
var err error
app.proompter = prompter
app.log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
app.bot, err = tgbotapi.NewBotAPI(cfg.BotToken)
if err != nil {
return nil, err
}
if cfg.MaxConcurrentRequests == 0 {
return nil, errors.New("concurrency limit not set")
}
app.sema = make(chan struct{}, cfg.MaxConcurrentRequests)
app.config = cfg
return &app, nil
}
type Config struct {
SystemPrompt string
OpenAIBaseURL string
BotToken string
MaxConcurrentRequests uint
ChatID int64
ChannelID int64
}
func LoadConfig(cfg *Config) error {
if err := godotenv.Load(".env"); err != nil {
slog.Warn("no env file loaded", "err", err)
}
cfg.OpenAIBaseURL = os.Getenv("OPENAI_BASE_URL")
if cfg.OpenAIBaseURL == "" {
return errors.New("openai base url not set")
}
cfg.BotToken = os.Getenv("BOT_TOKEN")
mcg, err := strconv.ParseUint(
os.Getenv("MAX_CONCURRENT_REQUESTS"),
10, 64,
)
if err != nil {
return err
}
cfg.MaxConcurrentRequests = uint(mcg)
chatID, err := strconv.ParseInt(os.Getenv("CHAT_ID"), 10, 64)
if err != nil {
return err
}
if chatID == 0 {
return errors.New("non-zero chat id is unsafe")
}
cfg.ChatID = chatID
channelID, err := strconv.ParseInt(os.Getenv("CHANNEL_ID"), 10, 64)
if err != nil {
return err
}
if chatID == 0 {
slog.Warn("channel id is not set")
}
cfg.ChannelID = channelID
sysPromptPath := os.Getenv("SYSTEM_PROMPT_PATH")
promptBytes, err := os.ReadFile(sysPromptPath)
if err != nil {
slog.Warn("could not load system prompt", "path", sysPromptPath)
} else {
cfg.SystemPrompt = string(promptBytes)
}
return nil
}
func main() {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
log := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
log.Info("Starting GROK")
var cfg Config
if err := LoadConfig(&cfg); err != nil {
log.Error("loading config", "err", err)
os.Exit(1)
}
prompter := NewOpenAIProoooompter(&cfg, log)
app, err := NewApp(&cfg, prompter)
if err != nil {
log.Error("initializing app", "err", err)
os.Exit(1)
}
go func() {
if err := app.HandleUpdates(ctx); err != nil {
log.Error("handleUpdates exited", "err", err)
}
cancel()
}()
<-ctx.Done()
}