Extract some stuff out

This commit is contained in:
Olivier Tremblay 2018-12-11 15:14:26 -05:00
parent 8675e6815a
commit 46286cb851
No known key found for this signature in database
GPG key ID: D1C73ACB855E3A6D
3 changed files with 199 additions and 114 deletions

8
channel.go Normal file
View file

@ -0,0 +1,8 @@
package sharethis
import "golang.org/x/crypto/ssh"
type NewChannel struct {
ssh.NewChannel
DoZip bool
}

View file

@ -1,8 +1,8 @@
package main // import "github.com/otremblay/sharethis"
package main // import "github.com/otremblay/sharethis/cmd/sharethis"
import (
"archive/tar"
"crypto/sha256"
"archive/zip"
"encoding/gob"
"flag"
"fmt"
@ -13,26 +13,19 @@ import (
"net/http"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
"syscall"
"github.com/otremblay/sharethis"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
type FileReq struct {
Path string
ShareCount uint
serverconn *ssh.ServerConn
}
var authorizedKeys = os.Getenv("SHARETHIS_AUTHORIZEDKEYS")
func main() {
bg := flag.Bool("bg", false, "sends the process in the background")
server := flag.Bool("server", false, "makes the process an http server")
// admin := flag.String("admin", "", "determine admin user")
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
sshport := flag.String("sshport", "2022", "the remote ssh port")
httpport := flag.String("httpport", "8888", "the remote server's http port")
@ -63,28 +56,8 @@ func main() {
if len(flag.Args()) < 1 {
log.Fatalln("Need filename")
}
var path string
var dir bool
var name string
if fs, err := os.Stat(flag.Arg(0)); err != nil {
log.Fatalln("Can't read file")
} else {
name = fs.Name()
p, err := filepath.Abs(fs.Name())
if err != nil {
log.Fatalln("Can't read file")
}
path = p
dir = fs.IsDir()
}
path := flag.Arg(0)
if *bg {
_, err := syscall.ForkExec(os.Args[0], append([]string{os.Args[0]}, flag.Args()...), &syscall.ProcAttr{Files: []uintptr{0, 1, 2}})
if err != nil {
log.Fatalln(err)
}
os.Exit(0)
}
var username string
userobj, err := user.Current()
if err != nil {
@ -124,36 +97,10 @@ func main() {
log.Fatalln("poop", err)
}
enc := gob.NewEncoder(ch)
path = flag.Arg(0)
fr := sharethis.NewFileReq(path, *httpport, *remotehost, *sharecount)
fmt.Println(fr)
hostname, err := os.Hostname()
if err != nil {
fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()")
hostname = "unknown"
}
fullpath := fmt.Sprintf("%s@%s:%s", username, hostname, path)
hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath)))
// In the words of weezer, I've got my hashed path.
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
var fileurl string
fullpath = fmt.Sprintf("%s/%s", hashedpath, name)
if *httpport == "443" {
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, fullpath)
} else if *httpport == "80" {
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, fullpath)
} else {
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, fullpath)
}
if dir {
fmt.Println(fileurl + ".tar")
fmt.Println(fileurl + ".zip")
} else {
fmt.Println(fileurl)
}
err = enc.Encode(&FileReq{Path: fullpath, ShareCount: *sharecount})
err = enc.Encode(fr)
if err != nil {
fmt.Println(err)
ch.Close()
@ -163,23 +110,23 @@ func main() {
mu := &sync.Mutex{}
defer ch.Close()
var getchfn func() ssh.NewChannel
if dir {
ncc := connection.HandleChannelOpen(fullpath + ".tar")
ncc2 := connection.HandleChannelOpen(fullpath + ".zip")
getchfn = func() ssh.NewChannel {
var getchfn func() *sharethis.NewChannel
if fr.IsDir {
ncc := connection.HandleChannelOpen(fr.Path + ".tar")
ncc2 := connection.HandleChannelOpen(fr.Path + ".zip")
getchfn = func() *sharethis.NewChannel {
select {
case nc := <-ncc:
return nc
return &sharethis.NewChannel{NewChannel: nc}
case nc := <-ncc2:
return nc
return &sharethis.NewChannel{NewChannel: nc, DoZip: true}
}
}
} else {
ncc := connection.HandleChannelOpen(fullpath)
getchfn = func() ssh.NewChannel {
return <-ncc
ncc := connection.HandleChannelOpen(fr.Path)
getchfn = func() *sharethis.NewChannel {
return &sharethis.NewChannel{<-ncc, false}
}
}
@ -191,7 +138,13 @@ func main() {
fmt.Fprintln(os.Stderr, err)
}
go func() {
go handleFileRequest(path, ch, mu, sharecount, nc.DoZip)
}
}
func handleFileRequest(path string, ch ssh.Channel, mu *sync.Mutex, sharecount *uint, dozip bool) {
f, err := os.Open(path)
if err != nil {
@ -199,8 +152,54 @@ func main() {
return
}
if fs, _ := f.Stat(); fs.IsDir() {
if dozip {
w := zip.NewWriter(ch)
zipfn := func(fi os.FileInfo, path string, file *os.File) error {
hdr, err := zip.FileInfoHeader(fi)
hdr.Name = path
if err != nil {
return fmt.Errorf("Couldn't build zip header: %v", err)
}
f, err := w.CreateHeader(hdr)
if err != nil {
return fmt.Errorf("Couldn't write zip header: %v", err)
}
_, err = io.Copy(f, file)
return err
}
err := WriteFiles(zipfn, fs, ".")
if err != nil {
fmt.Fprintln(ch, "Error building tar", err)
}
w.Close()
mu.Lock()
if *sharecount == 0 {
mu.Unlock()
os.Exit(0)
}
*sharecount--
mu.Unlock()
ch.Close()
return
} else {
tw := tar.NewWriter(ch)
err := WriteFiles(tw, fs, ".")
tarfn := func(fi os.FileInfo, path string, file *os.File) error {
hdr, err := tar.FileInfoHeader(fi, "")
hdr.Name = path
if err != nil {
return fmt.Errorf("Couldn't build tar header: %v", err)
}
err = tw.WriteHeader(hdr)
if err != nil {
return fmt.Errorf("Couldn't write tar header: %v", err)
}
_, err = io.Copy(tw, file)
return err
}
err := WriteFiles(tarfn, fs, ".")
if err != nil {
fmt.Fprintln(ch, "Error building tar", err)
}
@ -216,6 +215,7 @@ func main() {
return
}
}
_, err = io.Copy(ch, f)
if err != nil {
fmt.Println(err)
@ -229,12 +229,10 @@ func main() {
*sharecount--
mu.Unlock()
ch.Close()
}()
}
}
func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
func WriteFiles(writerfn func(os.FileInfo, string, *os.File) error, fi os.FileInfo, path string) error {
p := path + "/" + fi.Name()
f, err := os.Open(p)
if err != nil {
@ -247,25 +245,15 @@ func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
}
var cumulerror error
for _, nfi := range fis {
err := WriteFiles(tw, nfi, p)
err := WriteFiles(writerfn, nfi, p)
if err != nil {
cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err)
}
}
return cumulerror
}
hdr, err := tar.FileInfoHeader(fi, "")
hdr.Name = p
if err != nil {
return fmt.Errorf("Couldn't build tar header: %v", err)
}
err = tw.WriteHeader(hdr)
if err != nil {
return fmt.Errorf("Couldn't write tar header: %v", err)
}
_, err = io.Copy(tw, f)
return err
return writerfn(fi, p, f)
}
func PublicKeyFile(file string) (ssh.AuthMethod, error) {
@ -297,16 +285,16 @@ func SSHAgent() (ssh.AuthMethod, error) {
}
func runServer(host, sshport, httpport, keyfile string) {
filemap := map[string]*FileReq{}
filemap := map[string]*sharethis.FileReq{}
syncy := &sync.RWMutex{}
mapget := func(key string) (*FileReq, bool) {
mapget := func(key string) (*sharethis.FileReq, bool) {
syncy.RLock()
defer syncy.RUnlock()
c, ok := filemap[key]
return c, ok
}
mapset := func(key string, fr *FileReq) {
mapset := func(key string, fr *sharethis.FileReq) {
syncy.Lock()
filemap[key] = fr
syncy.Unlock()
@ -369,7 +357,7 @@ func runServer(host, sshport, httpport, keyfile string) {
go func() {
dec := gob.NewDecoder(channel)
var filereq FileReq
var filereq sharethis.FileReq
for {
err := dec.Decode(&filereq)
@ -378,7 +366,7 @@ func runServer(host, sshport, httpport, keyfile string) {
}
// TODO: also take the sharecount in the map.
filereq.serverconn = serverConn
filereq.ServerConn = serverConn
mapset(filereq.Path, &filereq)
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
return
@ -406,7 +394,7 @@ func runServer(host, sshport, httpport, keyfile string) {
}
fp = strings.TrimSuffix(fp, ".tar")
if fr, ok := mapget(fp); ok {
channel, req, err := fr.serverconn.OpenChannel(fp+suffix, nil)
channel, req, err := fr.ServerConn.OpenChannel(fp+suffix, nil)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return

89
filereq.go Normal file
View file

@ -0,0 +1,89 @@
package sharethis // import "github.com/otremblay/sharethis"
import (
"crypto/sha256"
"fmt"
"log"
"os"
"os/user"
"path/filepath"
"golang.org/x/crypto/ssh"
)
type FileReq struct {
Path string
ShareCount uint
Username string
Hostname string
ServerConn *ssh.ServerConn
fileurl string
httpport string
remotehost string
localpath string
filename string
IsDir bool
}
func NewFileReq(path, httpport, remotehost string, shareCount uint) *FileReq {
return getFileReq(DefaultPrefixFn, path, httpport, remotehost, shareCount)
}
func (fr *FileReq) RebuildUrls() {
fr = getFileReq(DefaultPrefixFn, fr.localpath, fr.httpport, fr.remotehost, fr.ShareCount)
}
func (fr *FileReq) String() string {
if fr.IsDir {
return fmt.Sprintf("%s\n%s", fr.fileurl+".tar", fr.fileurl+".zip")
} else {
return fmt.Sprintf(fr.fileurl)
}
}
func getRemotePath(prefixFn func() string, path string) string {
fullpath := fmt.Sprintf("%s:%s", prefixFn(), path)
hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath)))
return fmt.Sprintf("%s/%s", hashedpath, filepath.Base(path))
}
func getFileReq(prefixFn func() string, path string, httpport string, remotehost string, sharecount uint) *FileReq {
var dir bool
if fs, err := os.Stat(path); err != nil {
log.Fatalln("Can't read file")
} else {
p, err := filepath.Abs(fs.Name())
if err != nil {
log.Fatalln("Can't read file")
}
dir = fs.IsDir()
path = p
}
remotepath := getRemotePath(prefixFn, path)
var fileurl string
if httpport == "443" {
fileurl = fmt.Sprintf("https://%s/%s", remotehost, remotepath)
} else if httpport == "80" {
fileurl = fmt.Sprintf("http://%s/%s", remotehost, remotepath)
} else {
fileurl = fmt.Sprintf("http://%s:%s/%s", remotehost, httpport, remotepath)
}
return &FileReq{Path: remotepath, ShareCount: sharecount, fileurl: fileurl, localpath: path, IsDir: dir}
}
var DefaultPrefixFn = func() string {
hostname, err := os.Hostname()
if err != nil {
fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()")
hostname = "unknown"
}
var username string
userobj, err := user.Current()
if err != nil {
fmt.Fprintln(os.Stderr, "Could not get user with user.Current()")
username = "unknown"
} else {
username = userobj.Username
}
return fmt.Sprintf("%s@%s", username, hostname)
}