@@ -50,30 +50,26 @@ func main() {
5050 requiredEnvVars := []string {"KEY_FILE" , "KEY_FILE_NEXT" , "ZONE_ID" , "API_TOKEN" , "DOMAIN" }
5151 for _ , envVar := range requiredEnvVars {
5252 if os .Getenv (envVar ) == "" {
53- log .Println ("Error:" , envVar , "environment variable is not defined" )
54- os .Exit (1 )
53+ log .Fatalln ("Fatal:" , envVar , "environment variable is not defined" )
5554 }
5655 }
5756
5857 cert , err := generateCert (os .Getenv ("KEY_FILE" ))
5958 if err != nil {
60- log .Println ("Error generating cert:" , err )
61- os .Exit (1 )
59+ log .Fatalln ("Fatal: failed to generate current cert:" , err )
6260 }
6361
6462 certNext , err := generateCert (os .Getenv ("KEY_FILE_NEXT" ))
6563 if err != nil {
66- log .Println ("Error generating next cert:" , err )
67- os .Exit (1 )
64+ log .Fatalln ("Fatal: failed to generate next cert:" , err )
6865 }
6966
7067 log .Println ("Current cert:" , cert )
7168 log .Println ("Next cert:" , certNext )
7269
7370 tlsaRecords , err := getTLSARecords ()
7471 if err != nil {
75- log .Println ("Error:" , err )
76- return
72+ log .Fatalln ("Fatal: failed to get TLSA records:" , err )
7773 }
7874
7975 for i , record := range tlsaRecords {
@@ -82,27 +78,49 @@ func main() {
8278
8379 if len (tlsaRecords ) != 2 {
8480 log .Println ("Incorrect number of DNS entries. Deleting them and generating new ones." )
85- deleteAll (tlsaRecords )
86- addRequest (certNext )
87- addRequest (cert )
88- return
81+
82+ err = deleteAll (tlsaRecords )
83+ if err != nil {
84+ log .Fatalln ("Fatal: failed to delete all TLSA recors:" , err )
85+ }
86+
87+ err = addRequest (certNext )
88+ if err != nil {
89+ log .Fatalln ("Fatal: failed to add TLSA record for current cert:" , err )
90+ }
91+
92+ err = addRequest (cert )
93+ if err != nil {
94+ log .Fatalln ("Fatal: failed to add TLSA record for next cert:" , err )
95+ }
96+
97+ os .Exit (0 )
8998 }
9099
91- if (checkData (tlsaRecords [0 ], cert ) && checkData (tlsaRecords [1 ], certNext )) ||
92- (checkData (tlsaRecords [0 ], certNext ) && checkData (tlsaRecords [1 ], cert )) {
100+ switch {
101+ case (checkData (tlsaRecords [0 ], cert ) && checkData (tlsaRecords [1 ], certNext )) ||
102+ (checkData (tlsaRecords [0 ], certNext ) && checkData (tlsaRecords [1 ], cert )):
93103 log .Println ("Nothing to do!" )
94- } else if checkData (tlsaRecords [0 ], cert ) {
95- modifyRequest (certNext , tlsaRecords [1 ].ID )
96- } else if checkData (tlsaRecords [0 ], certNext ) {
97- modifyRequest (cert , tlsaRecords [1 ].ID )
98- } else if checkData (tlsaRecords [1 ], cert ) {
99- modifyRequest (certNext , tlsaRecords [0 ].ID )
100- } else if checkData (tlsaRecords [1 ], certNext ) {
101- modifyRequest (cert , tlsaRecords [0 ].ID )
102- } else {
103- modifyRequest (certNext , tlsaRecords [1 ].ID )
104- modifyRequest (cert , tlsaRecords [0 ].ID )
104+ case checkData (tlsaRecords [0 ], cert ):
105+ err = modifyRequest (certNext , tlsaRecords [1 ].ID )
106+ case checkData (tlsaRecords [0 ], certNext ):
107+ err = modifyRequest (cert , tlsaRecords [1 ].ID )
108+ case checkData (tlsaRecords [1 ], cert ):
109+ err = modifyRequest (certNext , tlsaRecords [0 ].ID )
110+ case checkData (tlsaRecords [1 ], certNext ):
111+ err = modifyRequest (cert , tlsaRecords [0 ].ID )
112+ default :
113+ err = modifyRequest (certNext , tlsaRecords [1 ].ID )
114+ if err != nil {
115+ break
116+ }
117+ err = modifyRequest (cert , tlsaRecords [0 ].ID )
105118 }
119+ if err != nil {
120+ log .Fatalln ("Fatal: failed to modify TLSA records:" , err )
121+ }
122+
123+ os .Exit (0 )
106124}
107125
108126func getTLSARecords () ([]tlsaRecord , error ) {
@@ -131,12 +149,18 @@ func getTLSARecords() ([]tlsaRecord, error) {
131149 }
132150 }(resp .Body )
133151
152+ body , err := io .ReadAll (resp .Body )
153+ if err != nil {
154+ return nil , fmt .Errorf ("failed reading response body: %v" , err )
155+ }
156+
134157 if resp .StatusCode != http .StatusOK {
135- return nil , fmt .Errorf ("HTTP request failed with status code: %s" , resp .Status )
158+ return nil , fmt .Errorf ("recieved %d HTTP response status code for GET request, response body : %s" , resp .StatusCode , string ( body ) )
136159 }
137160
138161 var response tlsaRecordsResponse
139- if err := json .NewDecoder (resp .Body ).Decode (& response ); err != nil {
162+ err = json .Unmarshal (body , & response )
163+ if err != nil {
140164 return nil , fmt .Errorf ("failed to decode JSON response: %v" , err )
141165 }
142166
@@ -171,18 +195,23 @@ func generateCert(keyPath string) (string, error) {
171195 return hex .EncodeToString (hashSum ), nil
172196}
173197
174- func deleteAll (tlsaRecords []tlsaRecord ) {
198+ func deleteAll (tlsaRecords []tlsaRecord ) error {
175199 zoneID , authToken := os .Getenv ("ZONE_ID" ), os .Getenv ("API_TOKEN" )
176200
177201 for _ , record := range tlsaRecords {
178202 log .Println ("Deleting DNS record:" , record .ID )
179203 url := cloudflareAPI + zoneID + "/dns_records/" + record .ID
180204 resp , err := makeHTTPRequest ("DELETE" , url , authToken , nil )
181- handleResponse (resp , err , "DELETE" )
205+ err = handleResponse (resp , err , "DELETE" )
206+ if err != nil {
207+ return err
208+ }
182209 }
210+
211+ return nil
183212}
184213
185- func addRequest (hash string ) {
214+ func addRequest (hash string ) error {
186215 log .Println ("Adding DNS record with hash:" , hash )
187216
188217 zoneID , authToken , domain := os .Getenv ("ZONE_ID" ), os .Getenv ("API_TOKEN" ), os .Getenv ("DOMAIN" )
@@ -193,10 +222,10 @@ func addRequest(hash string) {
193222 port , protocol , domain , usage , selector , matchingType , hash )
194223
195224 resp , err := makeHTTPRequest ("POST" , url , authToken , []byte (payload ))
196- handleResponse (resp , err , "POST" )
225+ return handleResponse (resp , err , "POST" )
197226}
198227
199- func modifyRequest (hash , id string ) {
228+ func modifyRequest (hash , id string ) error {
200229 log .Println ("Modifying DNS record:" , id , "with hash:" , hash )
201230
202231 zoneID , authToken , domain := os .Getenv ("ZONE_ID" ), os .Getenv ("API_TOKEN" ), os .Getenv ("DOMAIN" )
@@ -207,7 +236,7 @@ func modifyRequest(hash, id string) {
207236 port , protocol , domain , usage , selector , matchingType , hash )
208237
209238 resp , err := makeHTTPRequest ("PUT" , url , authToken , []byte (payload ))
210- handleResponse (resp , err , "PUT" )
239+ return handleResponse (resp , err , "PUT" )
211240}
212241
213242func makeHTTPRequest (method , url , authToken string , payload []byte ) (* http.Response , error ) {
@@ -223,29 +252,27 @@ func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Respo
223252 return client .Do (req )
224253}
225254
226- func handleResponse (resp * http.Response , err error , action string ) {
255+ func handleResponse (resp * http.Response , err error , action string ) error {
227256 if err != nil {
228- log .Println ("Error:" , err )
229- os .Exit (1 )
257+ return err
230258 }
231259 defer func (Body io.ReadCloser ) {
232260 err := Body .Close ()
233261 if err != nil {
234- log .Println ("Error closing HTTP body" , err )
262+ log .Println ("Error closing HTTP body: " , err )
235263 }
236264 }(resp .Body )
237265
238- log .Println (action , "HTTP Status Code:" , resp .Status )
266+ body , err := io .ReadAll (resp .Body )
267+ if err != nil {
268+ return fmt .Errorf ("failed reading response body: %v" , err )
269+ }
239270
240- if resp .StatusCode < 200 || resp .StatusCode >= 300 {
241- body , err := io .ReadAll (resp .Body )
242- if err != nil {
243- log .Println ("Error reading response body:" , err )
244- } else {
245- log .Println ("Response Body:" , string (body ))
246- }
247- os .Exit (1 )
271+ if resp .StatusCode != http .StatusOK {
272+ return fmt .Errorf ("recieved %d HTTP response status code for %s request, response body: %s" , resp .StatusCode , action , string (body ))
248273 }
274+
275+ return nil
249276}
250277
251278func checkData (record tlsaRecord , hash string ) (correct bool ) {
0 commit comments