You've already forked grok
241 lines
4.9 KiB
Go
241 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"strings"
|
|
|
|
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
|
"github.com/joho/godotenv"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/option"
|
|
"github.com/openai/openai-go/v3/responses"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
)
|
|
|
|
type OpenAIPrompter struct {
|
|
cli openai.Client
|
|
cfg *Config
|
|
}
|
|
|
|
func (p *OpenAIPrompter) Prompt(ctx context.Context, question string) (*HighlyTrustedResponse, error) {
|
|
resp, err := p.cli.Responses.New(ctx, responses.ResponseNewParams{
|
|
Instructions: openai.String(p.cfg.SystemPrompt),
|
|
Input: responses.ResponseNewParamsInputUnion{
|
|
OfString: openai.String(question),
|
|
},
|
|
Reasoning: shared.ReasoningParam{
|
|
Effort: shared.ReasoningEffortXhigh,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &HighlyTrustedResponse{
|
|
Text: resp.OutputText(),
|
|
}, nil
|
|
}
|
|
|
|
func NewOpenAIProoooompter(cfg *Config) *OpenAIPrompter {
|
|
return &OpenAIPrompter{
|
|
cli: openai.NewClient(
|
|
option.WithBaseURL(cfg.OpenAIBaseURL),
|
|
),
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
type HighlyTrustedResponse struct {
|
|
Text string
|
|
}
|
|
|
|
type Proompter interface {
|
|
Prompt(ctx context.Context, question string) (*HighlyTrustedResponse, error)
|
|
}
|
|
|
|
type App struct {
|
|
log *slog.Logger
|
|
bot *tgbotapi.BotAPI
|
|
proompter Proompter
|
|
sema chan struct{}
|
|
config *Config
|
|
}
|
|
|
|
func (a *App) handleMessage(ctx context.Context, msg *tgbotapi.Message) error {
|
|
if msg.Chat == nil {
|
|
return nil
|
|
}
|
|
|
|
chatID := msg.Chat.ID
|
|
if chatID != a.config.ChatID || !strings.HasPrefix(msg.Text, "@grok") {
|
|
return nil
|
|
}
|
|
|
|
question := msg.Text[len("@grok"):]
|
|
|
|
select {
|
|
case a.sema <- struct{}{}:
|
|
default:
|
|
a.log.Info("concurrency limit hit", "msg", msg)
|
|
return nil
|
|
}
|
|
defer func() { <-a.sema }()
|
|
|
|
a.log.Info(
|
|
"message passed all guards",
|
|
"og_text", msg.Text,
|
|
"transformed_text", question,
|
|
)
|
|
|
|
response, err := a.proompter.Prompt(ctx, question)
|
|
if err != nil {
|
|
return fmt.Errorf("prompting: %w", err)
|
|
}
|
|
|
|
_, err = a.bot.Send(tgbotapi.MessageConfig{
|
|
BaseChat: tgbotapi.BaseChat{
|
|
ChatID: chatID,
|
|
ReplyToMessageID: msg.MessageID,
|
|
},
|
|
Text: response.Text,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("responding: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *App) HandleUpdates(ctx context.Context) error {
|
|
upds := a.bot.GetUpdatesChan(tgbotapi.UpdateConfig{})
|
|
for {
|
|
select {
|
|
case upd, ok := <-upds:
|
|
if !ok {
|
|
return errors.New("channel closed")
|
|
}
|
|
|
|
if upd.Message != nil {
|
|
a.log.Info("new message", "chat", upd.Message.Chat)
|
|
go func() {
|
|
if err := a.handleMessage(ctx, upd.Message); err != nil {
|
|
a.log.Error("handling message", "msg", upd.Message, "err", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func NewApp(cfg *Config, prompter Proompter) (*App, error) {
|
|
var app App
|
|
var err error
|
|
|
|
app.proompter = prompter
|
|
|
|
app.log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
|
|
|
|
app.bot, err = tgbotapi.NewBotAPI(cfg.BotToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if cfg.MaxConcurrentRequests == 0 {
|
|
return nil, errors.New("concurrency limit not set")
|
|
}
|
|
app.sema = make(chan struct{}, cfg.MaxConcurrentRequests)
|
|
|
|
app.config = cfg
|
|
|
|
return &app, nil
|
|
}
|
|
|
|
type Config struct {
|
|
SystemPrompt string
|
|
OpenAIBaseURL string
|
|
BotToken string
|
|
MaxConcurrentRequests uint
|
|
ChatID int64
|
|
}
|
|
|
|
func LoadConfig(cfg *Config) error {
|
|
if err := godotenv.Load(".env"); err != nil {
|
|
slog.Warn("no env file loaded", "err", err)
|
|
}
|
|
|
|
cfg.OpenAIBaseURL = os.Getenv("OPENAI_BASE_URL")
|
|
if cfg.OpenAIBaseURL == "" {
|
|
return errors.New("openai base url not set")
|
|
}
|
|
|
|
cfg.BotToken = os.Getenv("BOT_TOKEN")
|
|
mcg, err := strconv.ParseUint(
|
|
os.Getenv("MAX_CONCURRENT_REQUESTS"),
|
|
10, 64,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg.MaxConcurrentRequests = uint(mcg)
|
|
chatID, err := strconv.ParseInt(os.Getenv("CHAT_ID"), 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if chatID == 0 {
|
|
return errors.New("non-zero chat id is unsafe")
|
|
}
|
|
|
|
cfg.ChatID = chatID
|
|
|
|
sysPromptPath := os.Getenv("SYSTEM_PROMPT_PATH")
|
|
promptBytes, err := os.ReadFile(sysPromptPath)
|
|
if err != nil {
|
|
slog.Warn("could not load system prompt", "path", sysPromptPath)
|
|
} else {
|
|
cfg.SystemPrompt = string(promptBytes)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
|
defer cancel()
|
|
|
|
log := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
|
|
log.Info("Starting GROK")
|
|
|
|
var cfg Config
|
|
if err := LoadConfig(&cfg); err != nil {
|
|
log.Error("loading config", "err", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
prompter := NewOpenAIProoooompter(&cfg)
|
|
|
|
app, err := NewApp(&cfg, prompter)
|
|
if err != nil {
|
|
log.Error("initializing app", "err", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
go func() {
|
|
if err := app.HandleUpdates(ctx); err != nil {
|
|
log.Error("handleUpdates exited", "err", err)
|
|
}
|
|
cancel()
|
|
}()
|
|
|
|
<-ctx.Done()
|
|
}
|