Refactor logger and use it everywhere 15/44015/3
authorKiran Kamineni <kiran.k.kamineni@intel.com>
Fri, 20 Apr 2018 04:27:01 +0000 (21:27 -0700)
committerKiran Kamineni <kiran.k.kamineni@intel.com>
Fri, 20 Apr 2018 21:48:26 +0000 (14:48 -0700)
Refactored the logger to print the right line
number. This is done by using the runtime.caller
function within the logger.output function

Issue-ID: AAF-257
Change-Id: Ie26de43ca74c71f382d3b5f93ebd4eaf6d51e2b4
Signed-off-by: Kiran Kamineni <kiran.k.kamineni@intel.com>
sms-service/doc/coverage.html
sms-service/doc/coverage.md [moved from sms-service/src/sms/coverage.md with 100% similarity]
sms-service/src/quorumclient/quorumclient.go
sms-service/src/sms/Gopkg.lock
sms-service/src/sms/auth/auth.go
sms-service/src/sms/backend/backend.go
sms-service/src/sms/backend/vault.go
sms-service/src/sms/backend/vault_test.go
sms-service/src/sms/handler/handler.go
sms-service/src/sms/handler/handler_test.go
sms-service/src/sms/log/logger.go

index d03ddde..39ee191 100644 (file)
                        <div id="nav">
                                <select id="files">
                                
-                               <option value="file0">sms/auth/auth.go (17.6%)</option>
+                               <option value="file0">sms/auth/auth.go (76.1%)</option>
                                
-                               <option value="file1">sms/backend/backend.go (66.7%)</option>
+                               <option value="file1">sms/backend/backend.go (80.0%)</option>
                                
-                               <option value="file2">sms/backend/vault.go (60.5%)</option>
+                               <option value="file2">sms/backend/vault.go (72.5%)</option>
                                
-                               <option value="file3">sms/config/config.go (90.9%)</option>
+                               <option value="file3">sms/config/config.go (78.6%)</option>
                                
-                               <option value="file4">sms/handler/handler.go (55.1%)</option>
+                               <option value="file4">sms/handler/handler.go (63.0%)</option>
                                
-                               <option value="file5">sms/log/logger.go (31.2%)</option>
+                               <option value="file5">sms/log/logger.go (65.6%)</option>
                                
-                               <option value="file6">sms/sms.go (82.6%)</option>
+                               <option value="file6">sms/sms.go (77.8%)</option>
                                
                                </select>
                        </div>
@@ -109,6 +109,7 @@ package auth
 
 import (
         "bytes"
+        "crypto"
         "crypto/tls"
         "crypto/x509"
         "encoding/base64"
@@ -119,63 +120,63 @@ import (
         smslogger "sms/log"
 )
 
-var tlsConfig *tls.Config
-
 // GetTLSConfig initializes a tlsConfig using the CA's certificate
 // This config is then used to enable the server for mutual TLS
 func GetTLSConfig(caCertFile string) (*tls.Config, error) <span class="cov10" title="3">{
+
         // Initialize tlsConfig once
-        if tlsConfig == nil </span><span class="cov10" title="3">{
-                caCert, err := ioutil.ReadFile(caCertFile)
+        caCert, err := ioutil.ReadFile(caCertFile)
 
-                if err != nil </span><span class="cov1" title="1">{
-                        return nil, err
-                }</span>
+        if err != nil </span><span class="cov1" title="1">{
+                return nil, err
+        }</span>
 
-                <span class="cov6" title="2">caCertPool := x509.NewCertPool()
-                caCertPool.AppendCertsFromPEM(caCert)
+        <span class="cov6" title="2">caCertPool := x509.NewCertPool()
+        caCertPool.AppendCertsFromPEM(caCert)
 
-                tlsConfig = &amp;tls.Config{
-                        ClientAuth: tls.RequireAndVerifyClientCert,
-                        ClientCAs:  caCertPool,
-                        MinVersion: tls.VersionTLS12,
-                }
-                tlsConfig.BuildNameToCertificate()</span>
+        tlsConfig := &amp;tls.Config{
+                // Change to RequireAndVerify once we have mandatory certs
+                ClientAuth: tls.VerifyClientCertIfGiven,
+                ClientCAs:  caCertPool,
+                MinVersion: tls.VersionTLS12,
         }
-        <span class="cov6" title="2">return tlsConfig, nil</span>
+        tlsConfig.BuildNameToCertificate()
+        return tlsConfig, nil</span>
 }
 
 // GeneratePGPKeyPair produces a PGP key pair and returns
 // two things:
 // A base64 encoded form of the public part of the entity
 // A base64 encoded form of the private key
-func GeneratePGPKeyPair() (string, string, error) <span class="cov0" title="0">{
+func GeneratePGPKeyPair() (string, string, error) <span class="cov10" title="3">{
+
         var entity *openpgp.Entity
-        entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", nil)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        config := &amp;packet.Config{
+                DefaultHash: crypto.SHA256,
+        }
+
+        entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config)
+        if smslogger.CheckError(err, "Create Entity") != nil </span><span class="cov0" title="0">{
                 return "", "", err
         }</span>
 
         // Sign the identity in the entity
-        <span class="cov0" title="0">for _, id := range entity.Identities </span><span class="cov0" title="0">{
+        <span class="cov10" title="3">for _, id := range entity.Identities </span><span class="cov10" title="3">{
                 err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
-                if err != nil </span><span class="cov0" title="0">{
-                        smslogger.WriteError(err.Error())
+                if smslogger.CheckError(err, "Sign Entity") != nil </span><span class="cov0" title="0">{
                         return "", "", err
                 }</span>
         }
 
         // Sign the subkey in the entity
-        <span class="cov0" title="0">for _, subkey := range entity.Subkeys </span><span class="cov0" title="0">{
+        <span class="cov10" title="3">for _, subkey := range entity.Subkeys </span><span class="cov10" title="3">{
                 err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil)
-                if err != nil </span><span class="cov0" title="0">{
-                        smslogger.WriteError(err.Error())
+                if smslogger.CheckError(err, "Sign Subkey") != nil </span><span class="cov0" title="0">{
                         return "", "", err
                 }</span>
         }
 
-        <span class="cov0" title="0">buffer := new(bytes.Buffer)
+        <span class="cov10" title="3">buffer := new(bytes.Buffer)
         entity.Serialize(buffer)
         pbkey := base64.StdEncoding.EncodeToString(buffer.Bytes())
 
@@ -186,40 +187,96 @@ func GeneratePGPKeyPair() (string, string, error) <span class="cov0" title="0">{
         return pbkey, prkey, nil</span>
 }
 
-// DecryptPGPBytes decrypts a PGP encoded input string and returns
+// EncryptPGPString takes data and a public key and encrypts using that
+// public key
+func EncryptPGPString(data string, pbKey string) (string, error) <span class="cov6" title="2">{
+
+        pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey)
+        if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil </span><span class="cov0" title="0">{
+                return "", err
+        }</span>
+
+        <span class="cov6" title="2">dataBytes := []byte(data)
+
+        pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes)))
+        if smslogger.CheckError(err, "Reading entity from PGP key") != nil </span><span class="cov0" title="0">{
+                return "", err
+        }</span>
+
+        // encrypt string
+        <span class="cov6" title="2">buf := new(bytes.Buffer)
+        out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil)
+        if smslogger.CheckError(err, "Creating Encryption Pipe") != nil </span><span class="cov0" title="0">{
+                return "", err
+        }</span>
+
+        <span class="cov6" title="2">_, err = out.Write(dataBytes)
+        if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil </span><span class="cov0" title="0">{
+                return "", err
+        }</span>
+
+        <span class="cov6" title="2">err = out.Close()
+        if smslogger.CheckError(err, "Closing Encryption Pipe") != nil </span><span class="cov0" title="0">{
+                return "", err
+        }</span>
+
+        <span class="cov6" title="2">crp := base64.StdEncoding.EncodeToString(buf.Bytes())
+        return crp, nil</span>
+}
+
+// DecryptPGPString decrypts a PGP encoded input string and returns
 // a base64 representation of the decoded string
