A demo of Open AI (https://beta.openai.com)'s API
https://ask.mills.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
4.0 KiB
170 lines
4.0 KiB
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/PullRequestInc/go-gpt3"
|
|
)
|
|
|
|
const indexTemplate = `
|
|
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<link rel="stylesheet" href="https://cdn.simplecss.org/simple.min.css">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
|
<title>GPT3 Demo</title>
|
|
</head>
|
|
<body>
|
|
<header>
|
|
<nav>
|
|
<a href="/">Home</a>
|
|
<a href="https://git.mills.io/prologic/gpt">Source</a>
|
|
</nav>
|
|
<h1>GPT3 Demo</h1>
|
|
<p>A GPT3 Demo Web Application and API using the <a href="https://beta.openai.com/">Open AI</a> service.</p>
|
|
</header>
|
|
<main>
|
|
<h2>Ask me anything!</h2>
|
|
<form action="/" method="POST">
|
|
<input type="text" name="prompt" placeholder="Enter your prompt here..." required>
|
|
<button type="submit">Go!</button>
|
|
</form>
|
|
<p>{{ .Response }}</p>
|
|
</main>
|
|
<footer>
|
|
<p>Licensed under the <a href="https://git.mills.io/prologic/gpt/blob/master/LICENSE">WTFPL License</a></p>
|
|
</footer>
|
|
</body>
|
|
</html>
|
|
`
|
|
|
|
func render(name, tmpl string, ctx interface{}, w io.Writer) error {
|
|
t, err := template.New(name).Parse(tmpl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return t.Execute(w, ctx)
|
|
}
|
|
|
|
type gptHandler struct {
|
|
cli gpt3.Client
|
|
engine string
|
|
}
|
|
|
|
func (h gptHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodHead || r.Method == http.MethodGet && r.URL.Path == "/health" {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
fmt.Fprintf(w, "OK\n")
|
|
return
|
|
}
|
|
|
|
if r.Method == http.MethodHead || r.Method == http.MethodGet || r.Method == http.MethodPost {
|
|
switch r.URL.Path {
|
|
case "/":
|
|
ctx := struct {
|
|
Response string
|
|
}{}
|
|
|
|
if r.Method == http.MethodHead {
|
|
return
|
|
}
|
|
|
|
if r.Method == http.MethodPost {
|
|
prompt := r.FormValue("prompt")
|
|
|
|
if prompt == "" {
|
|
ctx.Response = "Error: no prompt entered"
|
|
} else {
|
|
response, err := h.cli.CompletionWithEngine(r.Context(), h.engine, gpt3.CompletionRequest{
|
|
Prompt: []string{prompt},
|
|
MaxTokens: gpt3.IntPtr(150),
|
|
Echo: false,
|
|
Temperature: gpt3.Float32Ptr(0.7),
|
|
TopP: gpt3.Float32Ptr(0),
|
|
FrequencyPenalty: 0,
|
|
PresencePenalty: 0,
|
|
})
|
|
if err != nil {
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
for _, choice := range response.Choices {
|
|
text := strings.TrimSpace(choice.Text)
|
|
if len(text) > 0 {
|
|
ctx.Response += text
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := render("index", indexTemplate, ctx, w); err != nil {
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
default:
|
|
http.Error(w, "Not Found", http.StatusNotFound)
|
|
return
|
|
}
|
|
}
|
|
|
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
|
|
func main() {
|
|
var (
|
|
engine = flag.String("engine", "text-davinci-002", "The engine to use")
|
|
server = flag.Bool("server", false, "Run in server mode")
|
|
bind = flag.String("bind", "0.0.0.0:8000", "[interface]:<port> to bind to in server mode")
|
|
)
|
|
|
|
flag.Parse()
|
|
|
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
|
if apiKey == "" {
|
|
log.Fatalln("missing openai key")
|
|
}
|
|
|
|
cli := gpt3.NewClient(apiKey)
|
|
|
|
if !*server {
|
|
stdin, err := ioutil.ReadAll(os.Stdin)
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
input := strings.TrimSpace(string(stdin))
|
|
|
|
ctx := context.Background()
|
|
response, err := cli.CompletionWithEngine(ctx, *engine, gpt3.CompletionRequest{
|
|
Prompt: []string{input},
|
|
MaxTokens: gpt3.IntPtr(150),
|
|
Echo: false,
|
|
Temperature: gpt3.Float32Ptr(0.7),
|
|
TopP: gpt3.Float32Ptr(0),
|
|
FrequencyPenalty: 0,
|
|
PresencePenalty: 0,
|
|
})
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
|
|
for _, choice := range response.Choices {
|
|
text := strings.TrimSpace(choice.Text)
|
|
if len(text) > 0 {
|
|
fmt.Println(text)
|
|
}
|
|
}
|
|
} else {
|
|
log.Fatal(http.ListenAndServe(*bind, gptHandler{cli: cli, engine: *engine}))
|
|
}
|
|
}
|
|
|