Skip to content

Commit 4b9d8e5

Browse files
committed
Some cleanup in error handling, better logging, and more logical program exit behaviour
1 parent cec127d commit 4b9d8e5

File tree

1 file changed

+73
-46
lines changed

1 file changed

+73
-46
lines changed

cf-tlsa-acmesh.go

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,26 @@ func main() {
5050
requiredEnvVars := []string{"KEY_FILE", "KEY_FILE_NEXT", "ZONE_ID", "API_TOKEN", "DOMAIN"}
5151
for _, envVar := range requiredEnvVars {
5252
if os.Getenv(envVar) == "" {
53-
log.Println("Error:", envVar, "environment variable is not defined")
54-
os.Exit(1)
53+
log.Fatalln("Fatal:", envVar, "environment variable is not defined")
5554
}
5655
}
5756

5857
cert, err := generateCert(os.Getenv("KEY_FILE"))
5958
if err != nil {
60-
log.Println("Error generating cert:", err)
61-
os.Exit(1)
59+
log.Fatalln("Fatal: failed to generate current cert:", err)
6260
}
6361

6462
certNext, err := generateCert(os.Getenv("KEY_FILE_NEXT"))
6563
if err != nil {
66-
log.Println("Error generating next cert:", err)
67-
os.Exit(1)
64+
log.Fatalln("Fatal: failed to generate next cert:", err)
6865
}
6966

7067
log.Println("Current cert:", cert)
7168
log.Println("Next cert:", certNext)
7269

7370
tlsaRecords, err := getTLSARecords()
7471
if err != nil {
75-
log.Println("Error:", err)
76-
return
72+
log.Fatalln("Fatal: failed to get TLSA records:", err)
7773
}
7874

7975
for i, record := range tlsaRecords {
@@ -82,27 +78,49 @@ func main() {
8278

8379
if len(tlsaRecords) != 2 {
8480
log.Println("Incorrect number of DNS entries. Deleting them and generating new ones.")
85-
deleteAll(tlsaRecords)
86-
addRequest(certNext)
87-
addRequest(cert)
88-
return
81+
82+
err = deleteAll(tlsaRecords)
83+
if err != nil {
84+
log.Fatalln("Fatal: failed to delete all TLSA recors:", err)
85+
}
86+
87+
err = addRequest(certNext)
88+
if err != nil {
89+
log.Fatalln("Fatal: failed to add TLSA record for current cert:", err)
90+
}
91+
92+
err = addRequest(cert)
93+
if err != nil {
94+
log.Fatalln("Fatal: failed to add TLSA record for next cert:", err)
95+
}
96+
97+
os.Exit(0)
8998
}
9099

91-
if (checkData(tlsaRecords[0], cert) && checkData(tlsaRecords[1], certNext)) ||
92-
(checkData(tlsaRecords[0], certNext) && checkData(tlsaRecords[1], cert)) {
100+
switch {
101+
case (checkData(tlsaRecords[0], cert) && checkData(tlsaRecords[1], certNext)) ||
102+
(checkData(tlsaRecords[0], certNext) && checkData(tlsaRecords[1], cert)):
93103
log.Println("Nothing to do!")
94-
} else if checkData(tlsaRecords[0], cert) {
95-
modifyRequest(certNext, tlsaRecords[1].ID)
96-
} else if checkData(tlsaRecords[0], certNext) {
97-
modifyRequest(cert, tlsaRecords[1].ID)
98-
} else if checkData(tlsaRecords[1], cert) {
99-
modifyRequest(certNext, tlsaRecords[0].ID)
100-
} else if checkData(tlsaRecords[1], certNext) {
101-
modifyRequest(cert, tlsaRecords[0].ID)
102-
} else {
103-
modifyRequest(certNext, tlsaRecords[1].ID)
104-
modifyRequest(cert, tlsaRecords[0].ID)
104+
case checkData(tlsaRecords[0], cert):
105+
err = modifyRequest(certNext, tlsaRecords[1].ID)
106+
case checkData(tlsaRecords[0], certNext):
107+
err = modifyRequest(cert, tlsaRecords[1].ID)
108+
case checkData(tlsaRecords[1], cert):
109+
err = modifyRequest(certNext, tlsaRecords[0].ID)
110+
case checkData(tlsaRecords[1], certNext):
111+
err = modifyRequest(cert, tlsaRecords[0].ID)
112+
default:
113+
err = modifyRequest(certNext, tlsaRecords[1].ID)
114+
if err != nil {
115+
break
116+
}
117+
err = modifyRequest(cert, tlsaRecords[0].ID)
105118
}
119+
if err != nil {
120+
log.Fatalln("Fatal: failed to modify TLSA records:", err)
121+
}
122+
123+
os.Exit(0)
106124
}
107125

108126
func getTLSARecords() ([]tlsaRecord, error) {
@@ -131,12 +149,18 @@ func getTLSARecords() ([]tlsaRecord, error) {
131149
}
132150
}(resp.Body)
133151

152+
body, err := io.ReadAll(resp.Body)
153+
if err != nil {
154+
return nil, fmt.Errorf("failed reading response body: %v", err)
155+
}
156+
134157
if resp.StatusCode != http.StatusOK {
135-
return nil, fmt.Errorf("HTTP request failed with status code: %s", resp.Status)
158+
return nil, fmt.Errorf("recieved %d HTTP response status code for GET request, response body: %s", resp.StatusCode, string(body))
136159
}
137160

138161
var response tlsaRecordsResponse
139-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
162+
err = json.Unmarshal(body, &response)
163+
if err != nil {
140164
return nil, fmt.Errorf("failed to decode JSON response: %v", err)
141165
}
142166

@@ -171,18 +195,23 @@ func generateCert(keyPath string) (string, error) {
171195
return hex.EncodeToString(hashSum), nil
172196
}
173197

174-
func deleteAll(tlsaRecords []tlsaRecord) {
198+
func deleteAll(tlsaRecords []tlsaRecord) error {
175199
zoneID, authToken := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN")
176200

177201
for _, record := range tlsaRecords {
178202
log.Println("Deleting DNS record:", record.ID)
179203
url := cloudflareAPI + zoneID + "/dns_records/" + record.ID
180204
resp, err := makeHTTPRequest("DELETE", url, authToken, nil)
181-
handleResponse(resp, err, "DELETE")
205+
err = handleResponse(resp, err, "DELETE")
206+
if err != nil {
207+
return err
208+
}
182209
}
210+
211+
return nil
183212
}
184213

185-
func addRequest(hash string) {
214+
func addRequest(hash string) error {
186215
log.Println("Adding DNS record with hash:", hash)
187216

188217
zoneID, authToken, domain := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN"), os.Getenv("DOMAIN")
@@ -193,10 +222,10 @@ func addRequest(hash string) {
193222
port, protocol, domain, usage, selector, matchingType, hash)
194223

195224
resp, err := makeHTTPRequest("POST", url, authToken, []byte(payload))
196-
handleResponse(resp, err, "POST")
225+
return handleResponse(resp, err, "POST")
197226
}
198227

199-
func modifyRequest(hash, id string) {
228+
func modifyRequest(hash, id string) error {
200229
log.Println("Modifying DNS record:", id, "with hash:", hash)
201230

202231
zoneID, authToken, domain := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN"), os.Getenv("DOMAIN")
@@ -207,7 +236,7 @@ func modifyRequest(hash, id string) {
207236
port, protocol, domain, usage, selector, matchingType, hash)
208237

209238
resp, err := makeHTTPRequest("PUT", url, authToken, []byte(payload))
210-
handleResponse(resp, err, "PUT")
239+
return handleResponse(resp, err, "PUT")
211240
}
212241

213242
func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Response, error) {
@@ -223,29 +252,27 @@ func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Respo
223252
return client.Do(req)
224253
}
225254

226-
func handleResponse(resp *http.Response, err error, action string) {
255+
func handleResponse(resp *http.Response, err error, action string) error {
227256
if err != nil {
228-
log.Println("Error:", err)
229-
os.Exit(1)
257+
return err
230258
}
231259
defer func(Body io.ReadCloser) {
232260
err := Body.Close()
233261
if err != nil {
234-
log.Println("Error closing HTTP body", err)
262+
log.Println("Error closing HTTP body:", err)
235263
}
236264
}(resp.Body)
237265

238-
log.Println(action, "HTTP Status Code:", resp.Status)
266+
body, err := io.ReadAll(resp.Body)
267+
if err != nil {
268+
return fmt.Errorf("failed reading response body: %v", err)
269+
}
239270

240-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
241-
body, err := io.ReadAll(resp.Body)
242-
if err != nil {
243-
log.Println("Error reading response body:", err)
244-
} else {
245-
log.Println("Response Body:", string(body))
246-
}
247-
os.Exit(1)
271+
if resp.StatusCode != http.StatusOK {
272+
return fmt.Errorf("recieved %d HTTP response status code for %s request, response body: %s", resp.StatusCode, action, string(body))
248273
}
274+
275+
return nil
249276
}
250277

251278
func checkData(record tlsaRecord, hash string) (correct bool) {

0 commit comments

Comments
 (0)