-func DecryptPGPBytes(data string, prKey string) (string, error) <span class="cov0" title="0">{
+func DecryptPGPString(data string, prKey string) (string, error) <span class="cov1" title="1">{
+
         // Convert private key to bytes from base64
         prKeyBytes, err := base64.StdEncoding.DecodeString(prKey)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError("Error Decoding base64 private key: " + err.Error())
+        if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil </span><span class="cov0" title="0">{
                 return "", err
         }</span>
 
-        <span class="cov0" title="0">dataBytes, err := base64.StdEncoding.DecodeString(data)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError("Error Decoding base64 data: " + err.Error())
+        <span class="cov1" title="1">dataBytes, err := base64.StdEncoding.DecodeString(data)
+        if smslogger.CheckError(err, "Decoding base64 data") != nil </span><span class="cov0" title="0">{
                 return "", err
         }</span>
 
-        <span class="cov0" title="0">prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError("Error reading entity from PGP key: " + err.Error())
+        <span class="cov1" title="1">prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
+        if smslogger.CheckError(err, "Read Entity") != nil </span><span class="cov0" title="0">{
                 return "", err
         }</span>
 
-        <span class="cov0" title="0">prEntityList := &amp;openpgp.EntityList{prEntity}
+        <span class="cov1" title="1">prEntityList := &amp;openpgp.EntityList{prEntity}
         message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError("Error Decrypting message: " + err.Error())
+        if smslogger.CheckError(err, "Decrypting Message") != nil </span><span class="cov0" title="0">{
                 return "", err
         }</span>
 
-        <span class="cov0" title="0">var retBuf bytes.Buffer
+        <span class="cov1" title="1">var retBuf bytes.Buffer
         retBuf.ReadFrom(message.UnverifiedBody)
 
         return retBuf.String(), nil</span>
 }
+
+// ReadFromFile reads a file and loads the PGP key into
+// a string
+func ReadFromFile(fileName string) (string, error) <span class="cov6" title="2">{
+
+        data, err := ioutil.ReadFile(fileName)
+        if smslogger.CheckError(err, "Read from file") != nil </span><span class="cov0" title="0">{
+                return "", err
+        }</span>
+        <span class="cov6" title="2">return string(data), nil</span>
+}
+
+// WriteToFile writes a PGP key into a file.
+// It will truncate the file if it exists
+func WriteToFile(data string, fileName string) error <span class="cov0" title="0">{
+
+        err := ioutil.WriteFile(fileName, []byte(data), 0600)
+        if smslogger.CheckError(err, "Write to file") != nil </span><span class="cov0" title="0">{
+                return err
+        }</span>
+        <span class="cov0" title="0">return nil</span>
+}
 </pre>
                
                <pre class="file" id="file1" style="display: none">/*
@@ -264,6 +321,7 @@ type SecretBackend interface {
         Init() error
         GetStatus() (bool, error)
         Unseal(shard string) error
+        RegisterQuorum(pgpkey string) (string, error)
 
         GetSecret(dom string, sec string) (Secret, error)
         ListSecret(dom string) ([]string, error)
@@ -276,19 +334,18 @@ type SecretBackend interface {
 }
 
 // InitSecretBackend returns an interface implementation
-func InitSecretBackend() (SecretBackend, error) <span class="cov10" title="2">{
+func InitSecretBackend() (SecretBackend, error) <span class="cov8" title="1">{
         backendImpl := &amp;Vault{
-                vaultAddress: smsconfig.SMSConfig.VaultAddress,
+                vaultAddress: smsconfig.SMSConfig.BackendAddress,
                 vaultToken:   smsconfig.SMSConfig.VaultToken,
         }
 
         err := backendImpl.Init()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "InitSecretBackend") != nil </span><span class="cov0" title="0">{
                 return nil, err
         }</span>
 
-        <span class="cov10" title="2">return backendImpl, nil</span>
+        <span class="cov8" title="1">return backendImpl, nil</span>
 }
 
 // LoginBackend Interface that will be implemented for various login backends
@@ -330,68 +387,108 @@ import (
 // Vault is the main Struct used in Backend to initialize the struct
 type Vault struct {
         sync.Mutex
-        engineType        string
-        initRoleDone      bool
-        policyName        string
-        roleID            string
-        secretID          string
-        vaultAddress      string
-        vaultClient       *vaultapi.Client
-        vaultMount        string
-        vaultTempTokenTTL time.Time
-        vaultToken        string
-        unsealShards      []string
-        rootToken         string
-        pgpPub            string
-        pgpPr             string
+        initRoleDone          bool
+        policyName            string
+        roleID                string
+        secretID              string
+        vaultAddress          string
+        vaultClient           *vaultapi.Client
+        vaultMountPrefix      string
+        internalDomain        string
+        internalDomainMounted bool
+        vaultTempTokenTTL     time.Time
+        vaultToken            string
+        shards                []string
+        prkey                 string
 }
 
-// Init will initialize the vault connection
-// It will also create the initial policy if it does not exist
-// TODO: Check to see if we need to wait for vault to be running
-func (v *Vault) Init() error <span class="cov4" title="3">{
+// initVaultClient will create the initial
+// Vault strcuture and populate it with the
+// right values and it will also create
+// a vault client
+func (v *Vault) initVaultClient() error <span class="cov6" title="11">{
+
         vaultCFG := vaultapi.DefaultConfig()
         vaultCFG.Address = v.vaultAddress
         client, err := vaultapi.NewClient(vaultCFG)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
-                return errors.New("Unable to create new vault client")
+        if smslogger.CheckError(err, "Create new vault client") != nil </span><span class="cov0" title="0">{
+                return err
         }</span>
 
-        <span class="cov4" title="3">v.engineType = "kv"
-        v.initRoleDone = false
+        <span class="cov6" title="11">v.initRoleDone = false
         v.policyName = "smsvaultpolicy"
         v.vaultClient = client
-        v.vaultMount = "sms"
+        v.vaultMountPrefix = "sms"
+        v.internalDomain = "smsinternaldomain"
+        v.internalDomainMounted = false
+        v.prkey = ""
+        return nil</span>
+}
 
-        err = v.initRole()
-        if err != nil </span><span class="cov2" title="2">{
-                smslogger.WriteError(err.Error())
+// Init will initialize the vault connection
+// It will also initialize vault if it is not
+// already initialized.
+// The initial policy will also be created
+func (v *Vault) Init() error <span class="cov1" title="1">{
+
+        v.initVaultClient()
+        // Initialize vault if it is not already
+        // Returns immediately if it is initialized
+        v.initializeVault()
+
+        err := v.initRole()
+        if smslogger.CheckError(err, "InitRole First Attempt") != nil </span><span class="cov0" title="0">{
                 smslogger.WriteInfo("InitRole will try again later")
         }</span>
 
-        <span class="cov4" title="3">return nil</span>
+        <span class="cov1" title="1">return nil</span>
 }
 
 // GetStatus returns the current seal status of vault
-func (v *Vault) GetStatus() (bool, error) <span class="cov4" title="3">{
+func (v *Vault) GetStatus() (bool, error) <span class="cov3" title="3">{
+
         sys := v.vaultClient.Sys()
         sealStatus, err := sys.SealStatus()
-        if err != nil </span><span class="cov1" title="1">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Getting Status") != nil </span><span class="cov0" title="0">{
                 return false, errors.New("Error getting status")
         }</span>
 
-        <span class="cov2" title="2">return sealStatus.Sealed, nil</span>
+        <span class="cov3" title="3">return sealStatus.Sealed, nil</span>
+}
+
+// RegisterQuorum registers the PGP public key for a quorum client
+// We will return a shard to the client that is registering
+func (v *Vault) RegisterQuorum(pgpkey string) (string, error) <span class="cov0" title="0">{
+
+        v.Lock()
+        defer v.Unlock()
+
+        if v.shards == nil </span><span class="cov0" title="0">{
+                smslogger.WriteError("Invalid operation in RegisterQuorum")
+                return "", errors.New("Invalid operation")
+        }</span>
+        // Pop the slice
+        <span class="cov0" title="0">var sh string
+        sh, v.shards = v.shards[len(v.shards)-1], v.shards[:len(v.shards)-1]
+        if len(v.shards) == 0 </span><span class="cov0" title="0">{
+                v.shards = nil
+        }</span>
+
+        // Decrypt with SMS pgp Key
+        <span class="cov0" title="0">sh, _ = smsauth.DecryptPGPString(sh, v.prkey)
+        // Encrypt with Quorum client pgp key
+        sh, _ = smsauth.EncryptPGPString(sh, pgpkey)
+
+        return sh, nil</span>
 }
 
 // Unseal is a passthrough API that allows any
 // unseal or initialization processes for the backend
 func (v *Vault) Unseal(shard string) error <span class="cov0" title="0">{
+
         sys := v.vaultClient.Sys()
         _, err := sys.Unseal(shard)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Unseal Operation") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to execute unseal operation with specified shard")
         }</span>
 
@@ -401,80 +498,140 @@ func (v *Vault) Unseal(shard string) error <span class="cov0" title="0">{
 // GetSecret returns a secret mounted on a particular domain name
 // The secret itself is referenced via its name which translates to
 // a mount path in vault
-func (v *Vault) GetSecret(dom string, name string) (Secret, error) <span class="cov6" title="6">{
+func (v *Vault) GetSecret(dom string, name string) (Secret, error) <span class="cov5" title="7">{
+
         err := v.checkToken()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Tocken Check") != nil </span><span class="cov0" title="0">{
                 return Secret{}, errors.New("Token check failed")
         }</span>
 
-        <span class="cov6" title="6">dom = v.vaultMount + "/" + dom
+        <span class="cov5" title="7">dom = v.vaultMountPrefix + "/" + dom
 
         sec, err := v.vaultClient.Logical().Read(dom + "/" + name)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Read Secret") != nil </span><span class="cov0" title="0">{
                 return Secret{}, errors.New("Unable to read Secret at provided path")
         }</span>
 
         // sec and err are nil in the case where a path does not exist
-        <span class="cov6" title="6">if sec == nil </span><span class="cov0" title="0">{
+        <span class="cov5" title="7">if sec == nil </span><span class="cov0" title="0">{
                 smslogger.WriteWarn("Vault read was empty. Invalid Path")
                 return Secret{}, errors.New("Secret not found at the provided path")
         }</span>
 
-        <span class="cov6" title="6">return Secret{Name: name, Values: sec.Data}, nil</span>
+        <span class="cov5" title="7">return Secret{Name: name, Values: sec.Data}, nil</span>
 }
 
 // ListSecret returns a list of secret names on a particular domain
 // The values of the secret are not returned
-func (v *Vault) ListSecret(dom string) ([]string, error) <span class="cov2" title="2">{
+func (v *Vault) ListSecret(dom string) ([]string, error) <span class="cov3" title="3">{
+
         err := v.checkToken()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
                 return nil, errors.New("Token check failed")
         }</span>
 
-        <span class="cov2" title="2">dom = v.vaultMount + "/" + dom
+        <span class="cov3" title="3">dom = v.vaultMountPrefix + "/" + dom
 
         sec, err := v.vaultClient.Logical().List(dom)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Read Secret") != nil </span><span class="cov0" title="0">{
                 return nil, errors.New("Unable to read Secret at provided path")
         }</span>
 
         // sec and err are nil in the case where a path does not exist
-        <span class="cov2" title="2">if sec == nil </span><span class="cov0" title="0">{
+        <span class="cov3" title="3">if sec == nil </span><span class="cov0" title="0">{
                 smslogger.WriteWarn("Vaultclient returned empty data")
                 return nil, errors.New("Secret not found at the provided path")
         }</span>
 
-        <span class="cov2" title="2">val, ok := sec.Data["keys"].([]interface{})
+        <span class="cov3" title="3">val, ok := sec.Data["keys"].([]interface{})
         if !ok </span><span class="cov0" title="0">{
                 smslogger.WriteError("Secret not found at the provided path")
                 return nil, errors.New("Secret not found at the provided path")
         }</span>
 
-        <span class="cov2" title="2">retval := make([]string, len(val))
-        for i, v := range val </span><span class="cov6" title="6">{
+        <span class="cov3" title="3">retval := make([]string, len(val))
+        for i, v := range val </span><span class="cov5" title="7">{
                 retval[i] = fmt.Sprint(v)
         }</span>
 
-        <span class="cov2" title="2">return retval, nil</span>
+        <span class="cov3" title="3">return retval, nil</span>
+}
+
+// Mounts the internal Domain if its not already mounted
+func (v *Vault) mountInternalDomain(name string) error <span class="cov5" title="8">{
+
+        if v.internalDomainMounted </span><span class="cov1" title="1">{
+                return nil
+        }</span>
+
+        <span class="cov5" title="7">name = strings.TrimSpace(name)
+        mountPath := v.vaultMountPrefix + "/" + name
+        mountInput := &amp;vaultapi.MountInput{
+                Type:        "kv",
+                Description: "Mount point for domain: " + name,
+                Local:       false,
+                SealWrap:    false,
+                Config:      vaultapi.MountConfigInput{},
+        }
+
+        err := v.vaultClient.Sys().Mount(mountPath, mountInput)
+        if smslogger.CheckError(err, "Mount internal Domain") != nil </span><span class="cov1" title="1">{
+                if strings.Contains(err.Error(), "existing mount") </span><span class="cov1" title="1">{
+                        // It is already mounted
+                        v.internalDomainMounted = true
+                        return nil
+                }</span>
+                // Ran into some other error mounting it.
+                <span class="cov0" title="0">return errors.New("Unable to mount internal Domain")</span>
+        }
+
+        <span class="cov5" title="6">v.internalDomainMounted = true
+        return nil</span>
+}
+
+// Stores the UUID created for secretdomain in vault
+// under v.vaultMountPrefix / smsinternal domain
+func (v *Vault) storeUUID(uuid string, name string) error <span class="cov5" title="8">{
+
+        // Check if token is still valid
+        err := v.checkToken()
+        if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
+                return errors.New("Token Check failed")
+        }</span>
+
+        <span class="cov5" title="8">err = v.mountInternalDomain(v.internalDomain)
+        if smslogger.CheckError(err, "Mount Internal Domain") != nil </span><span class="cov0" title="0">{
+                return err
+        }</span>
+
+        <span class="cov5" title="8">secret := Secret{
+                Name: name,
+                Values: map[string]interface{}{
+                        "uuid": uuid,
+                },
+        }
+
+        err = v.CreateSecret(v.internalDomain, secret)
+        if smslogger.CheckError(err, "Write UUID to domain") != nil </span><span class="cov0" title="0">{
+                return err
+        }</span>
+
+        <span class="cov5" title="8">return nil</span>
 }
 
 // CreateSecretDomain mounts the kv backend on a path with the given name
-func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span class="cov2" title="2">{
+func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span class="cov5" title="8">{
+
         // Check if token is still valid
         err := v.checkToken()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
                 return SecretDomain{}, errors.New("Token Check failed")
         }</span>
 
-        <span class="cov2" title="2">name = strings.TrimSpace(name)
-        mountPath := v.vaultMount + "/" + name
+        <span class="cov5" title="8">name = strings.TrimSpace(name)
+        mountPath := v.vaultMountPrefix + "/" + name
         mountInput := &amp;vaultapi.MountInput{
-                Type:        v.engineType,
+                Type:        "kv",
                 Description: "Mount point for domain: " + name,
                 Local:       false,
                 SealWrap:    false,
@@ -482,171 +639,212 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span clas
         }
 
         err = v.vaultClient.Sys().Mount(mountPath, mountInput)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Create Domain") != nil </span><span class="cov0" title="0">{
                 return SecretDomain{}, errors.New("Unable to create Secret Domain")
         }</span>
 
-        <span class="cov2" title="2">uuid, _ := uuid.GenerateUUID()
-        return SecretDomain{uuid, name}, nil</span>
+        <span class="cov5" title="8">uuid, _ := uuid.GenerateUUID()
+        err = v.storeUUID(uuid, name)
+        if smslogger.CheckError(err, "Store UUID") != nil </span><span class="cov0" title="0">{
+                // Mount was successful at this point.
+                // Rollback the mount operation since we could not
+                // store the UUID for the mount.
+                v.vaultClient.Sys().Unmount(mountPath)
+                return SecretDomain{}, errors.New("Unable to store Secret Domain UUID. Retry")
+        }</span>
+
+        <span class="cov5" title="8">return SecretDomain{uuid, name}, nil</span>
 }
 
 // CreateSecret creates a secret mounted on a particular domain name
 // The secret itself is mounted on a path specified by name
-func (v *Vault) CreateSecret(dom string, sec Secret) error <span class="cov6" title="6">{
+func (v *Vault) CreateSecret(dom string, sec Secret) error <span class="cov7" title="18">{
+
         err := v.checkToken()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
                 return errors.New("Token check failed")
         }</span>
 
-        <span class="cov6" title="6">dom = v.vaultMount + "/" + dom
+        <span class="cov7" title="18">dom = v.vaultMountPrefix + "/" + dom
 
         // Vault return is empty on successful write
         // TODO: Check if values is not empty
         _, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Create Secret") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to create Secret at provided path")
         }</span>
 
-        <span class="cov6" title="6">return nil</span>
+        <span class="cov7" title="18">return nil</span>
 }
 
 // DeleteSecretDomain deletes a secret domain which translates to
 // an unmount operation on the given path in Vault
-func (v *Vault) DeleteSecretDomain(name string) error <span class="cov2" title="2">{
+func (v *Vault) DeleteSecretDomain(name string) error <span class="cov3" title="3">{
+
         err := v.checkToken()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
                 return errors.New("Token Check Failed")
         }</span>
 
-        <span class="cov2" title="2">name = strings.TrimSpace(name)
-        mountPath := v.vaultMount + "/" + name
+        <span class="cov3" title="3">name = strings.TrimSpace(name)
+        mountPath := v.vaultMountPrefix + "/" + name
 
         err = v.vaultClient.Sys().Unmount(mountPath)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Delete Domain") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to delete domain specified")
         }</span>
 
-        <span class="cov2" title="2">return nil</span>
+        <span class="cov3" title="3">return nil</span>
 }
 
 // DeleteSecret deletes a secret mounted on the path provided
-func (v *Vault) DeleteSecret(dom string, name string) error <span class="cov6" title="6">{
+func (v *Vault) DeleteSecret(dom string, name string) error <span class="cov5" title="7">{
+
         err := v.checkToken()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
                 return errors.New("Token check failed")
         }</span>
 
-        <span class="cov6" title="6">dom = v.vaultMount + "/" + dom
+        <span class="cov5" title="7">dom = v.vaultMountPrefix + "/" + dom
 
         // Vault return is empty on successful delete
         _, err = v.vaultClient.Logical().Delete(dom + "/" + name)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Delete Secret") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to delete Secret at provided path")
         }</span>
 
-        <span class="cov6" title="6">return nil</span>
+        <span class="cov5" title="7">return nil</span>
 }
 
-// initRole is called only once during the service bring up
-func (v *Vault) initRole() error <span class="cov4" title="3">{
+// initRole is called only once during SMS bring up
+// It initially creates a role and secret id associated with
+// that role. Later restarts will use the existing role-id
+// and secret-id stored on disk
+func (v *Vault) initRole() error <span class="cov10" title="56">{
+
+        if v.initRoleDone </span><span class="cov9" title="48">{
+                return nil
+        }</span>
+
         // Use the root token once here
-        v.vaultClient.SetToken(v.vaultToken)
+        <span class="cov5" title="8">v.vaultClient.SetToken(v.vaultToken)
         defer v.vaultClient.ClearToken()
 
-        rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
+        // Check if roleID and secretID has already been created
+        rID, error := smsauth.ReadFromFile("auth/role")
+        if error != nil </span><span class="cov5" title="7">{
+                smslogger.WriteWarn("Unable to find RoleID. Generating...")
+        }</span><span class="cov1" title="1"> else {
+                sID, error := smsauth.ReadFromFile("auth/secret")
+                if error != nil </span><span class="cov0" title="0">{
+                        smslogger.WriteWarn("Unable to find secretID. Generating...")
+                }</span><span class="cov1" title="1"> else {
+                        v.roleID = rID
+                        v.secretID = sID
+                        v.initRoleDone = true
+                        return nil
+                }</span>
+        }
+
+        <span class="cov5" title="7">rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
                         path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }`
         err := v.vaultClient.Sys().PutPolicy(v.policyName, rules)
-        if err != nil </span><span class="cov2" title="2">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Creating Policy") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to create policy for approle creation")
         }</span>
 
-        <span class="cov1" title="1">rName := v.vaultMount + "-role"
-        data := map[string]interface{}{
-                "token_ttl": "60m",
-                "policies":  [2]string{"default", v.policyName},
-        }
-
         //Check if applrole is mounted
-        authMounts, err := v.vaultClient.Sys().ListAuth()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        <span class="cov5" title="7">authMounts, err := v.vaultClient.Sys().ListAuth()
+        if smslogger.CheckError(err, "Mount Auth Backend") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to get mounted auth backends")
         }</span>
 
-        <span class="cov1" title="1">approleMounted := false
-        for k, v := range authMounts </span><span class="cov1" title="1">{
-                if v.Type == "approle" &amp;&amp; k == "approle/" </span><span class="cov1" title="1">{
+        <span class="cov5" title="7">approleMounted := false
+        for k, v := range authMounts </span><span class="cov5" title="7">{
+                if v.Type == "approle" &amp;&amp; k == "approle/" </span><span class="cov0" title="0">{
                         approleMounted = true
                         break</span>
                 }
         }
 
         // Mount approle in case its not already mounted
-        <span class="cov1" title="1">if !approleMounted </span><span class="cov0" title="0">{
+        <span class="cov5" title="7">if !approleMounted </span><span class="cov5" title="7">{
                 v.vaultClient.Sys().EnableAuth("approle", "approle", "")
         }</span>
 
+        <span class="cov5" title="7">rName := v.vaultMountPrefix + "-role"
+        data := map[string]interface{}{
+                "token_ttl": "60m",
+                "policies":  [2]string{"default", v.policyName},
+        }
+
         // Create a role-id
-        <span class="cov1" title="1">v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
+        v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
         sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id")
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Create RoleID") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to create role ID for approle")
         }</span>
-        <span class="cov1" title="1">v.roleID = sec.Data["role_id"].(string)
+        <span class="cov5" title="7">v.roleID = sec.Data["role_id"].(string)
 
         // Create a secret-id to go with it
         sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id",
                 map[string]interface{}{})
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Create SecretID") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to create secret ID for role")
         }</span>
 
-        <span class="cov1" title="1">v.secretID = sec.Data["secret_id"].(string)
+        <span class="cov5" title="7">v.secretID = sec.Data["secret_id"].(string)
         v.initRoleDone = true
+        /*
+        * Revoke the Root token.
+        * If a new Root Token is needed, it will need to be created
+        * using the unseal shards.
+         */
+        err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken)
+        if smslogger.CheckError(err, "Revoke Root Token") != nil </span><span class="cov0" title="0">{
+                smslogger.WriteWarn("Unable to Revoke Token")
+        }</span><span class="cov5" title="7"> else {
+                // Revoked successfully and clear it
+                v.vaultToken = ""
+        }</span>
+
+        // Store the role-id and secret-id
+        // We will need this if SMS restarts
+        <span class="cov5" title="7">smsauth.WriteToFile(v.roleID, "auth/role")
+        smsauth.WriteToFile(v.secretID, "auth/secret")
+
         return nil</span>
 }
 
 // Function checkToken() gets called multiple times to create
 // temporary tokens
