From 46286cb851e019b65620822d44d4a258ec6f5db9 Mon Sep 17 00:00:00 2001 From: Olivier Tremblay Date: Tue, 11 Dec 2018 15:14:26 -0500 Subject: [PATCH] Extract some stuff out --- channel.go | 8 ++ main.go => cmd/sharethis/main.go | 216 +++++++++++++++---------------- filereq.go | 89 +++++++++++++ 3 files changed, 199 insertions(+), 114 deletions(-) create mode 100644 channel.go rename main.go => cmd/sharethis/main.go (73%) create mode 100644 filereq.go diff --git a/channel.go b/channel.go new file mode 100644 index 0000000..42d821e --- /dev/null +++ b/channel.go @@ -0,0 +1,8 @@ +package sharethis + +import "golang.org/x/crypto/ssh" + +type NewChannel struct { + ssh.NewChannel + DoZip bool +} diff --git a/main.go b/cmd/sharethis/main.go similarity index 73% rename from main.go rename to cmd/sharethis/main.go index 80a5e2e..c57eeee 100644 --- a/main.go +++ b/cmd/sharethis/main.go @@ -1,8 +1,8 @@ -package main // import "github.com/otremblay/sharethis" +package main // import "github.com/otremblay/sharethis/cmd/sharethis" import ( "archive/tar" - "crypto/sha256" + "archive/zip" "encoding/gob" "flag" "fmt" @@ -13,26 +13,19 @@ import ( "net/http" "os" "os/user" - "path/filepath" "strings" "sync" - "syscall" + "github.com/otremblay/sharethis" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) -type FileReq struct { - Path string - ShareCount uint - serverconn *ssh.ServerConn -} - var authorizedKeys = os.Getenv("SHARETHIS_AUTHORIZEDKEYS") func main() { - bg := flag.Bool("bg", false, "sends the process in the background") server := flag.Bool("server", false, "makes the process an http server") + // admin := flag.String("admin", "", "determine admin user") remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact") sshport := flag.String("sshport", "2022", "the remote ssh port") httpport := flag.String("httpport", "8888", "the remote server's http port") @@ -63,28 +56,8 @@ func main() { if len(flag.Args()) < 1 { log.Fatalln("Need filename") } - var path string - var dir bool - var name string - if fs, err := os.Stat(flag.Arg(0)); err != nil { - log.Fatalln("Can't read file") - } else { - name = fs.Name() - p, err := filepath.Abs(fs.Name()) - if err != nil { - log.Fatalln("Can't read file") - } - path = p - dir = fs.IsDir() - } + path := flag.Arg(0) - if *bg { - _, err := syscall.ForkExec(os.Args[0], append([]string{os.Args[0]}, flag.Args()...), &syscall.ProcAttr{Files: []uintptr{0, 1, 2}}) - if err != nil { - log.Fatalln(err) - } - os.Exit(0) - } var username string userobj, err := user.Current() if err != nil { @@ -124,36 +97,10 @@ func main() { log.Fatalln("poop", err) } enc := gob.NewEncoder(ch) - path = flag.Arg(0) + fr := sharethis.NewFileReq(path, *httpport, *remotehost, *sharecount) + fmt.Println(fr) - hostname, err := os.Hostname() - if err != nil { - fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()") - hostname = "unknown" - } - fullpath := fmt.Sprintf("%s@%s:%s", username, hostname, path) - hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath))) - - // In the words of weezer, I've got my hashed path. - // TODO: Get the remote URL from the remote server instead of rebuilding it locally. - var fileurl string - fullpath = fmt.Sprintf("%s/%s", hashedpath, name) - if *httpport == "443" { - fileurl = fmt.Sprintf("https://%s/%s", *remotehost, fullpath) - } else if *httpport == "80" { - fileurl = fmt.Sprintf("http://%s/%s", *remotehost, fullpath) - } else { - fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, fullpath) - } - - if dir { - fmt.Println(fileurl + ".tar") - fmt.Println(fileurl + ".zip") - } else { - fmt.Println(fileurl) - } - - err = enc.Encode(&FileReq{Path: fullpath, ShareCount: *sharecount}) + err = enc.Encode(fr) if err != nil { fmt.Println(err) ch.Close() @@ -163,23 +110,23 @@ func main() { mu := &sync.Mutex{} defer ch.Close() - var getchfn func() ssh.NewChannel - if dir { - ncc := connection.HandleChannelOpen(fullpath + ".tar") - ncc2 := connection.HandleChannelOpen(fullpath + ".zip") - getchfn = func() ssh.NewChannel { + var getchfn func() *sharethis.NewChannel + if fr.IsDir { + ncc := connection.HandleChannelOpen(fr.Path + ".tar") + ncc2 := connection.HandleChannelOpen(fr.Path + ".zip") + getchfn = func() *sharethis.NewChannel { select { case nc := <-ncc: - return nc + return &sharethis.NewChannel{NewChannel: nc} case nc := <-ncc2: - return nc + return &sharethis.NewChannel{NewChannel: nc, DoZip: true} } } } else { - ncc := connection.HandleChannelOpen(fullpath) - getchfn = func() ssh.NewChannel { - return <-ncc + ncc := connection.HandleChannelOpen(fr.Path) + getchfn = func() *sharethis.NewChannel { + return &sharethis.NewChannel{<-ncc, false} } } @@ -191,36 +138,40 @@ func main() { fmt.Fprintln(os.Stderr, err) } - go func() { - f, err := os.Open(path) + go handleFileRequest(path, ch, mu, sharecount, nc.DoZip) - if err != nil { - fmt.Fprintln(ch, "Sharethis error") - return - } - if fs, _ := f.Stat(); fs.IsDir() { - tw := tar.NewWriter(ch) - err := WriteFiles(tw, fs, ".") + } +} + +func handleFileRequest(path string, ch ssh.Channel, mu *sync.Mutex, sharecount *uint, dozip bool) { + + f, err := os.Open(path) + + if err != nil { + fmt.Fprintln(ch, "Sharethis error") + return + } + if fs, _ := f.Stat(); fs.IsDir() { + if dozip { + w := zip.NewWriter(ch) + zipfn := func(fi os.FileInfo, path string, file *os.File) error { + hdr, err := zip.FileInfoHeader(fi) + hdr.Name = path if err != nil { - fmt.Fprintln(ch, "Error building tar", err) + return fmt.Errorf("Couldn't build zip header: %v", err) } - tw.Close() - mu.Lock() - if *sharecount == 0 { - mu.Unlock() - os.Exit(0) + f, err := w.CreateHeader(hdr) + if err != nil { + return fmt.Errorf("Couldn't write zip header: %v", err) } - *sharecount-- - mu.Unlock() - ch.Close() - - return + _, err = io.Copy(f, file) + return err } - _, err = io.Copy(ch, f) + err := WriteFiles(zipfn, fs, ".") if err != nil { - fmt.Println(err) - return + fmt.Fprintln(ch, "Error building tar", err) } + w.Close() mu.Lock() if *sharecount == 0 { mu.Unlock() @@ -229,12 +180,59 @@ func main() { *sharecount-- mu.Unlock() ch.Close() - }() + return + + } else { + tw := tar.NewWriter(ch) + tarfn := func(fi os.FileInfo, path string, file *os.File) error { + hdr, err := tar.FileInfoHeader(fi, "") + hdr.Name = path + if err != nil { + return fmt.Errorf("Couldn't build tar header: %v", err) + } + + err = tw.WriteHeader(hdr) + if err != nil { + return fmt.Errorf("Couldn't write tar header: %v", err) + } + _, err = io.Copy(tw, file) + return err + } + err := WriteFiles(tarfn, fs, ".") + if err != nil { + fmt.Fprintln(ch, "Error building tar", err) + } + tw.Close() + mu.Lock() + if *sharecount == 0 { + mu.Unlock() + os.Exit(0) + } + *sharecount-- + mu.Unlock() + ch.Close() + + return + } } + _, err = io.Copy(ch, f) + if err != nil { + fmt.Println(err) + return + } + mu.Lock() + if *sharecount == 0 { + mu.Unlock() + os.Exit(0) + } + *sharecount-- + mu.Unlock() + ch.Close() + } -func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error { +func WriteFiles(writerfn func(os.FileInfo, string, *os.File) error, fi os.FileInfo, path string) error { p := path + "/" + fi.Name() f, err := os.Open(p) if err != nil { @@ -247,25 +245,15 @@ func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error { } var cumulerror error for _, nfi := range fis { - err := WriteFiles(tw, nfi, p) + err := WriteFiles(writerfn, nfi, p) if err != nil { cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err) } } return cumulerror } - hdr, err := tar.FileInfoHeader(fi, "") - hdr.Name = p - if err != nil { - return fmt.Errorf("Couldn't build tar header: %v", err) - } - err = tw.WriteHeader(hdr) - if err != nil { - return fmt.Errorf("Couldn't write tar header: %v", err) - } - _, err = io.Copy(tw, f) - return err + return writerfn(fi, p, f) } func PublicKeyFile(file string) (ssh.AuthMethod, error) { @@ -297,16 +285,16 @@ func SSHAgent() (ssh.AuthMethod, error) { } func runServer(host, sshport, httpport, keyfile string) { - filemap := map[string]*FileReq{} + filemap := map[string]*sharethis.FileReq{} syncy := &sync.RWMutex{} - mapget := func(key string) (*FileReq, bool) { + mapget := func(key string) (*sharethis.FileReq, bool) { syncy.RLock() defer syncy.RUnlock() c, ok := filemap[key] return c, ok } - mapset := func(key string, fr *FileReq) { + mapset := func(key string, fr *sharethis.FileReq) { syncy.Lock() filemap[key] = fr syncy.Unlock() @@ -369,7 +357,7 @@ func runServer(host, sshport, httpport, keyfile string) { go func() { dec := gob.NewDecoder(channel) - var filereq FileReq + var filereq sharethis.FileReq for { err := dec.Decode(&filereq) @@ -378,7 +366,7 @@ func runServer(host, sshport, httpport, keyfile string) { } // TODO: also take the sharecount in the map. - filereq.serverconn = serverConn + filereq.ServerConn = serverConn mapset(filereq.Path, &filereq) go func() { serverConn.Wait(); mapdel(filereq.Path) }() return @@ -406,7 +394,7 @@ func runServer(host, sshport, httpport, keyfile string) { } fp = strings.TrimSuffix(fp, ".tar") if fr, ok := mapget(fp); ok { - channel, req, err := fr.serverconn.OpenChannel(fp+suffix, nil) + channel, req, err := fr.ServerConn.OpenChannel(fp+suffix, nil) if err != nil { fmt.Fprintln(os.Stderr, err) return diff --git a/filereq.go b/filereq.go new file mode 100644 index 0000000..9422e41 --- /dev/null +++ b/filereq.go @@ -0,0 +1,89 @@ +package sharethis // import "github.com/otremblay/sharethis" + +import ( + "crypto/sha256" + "fmt" + "log" + "os" + "os/user" + "path/filepath" + + "golang.org/x/crypto/ssh" +) + +type FileReq struct { + Path string + ShareCount uint + Username string + Hostname string + ServerConn *ssh.ServerConn + fileurl string + httpport string + remotehost string + localpath string + filename string + IsDir bool +} + +func NewFileReq(path, httpport, remotehost string, shareCount uint) *FileReq { + return getFileReq(DefaultPrefixFn, path, httpport, remotehost, shareCount) +} + +func (fr *FileReq) RebuildUrls() { + fr = getFileReq(DefaultPrefixFn, fr.localpath, fr.httpport, fr.remotehost, fr.ShareCount) +} + +func (fr *FileReq) String() string { + if fr.IsDir { + return fmt.Sprintf("%s\n%s", fr.fileurl+".tar", fr.fileurl+".zip") + } else { + return fmt.Sprintf(fr.fileurl) + } +} + +func getRemotePath(prefixFn func() string, path string) string { + fullpath := fmt.Sprintf("%s:%s", prefixFn(), path) + hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath))) + return fmt.Sprintf("%s/%s", hashedpath, filepath.Base(path)) +} + +func getFileReq(prefixFn func() string, path string, httpport string, remotehost string, sharecount uint) *FileReq { + var dir bool + if fs, err := os.Stat(path); err != nil { + log.Fatalln("Can't read file") + } else { + p, err := filepath.Abs(fs.Name()) + if err != nil { + log.Fatalln("Can't read file") + } + dir = fs.IsDir() + path = p + } + remotepath := getRemotePath(prefixFn, path) + var fileurl string + if httpport == "443" { + fileurl = fmt.Sprintf("https://%s/%s", remotehost, remotepath) + } else if httpport == "80" { + fileurl = fmt.Sprintf("http://%s/%s", remotehost, remotepath) + } else { + fileurl = fmt.Sprintf("http://%s:%s/%s", remotehost, httpport, remotepath) + } + return &FileReq{Path: remotepath, ShareCount: sharecount, fileurl: fileurl, localpath: path, IsDir: dir} +} + +var DefaultPrefixFn = func() string { + hostname, err := os.Hostname() + if err != nil { + fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()") + hostname = "unknown" + } + var username string + userobj, err := user.Current() + if err != nil { + fmt.Fprintln(os.Stderr, "Could not get user with user.Current()") + username = "unknown" + } else { + username = userobj.Username + } + return fmt.Sprintf("%s@%s", username, hostname) +}