diff --git a/main.go b/main.go index 38037f1..80a5e2e 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main // import "github.com/otremblay/sharethis" import ( + "archive/tar" "crypto/sha256" "encoding/gob" "flag" @@ -63,14 +64,18 @@ func main() { log.Fatalln("Need filename") } var path string + var dir bool + var name string if fs, err := os.Stat(flag.Arg(0)); err != nil { log.Fatalln("Can't read file") } else { + name = fs.Name() p, err := filepath.Abs(fs.Name()) if err != nil { log.Fatalln("Can't read file") } path = p + dir = fs.IsDir() } if *bg { @@ -132,15 +137,23 @@ func main() { // In the words of weezer, I've got my hashed path. // TODO: Get the remote URL from the remote server instead of rebuilding it locally. var fileurl string + fullpath = fmt.Sprintf("%s/%s", hashedpath, name) if *httpport == "443" { - fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath) + fileurl = fmt.Sprintf("https://%s/%s", *remotehost, fullpath) } else if *httpport == "80" { - fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath) + fileurl = fmt.Sprintf("http://%s/%s", *remotehost, fullpath) } else { - fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath) + fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, fullpath) } - fmt.Println(fileurl) - err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount}) + + if dir { + fmt.Println(fileurl + ".tar") + fmt.Println(fileurl + ".zip") + } else { + fmt.Println(fileurl) + } + + err = enc.Encode(&FileReq{Path: fullpath, ShareCount: *sharecount}) if err != nil { fmt.Println(err) ch.Close() @@ -150,9 +163,28 @@ func main() { mu := &sync.Mutex{} defer ch.Close() - ncc := connection.HandleChannelOpen(hashedpath) + var getchfn func() ssh.NewChannel + if dir { + ncc := connection.HandleChannelOpen(fullpath + ".tar") + ncc2 := connection.HandleChannelOpen(fullpath + ".zip") + getchfn = func() ssh.NewChannel { + select { + case nc := <-ncc: + return nc + case nc := <-ncc2: + return nc + } + } + + } else { + ncc := connection.HandleChannelOpen(fullpath) + getchfn = func() ssh.NewChannel { + return <-ncc + } + } + for { - nc := <-ncc + nc := getchfn() ch, req, err := nc.Accept() go ssh.DiscardRequests(req) if err != nil { @@ -161,10 +193,29 @@ func main() { go func() { f, err := os.Open(path) + if err != nil { fmt.Fprintln(ch, "Sharethis error") return } + if fs, _ := f.Stat(); fs.IsDir() { + tw := tar.NewWriter(ch) + err := WriteFiles(tw, fs, ".") + if err != nil { + fmt.Fprintln(ch, "Error building tar", err) + } + tw.Close() + mu.Lock() + if *sharecount == 0 { + mu.Unlock() + os.Exit(0) + } + *sharecount-- + mu.Unlock() + ch.Close() + + return + } _, err = io.Copy(ch, f) if err != nil { fmt.Println(err) @@ -183,6 +234,40 @@ func main() { } } +func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error { + p := path + "/" + fi.Name() + f, err := os.Open(p) + if err != nil { + return fmt.Errorf("Couldn't write files: %v", err) + } + if fs, _ := f.Stat(); fs.IsDir() { + fis, err := f.Readdir(-1) + if err != nil { + return fmt.Errorf("Couldn't enumerate dir: %v", err) + } + var cumulerror error + for _, nfi := range fis { + err := WriteFiles(tw, nfi, p) + if err != nil { + cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err) + } + } + return cumulerror + } + hdr, err := tar.FileInfoHeader(fi, "") + hdr.Name = p + if err != nil { + return fmt.Errorf("Couldn't build tar header: %v", err) + } + + err = tw.WriteHeader(hdr) + if err != nil { + return fmt.Errorf("Couldn't write tar header: %v", err) + } + _, err = io.Copy(tw, f) + return err +} + func PublicKeyFile(file string) (ssh.AuthMethod, error) { buffer, err := ioutil.ReadFile(file) if err != nil { @@ -310,8 +395,18 @@ func runServer(host, sshport, httpport, keyfile string) { return } fp := strings.TrimPrefix(req.URL.Path, "/") + var suffix string + if strings.HasSuffix(fp, ".zip") { + suffix = ".zip" + fp = strings.TrimSuffix(fp, ".zip") + } + if strings.HasSuffix(fp, ".tar") { + suffix = ".tar" + fp = strings.TrimSuffix(fp, ".tar") + } + fp = strings.TrimSuffix(fp, ".tar") if fr, ok := mapget(fp); ok { - channel, req, err := fr.serverconn.OpenChannel(fp, nil) + channel, req, err := fr.serverconn.OpenChannel(fp+suffix, nil) if err != nil { fmt.Fprintln(os.Stderr, err) return