-func (v *Vault) checkToken() error <span class="cov10" title="24">{
+func (v *Vault) checkToken() error <span class="cov9" title="54">{
+
         v.Lock()
         defer v.Unlock()
 
         // Init Role if it is not yet done
         // Role needs to be created before token can be created
-        if v.initRoleDone == false </span><span class="cov0" title="0">{
-                err := v.initRole()
-                if err != nil </span><span class="cov0" title="0">{
-                        smslogger.WriteError(err.Error())
-                        return errors.New("Unable to initRole in checkToken")
-                }</span>
-        }
+        err := v.initRole()
+        if err != nil </span><span class="cov0" title="0">{
+                smslogger.WriteError(err.Error())
+                return errors.New("Unable to initRole in checkToken")
+        }</span>
 
         // Return immediately if token still has life
-        <span class="cov10" title="24">if v.vaultClient.Token() != "" &amp;&amp;
-                time.Since(v.vaultTempTokenTTL) &lt; time.Minute*50 </span><span class="cov9" title="23">{
+        <span class="cov9" title="54">if v.vaultClient.Token() != "" &amp;&amp;
+                time.Since(v.vaultTempTokenTTL) &lt; time.Minute*50 </span><span class="cov9" title="47">{
                 return nil
         }</span>
 
         // Create a temporary token using our roleID and secretID
-        <span class="cov1" title="1">out, err := v.vaultClient.Logical().Write("auth/approle/login",
+        <span class="cov5" title="7">out, err := v.vaultClient.Logical().Write("auth/approle/login",
                 map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID})
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "Create Temp Token") != nil </span><span class="cov0" title="0">{
                 return errors.New("Unable to create Temporary Token for Role")
         }</span>
 
-        <span class="cov1" title="1">tok, err := out.TokenID()
+        <span class="cov5" title="7">tok, err := out.TokenID()
 
         v.vaultTempTokenTTL = time.Now()
         v.vaultClient.SetToken(tok)
@@ -655,31 +853,53 @@ func (v *Vault) checkToken() error <span class="cov10" title="24">{
 
 // vaultInit() is used to initialize the vault in cases where it is not
 // initialized. This happens once during intial bring up.
-func (v *Vault) initializeVault() error <span class="cov0" title="0">{
-        initReq := &amp;vaultapi.InitRequest{
-                SecretShares:    5,
+func (v *Vault) initializeVault() error <span class="cov2" title="2">{
+
+        // Check for vault init status and don't exit till it is initialized
+        for </span><span class="cov2" title="2">{
+                init, err := v.vaultClient.Sys().InitStatus()
+                if smslogger.CheckError(err, "Get Vault Init Status") != nil </span><span class="cov0" title="0">{
+                        smslogger.WriteInfo("Trying again in 10s...")
+                        time.Sleep(time.Second * 10)
+                        continue</span>
+                }
+                // Did not get any error
+                <span class="cov2" title="2">if init == true </span><span class="cov1" title="1">{
+                        smslogger.WriteInfo("Vault is already Initialized")
+                        return nil
+                }</span>
+
+                // init status is false
+                // break out of loop and finish initialization
+                <span class="cov1" title="1">smslogger.WriteInfo("Vault is not initialized. Initializing...")
+                break</span>
+        }
+
+        // Hardcoded this to 3. We should make this configurable
+        // in the future
+        <span class="cov1" title="1">initReq := &amp;vaultapi.InitRequest{
+                SecretShares:    3,
                 SecretThreshold: 3,
         }
 
         pbkey, prkey, err := smsauth.GeneratePGPKeyPair()
-        if err != nil </span><span class="cov0" title="0">{
+
+        if smslogger.CheckError(err, "Generating PGP Keys") != nil </span><span class="cov0" title="0">{
                 smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!")
-        }</span><span class="cov0" title="0"> else {
-                initReq.PGPKeys = []string{pbkey, pbkey, pbkey, pbkey, pbkey}
+        }</span><span class="cov1" title="1"> else {
+                initReq.PGPKeys = []string{pbkey, pbkey, pbkey}
                 initReq.RootTokenPGPKey = pbkey
-                v.pgpPub = pbkey
-                v.pgpPr = prkey
         }</span>
 
-        <span class="cov0" title="0">resp, err := v.vaultClient.Sys().Init(initReq)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        <span class="cov1" title="1">resp, err := v.vaultClient.Sys().Init(initReq)
+        if smslogger.CheckError(err, "Initialize Vault") != nil </span><span class="cov0" title="0">{
                 return errors.New("FATAL: Unable to initialize Vault")
         }</span>
 
-        <span class="cov0" title="0">if resp != nil </span><span class="cov0" title="0">{
-                v.unsealShards = resp.KeysB64
-                v.rootToken = resp.RootToken
+        <span class="cov1" title="1">if resp != nil </span><span class="cov1" title="1">{
+                v.prkey = prkey
+                v.shards = resp.KeysB64
+                v.vaultToken, _ = smsauth.DecryptPGPString(resp.RootToken, prkey)
                 return nil
         }</span>
 
@@ -708,6 +928,7 @@ package config
 import (
         "encoding/json"
         "os"
+        smslogger "sms/log"
 )
 
 // SMSConfiguration loads up all the values that are used to configure
@@ -718,8 +939,10 @@ type SMSConfiguration struct {
         ServerCert string `json:"servercert"`
         ServerKey  string `json:"serverkey"`
 
-        VaultAddress string `json:"vaultaddress"`
-        VaultToken   string `json:"vaulttoken"`
+        BackendAddress            string `json:"smsdbaddress"`
+        VaultToken                string `json:"vaulttoken"`
+        DisableTLS                bool   `json:"disable_tls"`
+        BackendAddressEnvVariable string `json:"smsdburlenv"`
 }
 
 // SMSConfig is the structure that stores the configuration
@@ -734,12 +957,19 @@ func ReadConfigFile(file string) (*SMSConfiguration, error) <span class="cov10"
                 }</span>
                 <span class="cov6" title="2">defer f.Close()
 
-                SMSConfig = &amp;SMSConfiguration{}
+                // Default behaviour is to enable TLS
+                SMSConfig = &amp;SMSConfiguration{DisableTLS: false}
                 decoder := json.NewDecoder(f)
                 err = decoder.Decode(SMSConfig)
                 if err != nil </span><span class="cov0" title="0">{
                         return nil, err
                 }</span>
+
+                <span class="cov6" title="2">if SMSConfig.BackendAddress == "" &amp;&amp; SMSConfig.BackendAddressEnvVariable != "" </span><span class="cov0" title="0">{
+                        // Get the value from ENV variable
+                        smslogger.WriteInfo("Using Environment Variable: " + SMSConfig.BackendAddressEnvVariable)
+                        SMSConfig.BackendAddress = os.Getenv(SMSConfig.BackendAddressEnvVariable)
+                }</span>
         }
 
         <span class="cov6" title="2">return SMSConfig, nil</span>
@@ -785,29 +1015,24 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
         var d smsbackend.SecretDomain
 
         err := json.NewDecoder(r.Body).Decode(&amp;d)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusBadRequest)
                 return
         }</span>
 
         <span class="cov6" title="3">dom, err := h.secretBackend.CreateSecretDomain(d.Name)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
 
-        <span class="cov6" title="3">jdata, err := json.Marshal(dom)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
+        w.WriteHeader(http.StatusCreated)
+        err = json.NewEncoder(w).Encode(dom)
+        if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
-
-        <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
-        w.WriteHeader(http.StatusCreated)
-        w.Write(jdata)</span>
 }
 
 // deleteSecretDomainHandler deletes a secret domain with the name provided
@@ -816,8 +1041,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques
         domName := vars["domName"]
 
         err := h.secretBackend.DeleteSecretDomain(domName)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
@@ -834,15 +1058,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) <sp
         // Get secrets to be stored from body
         var b smsbackend.Secret
         err := json.NewDecoder(r.Body).Decode(&amp;b)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "CreateSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusBadRequest)
                 return
         }</span>
 
         <span class="cov10" title="7">err = h.secretBackend.CreateSecret(domName, b)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "CreateSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
@@ -857,21 +1079,17 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) <span
         secName := vars["secretName"]
 
         sec, err := h.secretBackend.GetSecret(domName, secName)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "GetSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
 
