Enable folder sharing

This commit is contained in:
Olivier Tremblay 2018-10-24 08:07:04 -04:00
parent 1744d4281a
commit 8675e6815a
No known key found for this signature in database
GPG key ID: D1C73ACB855E3A6D

111
main.go
View file

@ -1,6 +1,7 @@
package main // import "github.com/otremblay/sharethis" package main // import "github.com/otremblay/sharethis"
import ( import (
"archive/tar"
"crypto/sha256" "crypto/sha256"
"encoding/gob" "encoding/gob"
"flag" "flag"
@ -63,14 +64,18 @@ func main() {
log.Fatalln("Need filename") log.Fatalln("Need filename")
} }
var path string var path string
var dir bool
var name string
if fs, err := os.Stat(flag.Arg(0)); err != nil { if fs, err := os.Stat(flag.Arg(0)); err != nil {
log.Fatalln("Can't read file") log.Fatalln("Can't read file")
} else { } else {
name = fs.Name()
p, err := filepath.Abs(fs.Name()) p, err := filepath.Abs(fs.Name())
if err != nil { if err != nil {
log.Fatalln("Can't read file") log.Fatalln("Can't read file")
} }
path = p path = p
dir = fs.IsDir()
} }
if *bg { if *bg {
@ -132,15 +137,23 @@ func main() {
// In the words of weezer, I've got my hashed path. // In the words of weezer, I've got my hashed path.
// TODO: Get the remote URL from the remote server instead of rebuilding it locally. // TODO: Get the remote URL from the remote server instead of rebuilding it locally.
var fileurl string var fileurl string
fullpath = fmt.Sprintf("%s/%s", hashedpath, name)
if *httpport == "443" { if *httpport == "443" {
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath) fileurl = fmt.Sprintf("https://%s/%s", *remotehost, fullpath)
} else if *httpport == "80" { } else if *httpport == "80" {
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath) fileurl = fmt.Sprintf("http://%s/%s", *remotehost, fullpath)
} else { } else {
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath) fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, fullpath)
} }
fmt.Println(fileurl)
err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount}) if dir {
fmt.Println(fileurl + ".tar")
fmt.Println(fileurl + ".zip")
} else {
fmt.Println(fileurl)
}
err = enc.Encode(&FileReq{Path: fullpath, ShareCount: *sharecount})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
ch.Close() ch.Close()
@ -150,9 +163,28 @@ func main() {
mu := &sync.Mutex{} mu := &sync.Mutex{}
defer ch.Close() defer ch.Close()
ncc := connection.HandleChannelOpen(hashedpath) var getchfn func() ssh.NewChannel
if dir {
ncc := connection.HandleChannelOpen(fullpath + ".tar")
ncc2 := connection.HandleChannelOpen(fullpath + ".zip")
getchfn = func() ssh.NewChannel {
select {
case nc := <-ncc:
return nc
case nc := <-ncc2:
return nc
}
}
} else {
ncc := connection.HandleChannelOpen(fullpath)
getchfn = func() ssh.NewChannel {
return <-ncc
}
}
for { for {
nc := <-ncc nc := getchfn()
ch, req, err := nc.Accept() ch, req, err := nc.Accept()
go ssh.DiscardRequests(req) go ssh.DiscardRequests(req)
if err != nil { if err != nil {
@ -161,10 +193,29 @@ func main() {
go func() { go func() {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
fmt.Fprintln(ch, "Sharethis error") fmt.Fprintln(ch, "Sharethis error")
return return
} }
if fs, _ := f.Stat(); fs.IsDir() {
tw := tar.NewWriter(ch)
err := WriteFiles(tw, fs, ".")
if err != nil {
fmt.Fprintln(ch, "Error building tar", err)
}
tw.Close()
mu.Lock()
if *sharecount == 0 {
mu.Unlock()
os.Exit(0)
}
*sharecount--
mu.Unlock()
ch.Close()
return
}
_, err = io.Copy(ch, f) _, err = io.Copy(ch, f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
@ -183,6 +234,40 @@ func main() {
} }
} }
func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
p := path + "/" + fi.Name()
f, err := os.Open(p)
if err != nil {
return fmt.Errorf("Couldn't write files: %v", err)
}
if fs, _ := f.Stat(); fs.IsDir() {
fis, err := f.Readdir(-1)
if err != nil {
return fmt.Errorf("Couldn't enumerate dir: %v", err)
}
var cumulerror error
for _, nfi := range fis {
err := WriteFiles(tw, nfi, p)
if err != nil {
cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err)
}
}
return cumulerror
}
hdr, err := tar.FileInfoHeader(fi, "")
hdr.Name = p
if err != nil {
return fmt.Errorf("Couldn't build tar header: %v", err)
}
err = tw.WriteHeader(hdr)
if err != nil {
return fmt.Errorf("Couldn't write tar header: %v", err)
}
_, err = io.Copy(tw, f)
return err
}
func PublicKeyFile(file string) (ssh.AuthMethod, error) { func PublicKeyFile(file string) (ssh.AuthMethod, error) {
buffer, err := ioutil.ReadFile(file) buffer, err := ioutil.ReadFile(file)
if err != nil { if err != nil {
@ -310,8 +395,18 @@ func runServer(host, sshport, httpport, keyfile string) {
return return
} }
fp := strings.TrimPrefix(req.URL.Path, "/") fp := strings.TrimPrefix(req.URL.Path, "/")
var suffix string
if strings.HasSuffix(fp, ".zip") {
suffix = ".zip"
fp = strings.TrimSuffix(fp, ".zip")
}
if strings.HasSuffix(fp, ".tar") {
suffix = ".tar"
fp = strings.TrimSuffix(fp, ".tar")
}
fp = strings.TrimSuffix(fp, ".tar")
if fr, ok := mapget(fp); ok { if fr, ok := mapget(fp); ok {
channel, req, err := fr.serverconn.OpenChannel(fp, nil) channel, req, err := fr.serverconn.OpenChannel(fp+suffix, nil)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
return return