diff --git a/src/main.go b/src/main.go index d34cfbb..0aa195f 100644 --- a/src/main.go +++ b/src/main.go @@ -18,6 +18,7 @@ func main() { inputDir := os.Args[1] outputDir := os.Args[2] + packageName := os.Args[3] files, err := getGoFiles(inputDir) if err != nil { fmt.Println(err) @@ -38,7 +39,7 @@ func main() { adaptedFile := strings.Replace(filepath.Base(file), ".go", "_adapted.go", 1) adaptedFile = filepath.Join(outputDir, adaptedFile) // Write the struct declarations and import declarations to the adapted file. - err = writeStruct(typeDecls, importDecls, adaptedFile) + err = writeStruct(typeDecls, importDecls, adaptedFile, packageName) if err != nil { fmt.Println(err) continue @@ -135,15 +136,18 @@ func extractTypes(file string) ([]string, []string, error) { } // writeStruct writes the given struct declarations and import declarations to a new file with the given name. -func writeStruct(structDecls []string, importDecls []string, file string) error { +func writeStruct(structDecls []string, importDecls []string, file string, packageName string) error { f, err := os.Create(file) if err != nil { return err } defer f.Close() + if packageName == "" { + packageName = "dto" + } // Add the package dto declaration at the top of the file. - _, err = io.WriteString(f, "package dto\n\n") + _, err = io.WriteString(f, fmt.Sprintf("package %s\n\n", packageName)) if err != nil { return err }