-        <span class="cov10" title="7">jdata, err := json.Marshal(sec)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        <span class="cov10" title="7">w.Header().Set("Content-Type", "application/json")
+        err = json.NewEncoder(w).Encode(sec)
+        if smslogger.CheckError(err, "GetSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
-
-        <span class="cov10" title="7">w.Header().Set("Content-Type", "application/json")
-        w.Write(jdata)</span>
 }
 
 // listSecretHandler handles listing all secrets under a particular domain name
@@ -880,8 +1098,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) <span
         domName := vars["domName"]
 
         secList, err := h.secretBackend.ListSecret(domName)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "ListSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
@@ -893,15 +1110,12 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) <span
                 secList,
         }
 
-        jdata, err := json.Marshal(retStruct)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        w.Header().Set("Content-Type", "application/json")
+        err = json.NewEncoder(w).Encode(retStruct)
+        if smslogger.CheckError(err, "ListSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
-
-        <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
-        w.Write(jdata)</span>
 }
 
 // deleteSecretHandler handles deleting a secret by given domain name and secret name
@@ -911,37 +1125,34 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) <sp
         secName := vars["secretName"]
 
         err := h.secretBackend.DeleteSecret(domName, secName)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "DeleteSecretHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
-}
 
-// struct that tracks various status items for SMS and backend
-type backendStatus struct {
-        Seal bool `json:"sealstatus"`
+        <span class="cov10" title="7">w.WriteHeader(http.StatusNoContent)</span>
 }
 
 // statusHandler returns information related to SMS and SMS backend services
