diff --git a/main.go b/main.go index 5cb947c..b80eb74 100644 --- a/main.go +++ b/main.go @@ -27,6 +27,8 @@ type FileReq struct { serverconn *ssh.ServerConn } +var authorizedKeys = os.Getenv("SHARETHIS_AUTHORIZEDKEYS") + func main() { bg := flag.Bool("bg", false, "sends the process in the background") server := flag.Bool("server", false, "makes the process an http server") @@ -34,12 +36,28 @@ func main() { sshport := flag.String("sshport", "2022", "the remote ssh port") httpport := flag.String("httpport", "8888", "the remote server's http port") sharecount := flag.Uint("count", 1, "Amount of times you want to share this file") + serverkey := flag.String("serverkey", "id_rsa", "Path to the server private key") flag.Parse() + if envsshport := os.Getenv("SHARETHIS_SSHPORT"); envsshport != "" { + *sshport = envsshport + } + if envhttpport := os.Getenv("SHARETHIS_HTTPPORT"); envhttpport != "" { + *httpport = envhttpport + } + if envremotehost := os.Getenv("SHARETHIS_REMOTEHOST"); envremotehost != "" { + *remotehost = envremotehost + } + if authorizedKeys == "" { + authorizedKeys = fmt.Sprintf("%s/.ssh/authorized_keys", os.ExpandEnv("HOME")) + } if *sharecount > 0 { *sharecount-- } if *server { - runServer("0.0.0.0", *sshport, *httpport, "id_rsa") + if envserverkey := os.Getenv("SHARETHIS_SERVERKEY"); envserverkey != "" { + *serverkey = envserverkey + } + runServer("0.0.0.0", *sshport, *httpport, *serverkey) } if len(flag.Args()) < 1 { log.Fatalln("Need filename") @@ -62,12 +80,6 @@ func main() { } os.Exit(0) } - keypath := fmt.Sprintf("%s/.ssh/st_rsa", os.Getenv("HOME")) - auth, err := PublicKeyFile(keypath) - if err != nil { - fmt.Println(err) - auth = SSHAgent() - } var username string userobj, err := user.Current() if err != nil { @@ -76,15 +88,26 @@ func main() { } else { username = userobj.Username } + sshConfig := &ssh.ClientConfig{ User: username, - Auth: []ssh.AuthMethod{ - auth, - }, + Auth: []ssh.AuthMethod{}, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, } + if agent, err := SSHAgent(); err == nil { + sshConfig.Auth = append(sshConfig.Auth, agent) + } + + keypath := fmt.Sprintf("%s/.ssh/st_rsa", os.Getenv("HOME")) + auth, err := PublicKeyFile(keypath) + if err != nil { + fmt.Println(err) + } else { + sshConfig.Auth = append(sshConfig.Auth, auth) + } + connection, err := ssh.Dial("tcp", fmt.Sprintf("%s:%s", *remotehost, *sshport), sshConfig) if err != nil { log.Fatalln("Failed to dial: %s", err) @@ -93,7 +116,7 @@ func main() { ch, reqch, err := connection.OpenChannel("Nope", nil) go ssh.DiscardRequests(reqch) if err != nil { - log.Fatalln(err) + log.Fatalln("poop", err) } enc := gob.NewEncoder(ch) path = flag.Arg(0) @@ -108,7 +131,6 @@ func main() { // In the words of weezer, I've got my hashed path. // TODO: Get the remote URL from the remote server instead of rebuilding it locally. - // TODO: Clean up the port from the URL if it's 80 or 443 var fileurl string if *httpport == "443" { fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath) @@ -174,14 +196,19 @@ func PublicKeyFile(file string) (ssh.AuthMethod, error) { return ssh.PublicKeys(key), nil } -func SSHAgent() ssh.AuthMethod { +func SSHAgent() (ssh.AuthMethod, error) { if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { - return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + a := agent.NewClient(sshAgent) + signers, _ := a.Signers() + if len(signers) == 0 { + return nil, fmt.Errorf("No signer found") + } + return ssh.PublicKeysCallback(a.Signers), nil } else { fmt.Println(err) - os.Exit(1) + return nil, err } - return nil + return nil, nil } func runServer(host, sshport, httpport, keyfile string) { @@ -227,13 +254,15 @@ func runServer(host, sshport, httpport, keyfile string) { for { nConn, err := listener.Accept() if err != nil { - log.Fatal("failed to accept incoming connection: ", err) + log.Println("failed to accept incoming connection: ", err) + continue } serverConn, chans, reqs, err := ssh.NewServerConn(nConn, cfg) if err != nil { - log.Fatal("failed to handshake: ", err) + log.Println("failed to handshake: ", err) + continue } // The incoming Request channel must be serviced. go ssh.DiscardRequests(reqs) @@ -243,7 +272,8 @@ func runServer(host, sshport, httpport, keyfile string) { for newChannel := range chans { channel, requests, err := newChannel.Accept() if err != nil { - log.Fatalf("Could not accept channel: %v", err) + log.Println("Could not accept channel: ", err) + continue } go func(in <-chan *ssh.Request) { @@ -308,7 +338,7 @@ func runServer(host, sshport, httpport, keyfile string) { } func buildCfg() *ssh.ServerConfig { - authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys") + authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeys) if err != nil { log.Fatalf("Failed to load authorized_keys, err: %v", err) } @@ -318,7 +348,7 @@ func buildCfg() *ssh.ServerConfig { pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) fmt.Println(comment) if err != nil { - log.Fatal(err) + log.Fatal("keypoop", err) } authorizedKeysMap[string(pubKey.Marshal())] = comment @@ -330,7 +360,8 @@ func buildCfg() *ssh.ServerConfig { cfg.SetDefaults() cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") } cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if _, ok := authorizedKeysMap[string(key.Marshal())]; ok { + if user, ok := authorizedKeysMap[string(key.Marshal())]; ok { + fmt.Println("Key used:", user) return nil, nil } return nil, fmt.Errorf("unknown public key for %q", conn.User())