-func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) <span class="cov6" title="3">{
+func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) <span class="cov7" title="4">{
         s, err := h.secretBackend.GetStatus()
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "StatusHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
 
-        <span class="cov6" title="3">status := backendStatus{Seal: s}
-        jdata, err := json.Marshal(status)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        <span class="cov7" title="4">status := struct {
+                Seal bool `json:"sealstatus"`
+        }{
+                s,
+        }
+
+        w.Header().Set("Content-Type", "application/json")
+        err = json.NewEncoder(w).Encode(status)
+        if smslogger.CheckError(err, "StatusHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
-
-        <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
-        w.Write(jdata)</span>
 }
 
 // loginHandler handles login via password and username
@@ -961,15 +1172,53 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) <span cla
         decoder := json.NewDecoder(r.Body)
         decoder.DisallowUnknownFields()
         err := decoder.Decode(&amp;inp)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "UnsealHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, "Bad input JSON", http.StatusBadRequest)
                 return
         }</span>
 
         <span class="cov0" title="0">err = h.secretBackend.Unseal(inp.UnsealShard)
-        if err != nil </span><span class="cov0" title="0">{
-                smslogger.WriteError(err.Error())
+        if smslogger.CheckError(err, "UnsealHandler") != nil </span><span class="cov0" title="0">{
+                http.Error(w, err.Error(), http.StatusInternalServerError)
+                return
+        }</span>
+}
+
+// registerHandler allows the quorum clients to register with SMS
+// with their PGP public keys that are then used by sms for backend
+// initialization
+func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) <span class="cov1" title="1">{
+        // Get shards to be used for unseal
+        type registerStruct struct {
+                PGPKey   string `json:"pgpkey"`
+                QuorumID string `json:"quorumid"`
+        }
+
+        var inp registerStruct
+        decoder := json.NewDecoder(r.Body)
+        decoder.DisallowUnknownFields()
+        err := decoder.Decode(&amp;inp)
+        if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{
+                http.Error(w, "Bad input JSON", http.StatusBadRequest)
+                return
+        }</span>
+
+        <span class="cov1" title="1">sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey)
+        if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{
+                http.Error(w, err.Error(), http.StatusInternalServerError)
+                return
+        }</span>
+
+        // Creating a struct for return data
+        <span class="cov1" title="1">shStruct := struct {
+                Shard string `json:"shard"`
+        }{
+                sh,
+        }
+
+        w.Header().Set("Content-Type", "application/json")
+        err = json.NewEncoder(w).Encode(shStruct)
+        if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{
                 http.Error(w, err.Error(), http.StatusInternalServerError)
                 return
         }</span>
@@ -987,8 +1236,9 @@ func CreateRouter(b smsbackend.SecretBackend) http.Handler <span class="cov4" ti
 
         // Initialization APIs which will be used by quorum client
         // to unseal and to provide root token to sms service
-        router.HandleFunc("/v1/sms/status", h.statusHandler).Methods("GET")
-        router.HandleFunc("/v1/sms/unseal", h.unsealHandler).Methods("POST")
+        router.HandleFunc("/v1/sms/quorum/status", h.statusHandler).Methods("GET")
+        router.HandleFunc("/v1/sms/quorum/unseal", h.unsealHandler).Methods("POST")
+        router.HandleFunc("/v1/sms/quorum/register", h.registerHandler).Methods("POST")
 
         router.HandleFunc("/v1/sms/domain", h.createSecretDomainHandler).Methods("POST")
         router.HandleFunc("/v1/sms/domain/{domName}", h.deleteSecretDomainHandler).Methods("DELETE")
@@ -1021,53 +1271,85 @@ func CreateRouter(b smsbackend.SecretBackend) http.Handler <span class="cov4" ti
 package log
 
 import (
+        "fmt"
         "log"
         "os"
 )
 
-var errLogger *log.Logger
-var warnLogger *log.Logger
-var infoLogger *log.Logger
+var errL, warnL, infoL *log.Logger
+var stdErr, stdWarn, stdInfo *log.Logger
 
 // Init will be called by sms.go before any other packages use it
-func Init(filePath string) <span class="cov8" title="1">{
-        f, err := os.Create(filePath)
+func Init(filePath string) <span class="cov1" title="1">{
+
+        stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
+        stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
+        stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+
+        if filePath == "" </span><span class="cov0" title="0">{
+                // We will just to std streams
+                return
+        }</span>
+
+        <span class="cov1" title="1">f, err := os.Create(filePath)
         if err != nil </span><span class="cov0" title="0">{
-                log.Println("Unable to create a log file")
-                log.Println(err)
-                errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
-                warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
-                infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
-        }</span><span class="cov8" title="1"> else {
-                errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
-                warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
-                infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
+                stdErr.Println("Unable to create log file: " + err.Error())
+                return
         }</span>
+
+        <span class="cov1" title="1">errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
+        warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
+        infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)</span>
 }
 
 // WriteError writes output to the writer we have
-// defined durint its creation with ERROR prefix
+// defined during its creation with ERROR prefix
 func WriteError(msg string) <span class="cov0" title="0">{
-        if errLogger != nil </span><span class="cov0" title="0">{
-                errLogger.Println(msg)
+        if errL != nil </span><span class="cov0" title="0">{
+                errL.Output(2, fmt.Sprintln(msg))
+        }</span>
+        <span class="cov0" title="0">if stdErr != nil </span><span class="cov0" title="0">{
+                stdErr.Output(2, fmt.Sprintln(msg))
         }</span>
 }
 
 // WriteWarn writes output to the writer we have
-// defined durint its creation with WARNING prefix
+// defined during its creation with WARNING prefix
 func WriteWarn(msg string) <span class="cov0" title="0">{
-        if warnLogger != nil </span><span class="cov0" title="0">{
-                warnLogger.Println(msg)
+        if warnL != nil </span><span class="cov0" title="0">{
+                warnL.Output(2, fmt.Sprintln(msg))
+        }</span>
+        <span class="cov0" title="0">if stdWarn != nil </span><span class="cov0" title="0">{
+                stdWarn.Output(2, fmt.Sprintln(msg))
         }</span>
 }
 
 // WriteInfo writes output to the writer we have
-// defined durint its creation with INFO prefix
-func WriteInfo(msg string) <span class="cov0" title="0">{
-        if infoLogger != nil </span><span class="cov0" title="0">{
-                infoLogger.Println(msg)
+// defined during its creation with INFO prefix
+func WriteInfo(msg string) <span class="cov1" title="1">{
+        if infoL != nil </span><span class="cov1" title="1">{
+                infoL.Output(2, fmt.Sprintln(msg))
+        }</span>
+        <span class="cov1" title="1">if stdInfo != nil </span><span class="cov1" title="1">{
+                stdInfo.Output(2, fmt.Sprintln(msg))
         }</span>
 }
+
+//CheckError is a helper function to reduce
+//repitition of error checkign blocks of code
+func CheckError(err error, topic string) error <span class="cov10" title="116">{
+        if err != nil </span><span class="cov1" title="1">{
+                msg := topic + ": " + err.Error()
+                if errL != nil </span><span class="cov1" title="1">{
+                        errL.Output(2, fmt.Sprintln(msg))
+                }</span>
+                <span class="cov1" title="1">if stdErr != nil </span><span class="cov1" title="1">{
+                        stdErr.Output(2, fmt.Sprintln(msg))
+                }</span>
+                <span class="cov1" title="1">return err</span>
+        }
+        <span class="cov9" title="115">return nil</span>
+}
 </pre>
                
                <pre class="file" id="file6" style="display: none">/*
@@ -1119,16 +1401,9 @@ func main() <span class="cov8" title="1">{
 
         <span class="cov8" title="1">httpRouter := smshandler.CreateRouter(backendImpl)
 
-        // TODO: Use CA certificate from AAF
-        tlsConfig, err := smsauth.GetTLSConfig(smsConf.CAFile)
-        if err != nil </span><span class="cov0" title="0">{
-                log.Fatal(err)
-        }</span>
-
-        <span class="cov8" title="1">httpServer := &amp;http.Server{
-                Handler:   httpRouter,
-                Addr:      ":10443",
-                TLSConfig: tlsConfig,
+        httpServer := &amp;http.Server{
+                Handler: httpRouter,
+                Addr:    ":10443",
         }
 
         // Listener for SIGINT so that it returns cleanly
@@ -1141,8 +1416,22 @@ func main() <span class="cov8" title="1">{
                 close(connectionsClose)
         }</span>()
 
-        <span class="cov8" title="1">err = httpServer.ListenAndServeTLS(smsConf.ServerCert, smsConf.ServerKey)
-        if err != nil &amp;&amp; err != http.ErrServerClosed </span><span class="cov0" title="0">{
+        // Start in TLS mode by default
+        <span class="cov8" title="1">if smsConf.DisableTLS == true </span><span class="cov0" title="0">{
+                smslogger.WriteWarn("TLS is Disabled")
+                err = httpServer.ListenAndServe()
+        }</span><span class="cov8" title="1"> else {
+                // TODO: Use CA certificate from AAF
+                tlsConfig, err := smsauth.GetTLSConfig(smsConf.CAFile)
+                if err != nil </span><span class="cov0" title="0">{
+                        log.Fatal(err)
+                }</span>
+
+                <span class="cov8" title="1">httpServer.TLSConfig = tlsConfig
+                err = httpServer.ListenAndServeTLS(smsConf.ServerCert, smsConf.ServerKey)</span>
+        }
+
+        <span class="cov8" title="1">if err != nil &amp;&amp; err != http.ErrServerClosed </span><span class="cov0" title="0">{
                 log.Fatal(err)
         }</span>
 
index 05fc967..dfa1a26 100644 (file)
@@ -37,13 +37,13 @@ func loadPGPKeys(prKeyPath string, pbKeyPath string) (string, string, error) {
        var pbkey, prkey string
        generated := false
        prkey, err := smsauth.ReadFromFile(prKeyPath)
-       if err != nil {
-               smslogger.WriteWarn("No Private Key found. Generating...")
+       if smslogger.CheckError(err, "LoadPGP Private Key") != nil {
+               smslogger.WriteInfo("No Private Key found. Generating...")
                pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
                generated = true
        } else {
                pbkey, err = smsauth.ReadFromFile(pbKeyPath)
-               if err != nil {
+               if smslogger.CheckError(err, "LoadPGP Public Key") != nil {
                        smslogger.WriteWarn("No Public Key found. Generating...")
                        pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
                        generated = true
@@ -70,7 +70,7 @@ func main() {
        prKeyPath := filepath.Join("auth", podName, "prkey")
        shardPath := filepath.Join("auth", podName, "shard")
 
-       smslogger.Init("")
+       smslogger.Init("quorum.log")
        smslogger.WriteInfo("Starting Log for Quorum Client")
 
        /*
@@ -80,7 +80,7 @@ func main() {
                In Kubernetes, pod restarts will also change the hostname
        */
        myID, err := smsauth.ReadFromFile(idFilePath)
-       if err != nil {
+       if smslogger.CheckError(err, "Read ID") != nil {
                smslogger.WriteWarn("Unable to find an ID for this client. Generating...")
                myID, _ = uuid.GenerateUUID()
                smsauth.WriteToFile(myID, idFilePath)
@@ -93,7 +93,7 @@ func main() {
        */
        registrationDone := true
        myShard, err := smsauth.ReadFromFile(shardPath)
-       if err != nil {
+       if smslogger.CheckError(err, "Read Shard") != nil {
                smslogger.WriteWarn("Unable to find a shard file. Registering with SMS...")
                registrationDone = false
        }
@@ -160,8 +160,7 @@ func main() {
 
                //URL and Port is configured in config file
                response, err := client.Get(cfg.BackEndURL + "/v1/sms/quorum/status")
-               if err != nil {
-                       smslogger.WriteError("Unable to connect to SMS. Retrying...")
+               if smslogger.CheckError(err, "Connect to SMS") != nil {
                        continue
                }
 
@@ -178,8 +177,7 @@ func main() {
                        if !registrationDone {
                                body := strings.NewReader(`{"pgpkey":"` + pbkey + `","quorumid":"` + myID + `"}`)
                                res, err := client.Post(cfg.BackEndURL+"/v1/sms/quorum/register", "application/json", body)
-                               if err != nil {
-                                       smslogger.WriteError("Ran into error during registration. Retrying...")
+                               if smslogger.CheckError(err, "Register with SMS") != nil {
                                        continue
                                }
                                registrationDone = true
@@ -195,8 +193,8 @@ func main() {
                        body := strings.NewReader(`{"unsealshard":"` + decShard + `"}`)
                        //URL and PORT is configured via config file
                        response, err = client.Post(cfg.BackEndURL+"/v1/sms/quorum/unseal", "application/json", body)
-                       if err != nil {
-                               smslogger.WriteError("Error unsealing vault. Retrying... " + err.Error())
+                       if smslogger.CheckError(err, "Unsealing Vault") != nil {
+                               continue
                        }
                }
        }
index c7684c7..2c09256 100644 (file)
 [solve-meta]
   analyzer-name = "dep"
   analyzer-version = 1
-  inputs-digest = "d19e17a023506ab731b0f26c6fcfebe581d4d5194af094aecea5e72daddd3ead"
+  inputs-digest = "8280cde72a3ab78ad00d13c192de5920d188f3052f45884563896cab659469f9"
   solver-name = "gps-cdcl"
   solver-version = 1
index 7172505..038e31d 100644 (file)
@@ -29,39 +29,27 @@ import (
        smslogger "sms/log"
 )
 
-var tlsConfig *tls.Config
-
-func checkError(err error, topic string) error {
-       if err != nil {
-               smslogger.WriteError(topic + ": " + err.Error())
-               return err
-       }
-
-       return nil
-}
-
 // GetTLSConfig initializes a tlsConfig using the CA's certificate
 // This config is then used to enable the server for mutual TLS
 func GetTLSConfig(caCertFile string) (*tls.Config, error) {
+
        // Initialize tlsConfig once
-       if tlsConfig == nil {
-               caCert, err := ioutil.ReadFile(caCertFile)
+       caCert, err := ioutil.ReadFile(caCertFile)
 
-               if err != nil {
-                       return nil, err
-               }
+       if err != nil {
+               return nil, err
+       }
 
-               caCertPool := x509.NewCertPool()
-               caCertPool.AppendCertsFromPEM(caCert)
+       caCertPool := x509.NewCertPool()
+       caCertPool.AppendCertsFromPEM(caCert)
 
-               tlsConfig = &tls.Config{
-                       // Change to RequireAndVerify once we have mandatory certs
-                       ClientAuth: tls.VerifyClientCertIfGiven,
-                       ClientCAs:  caCertPool,
-                       MinVersion: tls.VersionTLS12,
-               }
-               tlsConfig.BuildNameToCertificate()
+       tlsConfig := &tls.Config{
+               // Change to RequireAndVerify once we have mandatory certs
+               ClientAuth: tls.VerifyClientCertIfGiven,
+               ClientCAs:  caCertPool,
+               MinVersion: tls.VersionTLS12,
        }
+       tlsConfig.BuildNameToCertificate()
        return tlsConfig, nil
 }
 
@@ -70,22 +58,21 @@ func GetTLSConfig(caCertFile string) (*tls.Config, error) {
 // A base64 encoded form of the public part of the entity
 // A base64 encoded form of the private key
 func GeneratePGPKeyPair() (string, string, error) {
+
        var entity *openpgp.Entity
        config := &packet.Config{
                DefaultHash: crypto.SHA256,
        }
 
        entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Create Entity") != nil {
                return "", "", err
        }
 
        // Sign the identity in the entity
        for _, id := range entity.Identities {
                err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
-               if err != nil {
-                       smslogger.WriteError(err.Error())
+               if smslogger.CheckError(err, "Sign Entity") != nil {
                        return "", "", err
                }
        }
@@ -93,8 +80,7 @@ func GeneratePGPKeyPair() (string, string, error) {
        // Sign the subkey in the entity
        for _, subkey := range entity.Subkeys {
                err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil)
-               if err != nil {
-                       smslogger.WriteError(err.Error())
+               if smslogger.CheckError(err, "Sign Subkey") != nil {
                        return "", "", err
                }
        }
@@ -113,32 +99,33 @@ func GeneratePGPKeyPair() (string, string, error) {
 // EncryptPGPString takes data and a public key and encrypts using that
 // public key
 func EncryptPGPString(data string, pbKey string) (string, error) {
+
        pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey)
-       if checkError(err, "Decoding Base64 Public Key") != nil {
+       if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil {
                return "", err
        }
 
        dataBytes := []byte(data)
 
        pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes)))
-       if checkError(err, "Reading entity from PGP key") != nil {
+       if smslogger.CheckError(err, "Reading entity from PGP key") != nil {
                return "", err
        }
 
        // encrypt string
        buf := new(bytes.Buffer)
        out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil)
-       if checkError(err, "Creating Encryption Pipe") != nil {
+       if smslogger.CheckError(err, "Creating Encryption Pipe") != nil {
                return "", err
        }
 
        _, err = out.Write(dataBytes)
-       if checkError(err, "Writing to Encryption Pipe") != nil {
+       if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil {
                return "", err
        }
 
        err = out.Close()
-       if checkError(err, "Closing Encryption Pipe") != nil {
+       if smslogger.CheckError(err, "Closing Encryption Pipe") != nil {
                return "", err
        }
 
@@ -149,29 +136,26 @@ func EncryptPGPString(data string, pbKey string) (string, error) {
 // DecryptPGPString decrypts a PGP encoded input string and returns
 // a base64 representation of the decoded string
 func DecryptPGPString(data string, prKey string) (string, error) {
+
        // Convert private key to bytes from base64
        prKeyBytes, err := base64.StdEncoding.DecodeString(prKey)
-       if err != nil {
-               smslogger.WriteError("Error Decoding base64 private key: " + err.Error())
+       if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil {
                return "", err
        }
 
        dataBytes, err := base64.StdEncoding.DecodeString(data)
-       if err != nil {
-               smslogger.WriteError("Error Decoding base64 data: " + err.Error())
+       if smslogger.CheckError(err, "Decoding base64 data") != nil {
                return "", err
        }
 
        prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
-       if err != nil {
-               smslogger.WriteError("Error reading entity from PGP key: " + err.Error())
+       if smslogger.CheckError(err, "Read Entity") != nil {
                return "", err
        }
 
        prEntityList := &openpgp.EntityList{prEntity}
        message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil)
-       if err != nil {
-               smslogger.WriteError("Error Decrypting message: " + err.Error())
+       if smslogger.CheckError(err, "Decrypting Message") != nil {
                return "", err
        }
 
@@ -186,13 +170,10 @@ func DecryptPGPString(data string, prKey string) (string, error) {
 func ReadFromFile(fileName string) (string, error) {
 
        data, err := ioutil.ReadFile(fileName)
-       if err != nil {
-               smslogger.WriteError(err.Error())
-               smslogger.WriteError("Cannot read file: " + fileName)
+       if smslogger.CheckError(err, "Read from file") != nil {
                return "", err
        }
        return string(data), nil
-
 }
 
 // WriteToFile writes a PGP key into a file.
@@ -200,11 +181,8 @@ func ReadFromFile(fileName string) (string, error) {
 func WriteToFile(data string, fileName string) error {
 
        err := ioutil.WriteFile(fileName, []byte(data), 0600)
-       if err != nil {
-               smslogger.WriteError(err.Error())
-               smslogger.WriteError("Cannot write to file: " + fileName)
+       if smslogger.CheckError(err, "Write to file") != nil {
                return err
        }
        return nil
-
 }
index c137636..d7662ef 100644 (file)
@@ -60,8 +60,7 @@ func InitSecretBackend() (SecretBackend, error) {
        }
 
        err := backendImpl.Init()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "InitSecretBackend") != nil {
                return nil, err
        }
 
index e26baff..7fee097 100644 (file)
@@ -56,9 +56,8 @@ func (v *Vault) initVaultClient() error {
        vaultCFG := vaultapi.DefaultConfig()
        vaultCFG.Address = v.vaultAddress
        client, err := vaultapi.NewClient(vaultCFG)
-       if err != nil {
-               smslogger.WriteError(err.Error())
-               return errors.New("Unable to create new vault client")
+       if smslogger.CheckError(err, "Create new vault client") != nil {
+               return err
        }
 
        v.initRoleDone = false
@@ -69,7 +68,6 @@ func (v *Vault) initVaultClient() error {
        v.internalDomainMounted = false
        v.prkey = ""
        return nil
-
 }
 
 // Init will initialize the vault connection
@@ -84,8 +82,7 @@ func (v *Vault) Init() error {
        v.initializeVault()
 
        err := v.initRole()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "InitRole First Attempt") != nil {
                smslogger.WriteInfo("InitRole will try again later")
        }
 
@@ -94,10 +91,10 @@ func (v *Vault) Init() error {
 
 // GetStatus returns the current seal status of vault
 func (v *Vault) GetStatus() (bool, error) {
+
        sys := v.vaultClient.Sys()
        sealStatus, err := sys.SealStatus()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Getting Status") != nil {
                return false, errors.New("Error getting status")
        }
 
@@ -112,7 +109,7 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
        defer v.Unlock()
 
        if v.shards == nil {
-               smslogger.WriteError("Invalid operation")
+               smslogger.WriteError("Invalid operation in RegisterQuorum")
                return "", errors.New("Invalid operation")
        }
        // Pop the slice
@@ -133,10 +130,10 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
 // Unseal is a passthrough API that allows any
 // unseal or initialization processes for the backend
 func (v *Vault) Unseal(shard string) error {
+
        sys := v.vaultClient.Sys()
        _, err := sys.Unseal(shard)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Unseal Operation") != nil {
                return errors.New("Unable to execute unseal operation with specified shard")
        }
 
@@ -147,17 +144,16 @@ func (v *Vault) Unseal(shard string) error {
 // The secret itself is referenced via its name which translates to
 // a mount path in vault
 func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
+
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Tocken Check") != nil {
                return Secret{}, errors.New("Token check failed")
        }
 
        dom = v.vaultMountPrefix + "/" + dom
 
        sec, err := v.vaultClient.Logical().Read(dom + "/" + name)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Read Secret") != nil {
                return Secret{}, errors.New("Unable to read Secret at provided path")
        }
 
@@ -173,17 +169,16 @@ func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
 // ListSecret returns a list of secret names on a particular domain
 // The values of the secret are not returned
 func (v *Vault) ListSecret(dom string) ([]string, error) {
+
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Token Check") != nil {
                return nil, errors.New("Token check failed")
        }
 
        dom = v.vaultMountPrefix + "/" + dom
 
        sec, err := v.vaultClient.Logical().List(dom)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Read Secret") != nil {
                return nil, errors.New("Unable to read Secret at provided path")
        }
 
@@ -209,6 +204,7 @@ func (v *Vault) ListSecret(dom string) ([]string, error) {
 
 // Mounts the internal Domain if its not already mounted
 func (v *Vault) mountInternalDomain(name string) error {
+
        if v.internalDomainMounted {
                return nil
        }
@@ -224,14 +220,13 @@ func (v *Vault) mountInternalDomain(name string) error {
        }
 
        err := v.vaultClient.Sys().Mount(mountPath, mountInput)
-       if err != nil {
+       if smslogger.CheckError(err, "Mount internal Domain") != nil {
                if strings.Contains(err.Error(), "existing mount") {
                        // It is already mounted
                        v.internalDomainMounted = true
                        return nil
                }
                // Ran into some other error mounting it.
-               smslogger.WriteError(err.Error())
                return errors.New("Unable to mount internal Domain")
        }
 
@@ -242,16 +237,15 @@ func (v *Vault) mountInternalDomain(name string) error {
 // Stores the UUID created for secretdomain in vault
 // under v.vaultMountPrefix / smsinternal domain
 func (v *Vault) storeUUID(uuid string, name string) error {
+
        // Check if token is still valid
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Token Check") != nil {
                return errors.New("Token Check failed")
        }
 
        err = v.mountInternalDomain(v.internalDomain)
-       if err != nil {
-               smslogger.WriteError("Could not mount internal domain")
+       if smslogger.CheckError(err, "Mount Internal Domain") != nil {
                return err
        }
 
@@ -263,8 +257,7 @@ func (v *Vault) storeUUID(uuid string, name string) error {
        }
 
        err = v.CreateSecret(v.internalDomain, secret)
-       if err != nil {
-               smslogger.WriteError("Unable to write UUID to internal domain")
+       if smslogger.CheckError(err, "Write UUID to domain") != nil {
                return err
        }
 
@@ -273,10 +266,10 @@ func (v *Vault) storeUUID(uuid string, name string) error {
 
 // CreateSecretDomain mounts the kv backend on a path with the given name
 func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
+
        // Check if token is still valid
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Token Check") != nil {
                return SecretDomain{}, errors.New("Token Check failed")
        }
 
@@ -291,14 +284,13 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
        }
 
        err = v.vaultClient.Sys().Mount(mountPath, mountInput)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Create Domain") != nil {
                return SecretDomain{}, errors.New("Unable to create Secret Domain")
        }
 
        uuid, _ := uuid.GenerateUUID()
        err = v.storeUUID(uuid, name)
-       if err != nil {
+       if smslogger.CheckError(err, "Store UUID") != nil {
                // Mount was successful at this point.
                // Rollback the mount operation since we could not
                // store the UUID for the mount.
@@ -312,9 +304,9 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
 // CreateSecret creates a secret mounted on a particular domain name
 // The secret itself is mounted on a path specified by name
 func (v *Vault) CreateSecret(dom string, sec Secret) error {
+
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Token Check") != nil {
                return errors.New("Token check failed")
        }
 
@@ -323,8 +315,7 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
        // Vault return is empty on successful write
        // TODO: Check if values is not empty
        _, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Create Secret") != nil {
                return errors.New("Unable to create Secret at provided path")
        }
 
@@ -334,9 +325,9 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
 // DeleteSecretDomain deletes a secret domain which translates to
 // an unmount operation on the given path in Vault
 func (v *Vault) DeleteSecretDomain(name string) error {
+
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Token Check") != nil {
                return errors.New("Token Check Failed")
        }
 
@@ -344,8 +335,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
        mountPath := v.vaultMountPrefix + "/" + name
 
        err = v.vaultClient.Sys().Unmount(mountPath)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Delete Domain") != nil {
                return errors.New("Unable to delete domain specified")
        }
 
@@ -356,8 +346,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
 func (v *Vault) DeleteSecret(dom string, name string) error {
 
        err := v.checkToken()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Token Check") != nil {
                return errors.New("Token check failed")
        }
 
@@ -365,8 +354,7 @@ func (v *Vault) DeleteSecret(dom string, name string) error {
 
        // Vault return is empty on successful delete
        _, err = v.vaultClient.Logical().Delete(dom + "/" + name)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Delete Secret") != nil {
                return errors.New("Unable to delete Secret at provided path")
        }
 
@@ -406,15 +394,13 @@ func (v *Vault) initRole() error {
        rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
                        path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }`
        err := v.vaultClient.Sys().PutPolicy(v.policyName, rules)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Creating Policy") != nil {
                return errors.New("Unable to create policy for approle creation")
        }
 
        //Check if applrole is mounted
        authMounts, err := v.vaultClient.Sys().ListAuth()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Mount Auth Backend") != nil {
                return errors.New("Unable to get mounted auth backends")
        }
 
@@ -440,8 +426,7 @@ func (v *Vault) initRole() error {
        // Create a role-id
        v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
        sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id")
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Create RoleID") != nil {
                return errors.New("Unable to create role ID for approle")
        }
        v.roleID = sec.Data["role_id"].(string)
@@ -449,8 +434,7 @@ func (v *Vault) initRole() error {
        // Create a secret-id to go with it
        sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id",
                map[string]interface{}{})
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Create SecretID") != nil {
                return errors.New("Unable to create secret ID for role")
        }
 
@@ -462,8 +446,7 @@ func (v *Vault) initRole() error {
        * using the unseal shards.
         */
        err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken)
-       if err != nil {
-               smslogger.WriteWarn(err.Error())
+       if smslogger.CheckError(err, "Revoke Root Token") != nil {
                smslogger.WriteWarn("Unable to Revoke Token")
        } else {
                // Revoked successfully and clear it
@@ -481,6 +464,7 @@ func (v *Vault) initRole() error {
 // Function checkToken() gets called multiple times to create
 // temporary tokens
 func (v *Vault) checkToken() error {
+
        v.Lock()
        defer v.Unlock()
 
@@ -501,8 +485,7 @@ func (v *Vault) checkToken() error {
        // Create a temporary token using our roleID and secretID
        out, err := v.vaultClient.Logical().Write("auth/approle/login",
                map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID})
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Create Temp Token") != nil {
                return errors.New("Unable to create Temporary Token for Role")
        }
 
@@ -516,11 +499,12 @@ func (v *Vault) checkToken() error {
 // vaultInit() is used to initialize the vault in cases where it is not
 // initialized. This happens once during intial bring up.
 func (v *Vault) initializeVault() error {
+
        // Check for vault init status and don't exit till it is initialized
        for {
                init, err := v.vaultClient.Sys().InitStatus()
-               if err != nil {
-                       smslogger.WriteError("Unable to get initStatus, trying again in 10s: " + err.Error())
+               if smslogger.CheckError(err, "Get Vault Init Status") != nil {
+                       smslogger.WriteInfo("Trying again in 10s...")
                        time.Sleep(time.Second * 10)
                        continue
                }
@@ -545,7 +529,7 @@ func (v *Vault) initializeVault() error {
 
        pbkey, prkey, err := smsauth.GeneratePGPKeyPair()
 
-       if err != nil {
+       if smslogger.CheckError(err, "Generating PGP Keys") != nil {
                smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!")
        } else {
                initReq.PGPKeys = []string{pbkey, pbkey, pbkey}
@@ -553,8 +537,7 @@ func (v *Vault) initializeVault() error {
        }
 
        resp, err := v.vaultClient.Sys().Init(initReq)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "Initialize Vault") != nil {
                return errors.New("FATAL: Unable to initialize Vault")
        }
 
index 484c395..4862665 100644 (file)
 package backend
 
 import (
+       vaultapi "github.com/hashicorp/vault/api"
        credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
        vaulthttp "github.com/hashicorp/vault/http"
        vaultlogical "github.com/hashicorp/vault/logical"
+       vaultinmem "github.com/hashicorp/vault/physical/inmem"
        vaulttesting "github.com/hashicorp/vault/vault"
        "reflect"
        smslog "sms/log"
@@ -229,3 +231,39 @@ func TestDeleteSecret(t *testing.T) {
                t.Fatal("DeleteSecret: Error Creating secret")
        }
 }
+
+func TestInitializeVault(t *testing.T) {
+
+       inm, err := vaultinmem.NewInmem(nil, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       core, err := vaulttesting.NewCore(&vaulttesting.CoreConfig{
+               DisableMlock: true,
+               DisableCache: true,
+               Physical:     inm,
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       ln, addr := vaulthttp.TestServer(t, core)
+       defer ln.Close()
+
+       client, err := vaultapi.NewClient(&vaultapi.Config{
+               Address: addr,
+       })
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       v := &Vault{}
+       v.initVaultClient()
+       v.vaultClient = client
+
+       err = v.initializeVault()
+       if err != nil {
+               t.Fatal("InitializeVault: Error initializing Vault")
+       }
+}
index dbf3f93..7ce9e01 100644 (file)
@@ -37,15 +37,13 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
        var d smsbackend.SecretDomain
 
        err := json.NewDecoder(r.Body).Decode(&d)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
                http.Error(w, err.Error(), http.StatusBadRequest)
                return
        }
 
        dom, err := h.secretBackend.CreateSecretDomain(d.Name)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -53,8 +51,7 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
        w.Header().Set("Content-Type", "application/json")
        w.WriteHeader(http.StatusCreated)
        err = json.NewEncoder(w).Encode(dom)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -66,8 +63,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques
        domName := vars["domName"]
 
        err := h.secretBackend.DeleteSecretDomain(domName)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -84,15 +80,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) {
        // Get secrets to be stored from body
        var b smsbackend.Secret
        err := json.NewDecoder(r.Body).Decode(&b)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "CreateSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusBadRequest)
                return
        }
 
        err = h.secretBackend.CreateSecret(domName, b)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "CreateSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -107,16 +101,14 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) {
        secName := vars["secretName"]
 
        sec, err := h.secretBackend.GetSecret(domName, secName)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "GetSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
 
        w.Header().Set("Content-Type", "application/json")
        err = json.NewEncoder(w).Encode(sec)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "GetSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -128,8 +120,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
        domName := vars["domName"]
 
        secList, err := h.secretBackend.ListSecret(domName)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "ListSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -143,8 +134,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
 
        w.Header().Set("Content-Type", "application/json")
        err = json.NewEncoder(w).Encode(retStruct)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "ListSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -157,8 +147,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
        secName := vars["secretName"]
 
        err := h.secretBackend.DeleteSecret(domName, secName)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "DeleteSecretHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -169,8 +158,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
 // statusHandler returns information related to SMS and SMS backend services
 func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
        s, err := h.secretBackend.GetStatus()
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "StatusHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -183,8 +171,7 @@ func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
 
        w.Header().Set("Content-Type", "application/json")
        err = json.NewEncoder(w).Encode(status)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "StatusHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -207,15 +194,13 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) {
        decoder := json.NewDecoder(r.Body)
        decoder.DisallowUnknownFields()
        err := decoder.Decode(&inp)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "UnsealHandler") != nil {
                http.Error(w, "Bad input JSON", http.StatusBadRequest)
                return
        }
 
        err = h.secretBackend.Unseal(inp.UnsealShard)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "UnsealHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -235,15 +220,13 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
        decoder := json.NewDecoder(r.Body)
        decoder.DisallowUnknownFields()
        err := decoder.Decode(&inp)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "RegisterHandler") != nil {
                http.Error(w, "Bad input JSON", http.StatusBadRequest)
                return
        }
 
        sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey)
-       if err != nil {
-               smslogger.WriteError(err.Error())
+       if smslogger.CheckError(err, "RegisterHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
@@ -257,8 +240,7 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
 
        w.Header().Set("Content-Type", "application/json")
        err = json.NewEncoder(w).Encode(shStruct)
-       if err != nil {
-               smslogger.WriteError("Unable to encode response: " + err.Error())
+       if smslogger.CheckError(err, "RegisterHandler") != nil {
                http.Error(w, err.Error(), http.StatusInternalServerError)
                return
        }
index 52637f3..c1e55ed 100644 (file)
@@ -48,7 +48,7 @@ func (b *TestBackend) Unseal(shard string) error {
 }
 
 func (b *TestBackend) RegisterQuorum(pgpkey string) (string, error) {
-       return "", nil
+       return "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=", nil
 }
 
 func (b *TestBackend) GetSecret(dom string, sec string) (smsbackend.Secret, error) {
@@ -127,8 +127,49 @@ func TestStatusHandler(t *testing.T) {
        }
 }
 
+func TestRegisterHandler(t *testing.T) {
+       body := `{
+               "pgpkey":"asdasdasdasdgkjgljoiwera",
+               "quorumid":"123e4567-e89b-12d3-a456-426655440000"
+       }`
+       reader := strings.NewReader(body)
+       req, err := http.NewRequest("POST", "/v1/sms/quorum/register", reader)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       rr := httptest.NewRecorder()
+       hr := http.HandlerFunc(h.registerHandler)
+
+       hr.ServeHTTP(rr, req)
+
+       ret := rr.Code
+       if ret != http.StatusOK {
+               t.Errorf("registerHandler returned wrong status code: %v vs %v",
+                       ret, http.StatusOK)
+       }
+
+       expected := struct {
+               Shard string `json:"shard"`
+       }{
+               "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=",
+       }
+       got := struct {
+               Shard string `json:"shard"`
+       }{}
+
+       json.NewDecoder(rr.Body).Decode(&got)
+
+       if reflect.DeepEqual(expected, got) == false {
+               t.Errorf("statusHandler returned unexpected body: got %v vs %v",
+                       rr.Body.String(), expected)
+       }
+}
+
 func TestUnsealHandler(t *testing.T) {
-       req, err := http.NewRequest("GET", "/v1/sms/quorum/unseal", nil)
+       body := `{"unsealshard":"N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU="}`
+       reader := strings.NewReader(body)
+       req, err := http.NewRequest("POST", "/v1/sms/quorum/unseal", reader)
        if err != nil {
                t.Fatal(err)
        }
index 25da593..660f1ce 100644 (file)
 package log
 
 import (
+       "fmt"
        "log"
        "os"
 )
 
-var errLogger *log.Logger
-var warnLogger *log.Logger
-var infoLogger *log.Logger
+var errL, warnL, infoL *log.Logger
+var stdErr, stdWarn, stdInfo *log.Logger
 
 // Init will be called by sms.go before any other packages use it
 func Init(filePath string) {
+
+       stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
+       stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
+       stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+
        if filePath == "" {
-               errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
-               warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
-               infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+               // We will just to std streams
                return
        }
 
        f, err := os.Create(filePath)
        if err != nil {
-               log.Println("Unable to create a log file")
-               log.Println(err)
-               errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
-               warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
-               infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
-       } else {
-               errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
-               warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
-               infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
+               stdErr.Println("Unable to create log file: " + err.Error())
+               return
        }
+
+       errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
+       warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
+       infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
 }
 
 // WriteError writes output to the writer we have
-// defined durint its creation with ERROR prefix
+// defined during its creation with ERROR prefix
 func WriteError(msg string) {
-       if errLogger != nil {
-               errLogger.Println(msg)
+       if errL != nil {
+               errL.Output(2, fmt.Sprintln(msg))
+       }
+       if stdErr != nil {
+               stdErr.Output(2, fmt.Sprintln(msg))
        }
 }
 
 // WriteWarn writes output to the writer we have
-// defined durint its creation with WARNING prefix
+// defined during its creation with WARNING prefix
 func WriteWarn(msg string) {
-       if warnLogger != nil {
-               warnLogger.Println(msg)
+       if warnL != nil {
+               warnL.Output(2, fmt.Sprintln(msg))
+       }
+       if stdWarn != nil {
+               stdWarn.Output(2, fmt.Sprintln(msg))
        }
 }
 
 // WriteInfo writes output to the writer we have
-// defined durint its creation with INFO prefix
+// defined during its creation with INFO prefix
 func WriteInfo(msg string) {
-       if infoLogger != nil {
-               infoLogger.Println(msg)
+       if infoL != nil {
+               infoL.Output(2, fmt.Sprintln(msg))
+       }
+       if stdInfo != nil {
+               stdInfo.Output(2, fmt.Sprintln(msg))
+       }
+}
+
+//CheckError is a helper function to reduce
+//repitition of error checkign blocks of code
+func CheckError(err error, topic string) error {
+       if err != nil {
+               msg := topic + ": " + err.Error()
+               if errL != nil {
+                       errL.Output(2, fmt.Sprintln(msg))
+               }
+               if stdErr != nil {
+                       stdErr.Output(2, fmt.Sprintln(msg))
+               }
+               return err
        }
+       return nil
 }